SAX su Cloud TPU v5e

Cluster SAX (cella SAX)

Il server di amministrazione SAX e il server di modelli SAX sono due componenti essenziali per eseguire un cluster SAX.

Server di amministrazione SAX

Il server di amministrazione SAX monitora e coordina tutti i server di modelli SAX in un cluster SAX. In un cluster SAX, puoi avviare più server di amministrazione SAX, in cui solo uno dei server di amministrazione SAX è attivo tramite l'elezione del leader, mentre gli altri sono server di riserva. Quando il server amministrativo attivo si guasta, un server amministrativo di riserva diventa attivo. Il server di amministrazione SAX attivo assegna le richieste di inferenza e le repliche dei modelli ai server di modelli SAX disponibili.

Bucket di archiviazione amministratore SAX

Ogni cluster SAX richiede un bucket Cloud Storage per archiviare le configurazioni e le posizioni dei server di amministrazione SAX e dei server di modelli SAX nel cluster SAX.

Server del modello SAX

Il server di modelli SAX carica un checkpoint del modello ed esegue l'inferenza con GSPMD. Un server di modelli SAX viene eseguito su un singolo worker VM TPU. La pubblicazione di modelli TPU a singolo host richiede un singolo server di modelli SAX su una VM TPU a singolo host. La pubblicazione di modelli TPU multi-host richiede un gruppo di server di modelli SAX su uno slice TPU multi-host. La pubblicazione di modelli multi-host non è attualmente disponibile, ma questo documento fornisce un esempio con un modello di test di 175 miliardi di elementi per l'anteprima.

Erogazione del modello SAX

La sezione seguente illustra il flusso di lavoro per la pubblicazione di modelli linguistici utilizzando SAX. Utilizza il modello GPT-J 6B come esempio per il servizio di modelli su un singolo host.

Prima di iniziare, installa le immagini Docker Cloud TPU SAX sulla tua VM TPU:

sudo usermod -a -G docker ${USER}
newgrp docker

gcloud auth configure-docker us-docker.pkg.dev

SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server"
SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server"
SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util"

SAX_VERSION=v1.0.0

export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}"

docker pull ${SAX_ADMIN_SERVER_IMAGE_URL}
docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
docker pull ${SAX_UTIL_IMAGE_URL}

Imposta altre variabili che utilizzerai in seguito:

export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server"
export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server"
export SAX_CELL="/sax/test"

Esempio di distribuzione di un modello GPT-J 6B a host singolo

La pubblicazione di modelli a singolo host è applicabile al slice TPU a singolo host, ovvero v5litepod-1, v5litepod-4 e v5litepod-8.

  1. Crea un cluster SAX

    1. Crea un bucket di archiviazione Cloud Storage per il cluster SAX:

      SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket}
      
      gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \
      --project=${PROJECT_ID}

      Potresti aver bisogno di un altro bucket di archiviazione Cloud Storage per memorizzare il checkpoint.

      SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
    2. Accedi tramite SSH alla VM TPU in un terminale per avviare il server di amministrazione SAX:

      docker run \
      --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \
      -it \
      -d \
      --rm \
      --network host \
      --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \
      ${SAX_ADMIN_SERVER_IMAGE_URL}

      Per controllare il log di Docker:

      docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}

      L'output nel log sarà simile al seguente:

      I0829 01:22:31.184198       7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.347883       7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.360837      24 admin_server.go:44] Starting the server
      I0829 01:22:31.361420      24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8.
      I0829 01:22:31.361455      24 ipaddr.go:39] Skipping non-global IP address ::1/128.
      I0829 01:22:31.361462      24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64.
      I0829 01:22:31.361469      24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64.
      I0829 01:22:31.361474      24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64.
      I0829 01:22:31.361482      24 ipaddr.go:56] IPNet address 10.142.15.200
      I0829 01:22:31.361488      24 ipaddr.go:56] IPNet address 172.17.0.1
      I0829 01:22:31.456952      24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.609323      24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000"
      I0829 01:22:31.656021      24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.773245      24 mgr.go:781] Loaded manager state
      I0829 01:22:31.773260      24 mgr.go:784] Refreshing manager state every 10s
      I0829 01:22:31.773285      24 admin.go:350] Starting the server on port 10000
      I0829 01:22:31.773292      24 cloud.go:506] Starting the HTTP server on port 8080
  2. Avvia un server di modelli SAX a un solo host nel cluster SAX:

    A questo punto, il cluster SAX contiene solo il server di amministrazione SAX. Puoi connetterti alla VM TPU tramite SSH in un secondo terminale per avviare un server di modelli SAX nel cluster SAX:

    docker run \
        --privileged  \
        -it \
        -d \
        --rm \
        --network host \
        --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
        --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1
  3. Converti il checkpoint del modello:

    Per scaricare il checkpoint GPT-J da EleutherAI, devi installare PyTorch e Transformers:

    pip3 install accelerate
    pip3 install torch
    pip3 install transformers

    Per convertire il checkpoint in un checkpoint SAX, devi installare paxml:

    pip3 install paxml==1.1.0

    Il seguente script converte il checkpoint GPT-J in checkpoint SAX:

    python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b

    Al termine della conversione:

    ls checkpoint_00000000/

    Devi creare un file commit_success e posizionarlo nelle sottodirectory:

    gcloud storage cp checkpoint_00000000 ${CHECKPOINT_PATH} --recursive
    
    touch commit_success.txt
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/metadata/
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/state/
    
  4. Pubblica il modello nel cluster SAX

    Ora puoi pubblicare GPT-J con il checkpoint convertito nel passaggio precedente.

    MODEL_NAME=gptjtokenizedbf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32
    REPLICA=1

    Per pubblicare GPT-J (e i passaggi successivi), utilizza SSH per connetterti alla VM TPU in un terzo terminale:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}

    Vedrai molte attività nel log Docker del server del modello fino a quando non vedrai un messaggio simile al seguente che indica che il modello è stato caricato correttamente:

    I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
    
  5. Generare risultati dell'inferenza

    Per GPT-J, l'input e l'output devono essere formattati come stringa di ID token separata da virgole. Dovrai tokenizzare l'input di testo.

    TEXT = "Below is an instruction that describes a task, paired with
    an input that provides further context. Write a response that
    appropriately completes the request.\n\n### Instruction\:\nSummarize the
    following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly
    international in scope on Tuesday. We're visiting Italy, Russia, the
    United Arab Emirates, and the Himalayan Mountains. Find out who's
    attempting to circumnavigate the globe in a plane powered partially by the
    sun, and explore the mysterious appearance of craters in northern Asia.
    You'll also get a view of Mount Everest that was previously reserved for
    climbers. On this page you will find today's show Transcript and a place
    for you to request to be on the CNN Student News Roll Call. TRANSCRIPT .
    Click here to access the transcript of today's CNN Student News program.
    Please note that there may be a delay between the time when the video is
    available and when the transcript is published. CNN Student News is
    created by a team of journalists who consider the Common Core State
    Standards, national standards in different subject areas, and state
    standards when producing the show. ROLL CALL . For a chance to be
    mentioned on the next CNN Student News, comment on the bottom of this page
    with your school name, mascot, city and state. We will be selecting
    schools from the comments of the previous show. You must be a teacher or a
    student age 13 or older to request a mention on the CNN Student News Roll
    Call! Thank you for using CNN Student News!\n\n### Response\:

    Puoi ottenere la stringa degli ID token tramite il tokenizer EleutherAI/gpt-j-6b:

    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b")                  :

    Tokenizza il testo di input:

    encoded_example = tokenizer(TEXT)
    input_ids = encoded_example.input_ids
    INPUT_STR = ",".join([str(input_id) for input_id in input_ids])

    Dovresti visualizzare una stringa di ID token simile alla seguente:

    >>> INPUT_STR
    '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'

    Per generare un riepilogo del tuo articolo:

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          ${INPUT_STR}
    

    Dovresti visualizzare qualcosa di simile a quanto segue:

    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    |                                                                                                                                                    GENERATE                                                                                                                                                    |    SCORE     |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256  |  -0.91842502 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256     |   -1.1726116 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256                            |   -1.2472695 |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+

    Per detokenizzare la stringa degli ID token di output:

    output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')]
    OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
    

    Il testo detokenizzato sarà il seguente:

    >>> OUTPUT_TEXT
    'This page includes the show Transcript.\nUse the Transcript to help
    students with reading comprehension and vocabulary.\nAt the bottom of
    the page, comment for a chance to be mentioned on CNN Student News.
    You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
  6. Libera spazio nei container Docker e nei bucket di archiviazione Cloud Storage.

Anteprima del servizio di pubblicazione di modelli multi-host da 175 miliardi

Alcuni modelli linguistici di grandi dimensioni richiedono uno slice TPU multi-host, ovvero v5litepod-16 e versioni successive. In questi casi, tutti gli host TPU multi-host dovranno avere una copia di un server di modelli SAX e tutti i server di modelli funzioneranno come gruppo di server di modelli SAX per pubblicare il modello di grandi dimensioni su uno slice TPU multi-host.

  1. Creare un nuovo cluster SAX

    Puoi seguire la stessa procedura per creare un cluster SAX descritta nella guida di GPT-J per creare un nuovo cluster SAX e un server di amministrazione SAX.

    In alternativa, se hai già un cluster SAX, puoi avviare un server di modelli multi-host nel cluster SAX.

  2. Avvia un server di modelli SAX multi-host in un cluster SAX

    Utilizza lo stesso comando per creare una sezione TPU multi-host che per una sezione TPU mono-host, specifica solo il tipo di acceleratore multi-host appropriato:

    ACCELERATOR_TYPE=v5litepod-32
    ZONE=us-east1-c
    
    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --service-account ${SERVICE_ACCOUNT} \
      --reserved
    

    Per estrarre l'immagine del server del modello SAX su tutti gli host/worker TPU e lanciarli:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --worker=all \
      --command="
        gcloud auth configure-docker \
          us-docker.pkg.dev
        # Pull sax model server image
        docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
        # Run model server
        docker run \
          --privileged  \
          -it \
          -d \
          --rm \
          --network host \
          --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
          --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
          ${SAX_MODEL_SERVER_IMAGE_URL} \
            --sax_cell=${SAX_CELL} \
            --port=10001 \
            --platform_chip=tpuv4 \
            --platform_topology=1x1"
    
  3. Pubblica il modello nel cluster SAX

    Questo esempio utilizza un modello LmCloudSpmd175B32Test:

    MODEL_NAME=lmcloudspmd175b32test
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test
    CHECKPOINT_PATH=None
    REPLICA=1

    Per pubblicare il modello di test:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
  4. Generare risultati dell'inferenza

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          "Q:  Who is Harry Porter's mother? A\: "
    

    Tieni presente che, poiché questo esempio utilizza un modello di test con pesi casuali, il risultato potrebbe non essere significativo.

  5. Esegui la pulizia

    Interrompi i container Docker:

    docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME}
    docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}

    Elimina il bucket di archiviazione amministrativo di Cloud Storage e qualsiasi bucket di archiviazione dei dati utilizzando gcloud CLI come mostrato di seguito.

    gcloud storage rm gs://${SAX_ADMIN_STORAGE_BUCKET} --recursive --continue-on-error
    gcloud storage rm gs://${SAX_DATA_STORAGE_BUCKET} --recursive --continue-on-error