SAX su Cloud TPU v5e
Cluster SAX (cella SAX)
Il server di amministrazione SAX e il server di modelli SAX sono due componenti essenziali per eseguire un cluster SAX.
Server di amministrazione SAX
Il server di amministrazione SAX monitora e coordina tutti i server di modelli SAX in un cluster SAX. In un cluster SAX, puoi avviare più server di amministrazione SAX, in cui solo uno dei server di amministrazione SAX è attivo tramite l'elezione del leader, mentre gli altri sono server di riserva. Quando il server amministrativo attivo si guasta, un server amministrativo di riserva diventa attivo. Il server di amministrazione SAX attivo assegna le richieste di inferenza e le repliche dei modelli ai server di modelli SAX disponibili.
Bucket di archiviazione amministratore SAX
Ogni cluster SAX richiede un bucket Cloud Storage per archiviare le configurazioni e le posizioni dei server di amministrazione SAX e dei server di modelli SAX nel cluster SAX.
Server del modello SAX
Il server di modelli SAX carica un checkpoint del modello ed esegue l'inferenza con GSPMD. Un server di modelli SAX viene eseguito su un singolo worker VM TPU. La pubblicazione di modelli TPU a singolo host richiede un singolo server di modelli SAX su una VM TPU a singolo host. La pubblicazione di modelli TPU multi-host richiede un gruppo di server di modelli SAX su uno slice TPU multi-host. La pubblicazione di modelli multi-host non è attualmente disponibile, ma questo documento fornisce un esempio con un modello di test di 175 miliardi di elementi per l'anteprima.
Erogazione del modello SAX
La sezione seguente illustra il flusso di lavoro per la pubblicazione di modelli linguistici utilizzando SAX. Utilizza il modello GPT-J 6B come esempio per il servizio di modelli su un singolo host.
Prima di iniziare, installa le immagini Docker Cloud TPU SAX sulla tua VM TPU:
sudo usermod -a -G docker ${USER} newgrp docker gcloud auth configure-docker us-docker.pkg.dev SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server" SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server" SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util" SAX_VERSION=v1.0.0 export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION} export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION} export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}" docker pull ${SAX_ADMIN_SERVER_IMAGE_URL} docker pull ${SAX_MODEL_SERVER_IMAGE_URL} docker pull ${SAX_UTIL_IMAGE_URL}
Imposta altre variabili che utilizzerai in seguito:
export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server" export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server" export SAX_CELL="/sax/test"
Esempio di distribuzione di un modello GPT-J 6B a host singolo
La pubblicazione di modelli a singolo host è applicabile al slice TPU a singolo host, ovvero v5litepod-1, v5litepod-4 e v5litepod-8.
Crea un cluster SAX
Crea un bucket di archiviazione Cloud Storage per il cluster SAX:
SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket} gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \ --project=${PROJECT_ID}
Potresti aver bisogno di un altro bucket di archiviazione Cloud Storage per memorizzare il checkpoint.
SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
Accedi tramite SSH alla VM TPU in un terminale per avviare il server di amministrazione SAX:
docker run \ --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \ -it \ -d \ --rm \ --network host \ --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \ ${SAX_ADMIN_SERVER_IMAGE_URL}
Per controllare il log di Docker:
docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}
L'output nel log sarà simile al seguente:
I0829 01:22:31.184198 7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.347883 7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.360837 24 admin_server.go:44] Starting the server I0829 01:22:31.361420 24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8. I0829 01:22:31.361455 24 ipaddr.go:39] Skipping non-global IP address ::1/128. I0829 01:22:31.361462 24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64. I0829 01:22:31.361469 24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64. I0829 01:22:31.361474 24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64. I0829 01:22:31.361482 24 ipaddr.go:56] IPNet address 10.142.15.200 I0829 01:22:31.361488 24 ipaddr.go:56] IPNet address 172.17.0.1 I0829 01:22:31.456952 24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.609323 24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000" I0829 01:22:31.656021 24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.773245 24 mgr.go:781] Loaded manager state I0829 01:22:31.773260 24 mgr.go:784] Refreshing manager state every 10s I0829 01:22:31.773285 24 admin.go:350] Starting the server on port 10000 I0829 01:22:31.773292 24 cloud.go:506] Starting the HTTP server on port 8080
Avvia un server di modelli SAX a un solo host nel cluster SAX:
A questo punto, il cluster SAX contiene solo il server di amministrazione SAX. Puoi connetterti alla VM TPU tramite SSH in un secondo terminale per avviare un server di modelli SAX nel cluster SAX:
docker run \ --privileged \ -it \ -d \ --rm \ --network host \ --name ${SAX_MODEL_SERVER_DOCKER_NAME} \ --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ ${SAX_MODEL_SERVER_IMAGE_URL} \ --sax_cell=${SAX_CELL} \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=1x1
Converti il checkpoint del modello:
Per scaricare il checkpoint GPT-J da EleutherAI, devi installare PyTorch e Transformers:
pip3 install accelerate pip3 install torch pip3 install transformers
Per convertire il checkpoint in un checkpoint SAX, devi installare
paxml:pip3 install paxml==1.1.0
Il seguente script converte il checkpoint GPT-J in checkpoint SAX:
python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b
Al termine della conversione:
ls checkpoint_00000000/Devi creare un file commit_success e posizionarlo nelle sottodirectory:
gcloud storage cp checkpoint_00000000 ${CHECKPOINT_PATH} --recursive touch commit_success.txt gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/ gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/metadata/ gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/state/Pubblica il modello nel cluster SAX
Ora puoi pubblicare GPT-J con il checkpoint convertito nel passaggio precedente.
MODEL_NAME=gptjtokenizedbf16bs32 MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32 REPLICA=1
Per pubblicare GPT-J (e i passaggi successivi), utilizza SSH per connetterti alla VM TPU in un terzo terminale:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ publish \ ${SAX_CELL}/${MODEL_NAME} \ ${MODEL_CONFIG_PATH} \ ${CHECKPOINT_PATH} \ ${REPLICA}
Vedrai molte attività nel log Docker del server del modello fino a quando non vedrai un messaggio simile al seguente che indica che il modello è stato caricato correttamente:
I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
Generare risultati dell'inferenza
Per GPT-J, l'input e l'output devono essere formattati come stringa di ID token separata da virgole. Dovrai tokenizzare l'input di testo.
TEXT = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction\:\nSummarize the following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly international in scope on Tuesday. We're visiting Italy, Russia, the United Arab Emirates, and the Himalayan Mountains. Find out who's attempting to circumnavigate the globe in a plane powered partially by the sun, and explore the mysterious appearance of craters in northern Asia. You'll also get a view of Mount Everest that was previously reserved for climbers. On this page you will find today's show Transcript and a place for you to request to be on the CNN Student News Roll Call. TRANSCRIPT . Click here to access the transcript of today's CNN Student News program. Please note that there may be a delay between the time when the video is available and when the transcript is published. CNN Student News is created by a team of journalists who consider the Common Core State Standards, national standards in different subject areas, and state standards when producing the show. ROLL CALL . For a chance to be mentioned on the next CNN Student News, comment on the bottom of this page with your school name, mascot, city and state. We will be selecting schools from the comments of the previous show. You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call! Thank you for using CNN Student News!\n\n### Response\:
Puoi ottenere la stringa degli ID token tramite il tokenizer EleutherAI/gpt-j-6b:
from transformers import GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b") :
Tokenizza il testo di input:
encoded_example = tokenizer(TEXT) input_ids = encoded_example.input_ids INPUT_STR = ",".join([str(input_id) for input_id in input_ids])
Dovresti visualizzare una stringa di ID token simile alla seguente:
>>> INPUT_STR '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'
Per generare un riepilogo del tuo articolo:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ lm.generate \ ${SAX_CELL}/${MODEL_NAME} \ ${INPUT_STR}Dovresti visualizzare qualcosa di simile a quanto segue:
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+ | GENERATE | SCORE | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+ | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256 | -0.91842502 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -1.1726116 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -1.2472695 | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
Per detokenizzare la stringa degli ID token di output:
output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')] OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)Il testo detokenizzato sarà il seguente:
>>> OUTPUT_TEXT 'This page includes the show Transcript.\nUse the Transcript to help students with reading comprehension and vocabulary.\nAt the bottom of the page, comment for a chance to be mentioned on CNN Student News. You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
Libera spazio nei container Docker e nei bucket di archiviazione Cloud Storage.
Anteprima del servizio di pubblicazione di modelli multi-host da 175 miliardi
Alcuni modelli linguistici di grandi dimensioni richiedono uno slice TPU multi-host, ovvero v5litepod-16 e versioni successive. In questi casi, tutti gli host TPU multi-host dovranno avere una copia di un server di modelli SAX e tutti i server di modelli funzioneranno come gruppo di server di modelli SAX per pubblicare il modello di grandi dimensioni su uno slice TPU multi-host.
Creare un nuovo cluster SAX
Puoi seguire la stessa procedura per creare un cluster SAX descritta nella guida di GPT-J per creare un nuovo cluster SAX e un server di amministrazione SAX.
In alternativa, se hai già un cluster SAX, puoi avviare un server di modelli multi-host nel cluster SAX.
Avvia un server di modelli SAX multi-host in un cluster SAX
Utilizza lo stesso comando per creare una sezione TPU multi-host che per una sezione TPU mono-host, specifica solo il tipo di acceleratore multi-host appropriato:
ACCELERATOR_TYPE=v5litepod-32 ZONE=us-east1-c gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --service-account ${SERVICE_ACCOUNT} \ --reservedPer estrarre l'immagine del server del modello SAX su tutti gli host/worker TPU e lanciarli:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=" gcloud auth configure-docker \ us-docker.pkg.dev # Pull sax model server image docker pull ${SAX_MODEL_SERVER_IMAGE_URL} # Run model server docker run \ --privileged \ -it \ -d \ --rm \ --network host \ --name ${SAX_MODEL_SERVER_DOCKER_NAME} \ --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ ${SAX_MODEL_SERVER_IMAGE_URL} \ --sax_cell=${SAX_CELL} \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=1x1"Pubblica il modello nel cluster SAX
Questo esempio utilizza un modello LmCloudSpmd175B32Test:
MODEL_NAME=lmcloudspmd175b32test MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test CHECKPOINT_PATH=None REPLICA=1
Per pubblicare il modello di test:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ publish \ ${SAX_CELL}/${MODEL_NAME} \ ${MODEL_CONFIG_PATH} \ ${CHECKPOINT_PATH} \ ${REPLICA}
Generare risultati dell'inferenza
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ lm.generate \ ${SAX_CELL}/${MODEL_NAME} \ "Q: Who is Harry Porter's mother? A\: "Tieni presente che, poiché questo esempio utilizza un modello di test con pesi casuali, il risultato potrebbe non essere significativo.
Esegui la pulizia
Interrompi i container Docker:
docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME} docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}
Elimina il bucket di archiviazione amministrativo di Cloud Storage e qualsiasi bucket di archiviazione dei dati utilizzando gcloud CLI come mostrato di seguito.
gcloud storage rm gs://${SAX_ADMIN_STORAGE_BUCKET} --recursive --continue-on-error gcloud storage rm gs://${SAX_DATA_STORAGE_BUCKET} --recursive --continue-on-error