SAX sur Cloud TPU v5e

Cluster SAX (cellule SAX)

Le serveur d'administration SAX et le serveur de modèle SAX sont deux composants essentiels pour exécuter un cluster SAX.

Serveur d'administration SAX

Le serveur d'administration SAX surveille et coordonne tous les serveurs de modèle SAX d'un cluster SAX. Dans un cluster SAX, vous pouvez lancer plusieurs serveurs d'administration SAX, dont un seul est désigné comme étant actif par le biais d'une "élection du leader". Les autres sont des serveurs de secours. Lorsqu'un serveur d'administration actif présente une défaillance, un serveur d'administration de secours devient actif. Le serveur d'administration SAX actif attribue des instances répliquées de modèles et des requêtes d'inférence aux serveurs de modèle SAX disponibles.

Bucket de stockage d'administration SAX

Chaque cluster SAX nécessite un bucket Cloud Storage pour stocker les configurations et les emplacements des serveurs d'administration et de modèle SAX du cluster SAX.

Serveur de modèle SAX

Le serveur de modèle SAX charge un point de contrôle de modèle et exécute l'inférence avec GSPMD. Un serveur de modèle SAX s'exécute sur un seul nœud de calcul de VM TPU. La mise en service de modèle TPU à hôte unique nécessite un seul serveur de modèle SAX sur une VM TPU à hôte unique. La mise en service de modèles TPU multi-hôtes nécessite un groupe de serveurs de modèle SAX sur une tranche de TPU multi-hôtes. La mise en service de modèles multi-hôtes n'est pas disponible pour le moment, mais ce document fournit un exemple avec un modèle de test 175B en guise d'aperçu.

Mise en service du modèle SAX

La section suivante présente le workflow de mise en service de modèles de langage à l'aide de SAX. Nous utilisons ici le modèle GPT-J 6B comme exemple pour la mise en service de modèle à hôte unique.

Avant de commencer, installez les images Docker SAX Cloud TPU sur votre VM TPU :

sudo usermod -a -G docker ${USER}
newgrp docker

gcloud auth configure-docker us-docker.pkg.dev

SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server"
SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server"
SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util"

SAX_VERSION=v1.0.0

export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}"

docker pull ${SAX_ADMIN_SERVER_IMAGE_URL}
docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
docker pull ${SAX_UTIL_IMAGE_URL}

Définissez d'autres variables que vous utiliserez plus tard :

export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server"
export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server"
export SAX_CELL="/sax/test"

Exemple de mise en service de modèle GPT-J 6B à hôte unique

La mise en service de modèle à hôte unique s'applique aux tranches de TPU à hôte unique (v5litepod-1, v5litepod-4 et v5litepod-8).

  1. Créez un cluster SAX.

    1. Créez un bucket de stockage Cloud Storage pour le cluster SAX :

      SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket}
      
      gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \
      --project=${PROJECT_ID}

      Vous aurez peut-être besoin d'un autre bucket Cloud Storage pour stocker le point de contrôle.

      SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
    2. Connectez-vous en SSH à votre VM TPU dans un terminal pour lancer le serveur d'administration SAX :

      docker run \
      --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \
      -it \
      -d \
      --rm \
      --network host \
      --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \
      ${SAX_ADMIN_SERVER_IMAGE_URL}

      Pour consulter le journal Docker :

      docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}

      La sortie dans le journal ressemblera à ceci :

      I0829 01:22:31.184198       7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.347883       7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.360837      24 admin_server.go:44] Starting the server
      I0829 01:22:31.361420      24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8.
      I0829 01:22:31.361455      24 ipaddr.go:39] Skipping non-global IP address ::1/128.
      I0829 01:22:31.361462      24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64.
      I0829 01:22:31.361469      24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64.
      I0829 01:22:31.361474      24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64.
      I0829 01:22:31.361482      24 ipaddr.go:56] IPNet address 10.142.15.200
      I0829 01:22:31.361488      24 ipaddr.go:56] IPNet address 172.17.0.1
      I0829 01:22:31.456952      24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.609323      24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000"
      I0829 01:22:31.656021      24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.773245      24 mgr.go:781] Loaded manager state
      I0829 01:22:31.773260      24 mgr.go:784] Refreshing manager state every 10s
      I0829 01:22:31.773285      24 admin.go:350] Starting the server on port 10000
      I0829 01:22:31.773292      24 cloud.go:506] Starting the HTTP server on port 8080
  2. Lancez un serveur de modèle SAX à hôte unique dans le cluster SAX :

    À ce stade, le cluster SAX ne contient que le serveur d'administration SAX. Vous pouvez vous connecter à votre VM TPU via SSH dans un deuxième terminal pour lancer un serveur de modèle SAX dans votre cluster SAX :

    docker run \
        --privileged  \
        -it \
        -d \
        --rm \
        --network host \
        --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
        --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1
  3. Convertissez le point de contrôle du modèle :

    Vous devez installer PyTorch et Transformers pour télécharger le point de contrôle GPT-J depuis EleutherAI :

    pip3 install accelerate
    pip3 install torch
    pip3 install transformers

    Pour convertir le point de contrôle en point de contrôle SAX, vous devez installer paxml :

    pip3 install paxml==1.1.0

    Le script suivant convertit le point de contrôle GPT-J en point de contrôle SAX :

    python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b

    Une fois la conversion terminée :

    ls checkpoint_00000000/

    Vous devez créer un fichier commit_success et le placer dans les sous-répertoires :

    gcloud storage cp checkpoint_00000000 ${CHECKPOINT_PATH} --recursive
    
    touch commit_success.txt
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/metadata/
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/state/
    
  4. Publiez le modèle sur le cluster SAX.

    Vous pouvez maintenant publier GPT-J avec le point de contrôle que vous avez converti lors de l'étape précédente.

    MODEL_NAME=gptjtokenizedbf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32
    REPLICA=1

    Pour publier GPT-J (et pour les étapes suivantes), utilisez SSH pour vous connecter à votre VM TPU dans un troisième terminal :

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}

    Vous observerez beaucoup d'activité dans le journal Docker du serveur de modèle jusqu'à ce que quelque chose de semblable à ce qui suit s'affiche pour indiquer que le modèle a bien été chargé :

    I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
    
  5. Générez des résultats d'inférence.

    Pour GPT-J, l'entrée et la sortie doivent respecter le format suivant : chaîne d'ID de jetons séparés par une virgule. Vous devrez tokeniser l'entrée de texte.

    TEXT = "Below is an instruction that describes a task, paired with
    an input that provides further context. Write a response that
    appropriately completes the request.\n\n### Instruction\:\nSummarize the
    following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly
    international in scope on Tuesday. We're visiting Italy, Russia, the
    United Arab Emirates, and the Himalayan Mountains. Find out who's
    attempting to circumnavigate the globe in a plane powered partially by the
    sun, and explore the mysterious appearance of craters in northern Asia.
    You'll also get a view of Mount Everest that was previously reserved for
    climbers. On this page you will find today's show Transcript and a place
    for you to request to be on the CNN Student News Roll Call. TRANSCRIPT .
    Click here to access the transcript of today's CNN Student News program.
    Please note that there may be a delay between the time when the video is
    available and when the transcript is published. CNN Student News is
    created by a team of journalists who consider the Common Core State
    Standards, national standards in different subject areas, and state
    standards when producing the show. ROLL CALL . For a chance to be
    mentioned on the next CNN Student News, comment on the bottom of this page
    with your school name, mascot, city and state. We will be selecting
    schools from the comments of the previous show. You must be a teacher or a
    student age 13 or older to request a mention on the CNN Student News Roll
    Call! Thank you for using CNN Student News!\n\n### Response\:

    Vous pouvez obtenir la chaîne d'ID de jetons via le tokenizer EleutherAI/gpt-j-6b :

    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b")                  :

    Tokenisez le texte d'entrée :

    encoded_example = tokenizer(TEXT)
    input_ids = encoded_example.input_ids
    INPUT_STR = ",".join([str(input_id) for input_id in input_ids])

    Vous pouvez vous attendre à une chaîne d'ID de jeton semblable à celle-ci :

    >>> INPUT_STR
    '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'

    Pour générer un résumé de votre article :

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          ${INPUT_STR}
    

    Vous pouvez vous attendre à un résultat semblable à celui-ci :

    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    |                                                                                                                                                    GENERATE                                                                                                                                                    |    SCORE     |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256  |  -0.91842502 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256     |   -1.1726116 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256                            |   -1.2472695 |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+

    Pour détokeniser la chaîne d'ID de jetons de sortie :

    output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')]
    OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
    

    Vous pouvez vous attendre à ce que le texte détokenisé ressemble à ceci :

    >>> OUTPUT_TEXT
    'This page includes the show Transcript.\nUse the Transcript to help
    students with reading comprehension and vocabulary.\nAt the bottom of
    the page, comment for a chance to be mentioned on CNN Student News.
    You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
  6. Nettoyez vos conteneurs Docker et vos buckets de stockage Cloud Storage.

Aperçu de la mise en service de modèles multi-hôtes 175B

Certains grands modèles de langage nécessitent une tranche de TPU multi-hôtes, du type v5litepod-16 ou supérieure. Dans ce cas, tous les hôtes TPU multi-hôtes doivent disposer d'une copie d'un serveur de modèle SAX, et tous les serveurs de modèle fonctionnent comme un groupe de serveurs de modèle SAX pour diffuser le grand modèle sur une tranche TPU multi-hôtes.

  1. Créez un cluster SAX.

    Vous pouvez suivre l'étape de création d'un cluster SAX décrite dans le tutoriel GPT-J pour créer un cluster SAX et un serveur d'administration SAX.

    Si vous disposez déjà d'un cluster SAX, vous pouvez lancer un serveur de modèle multi-hôtes dans votre cluster SAX.

  2. Lancez un serveur de modèle SAX multi-hôtes dans un cluster SAX.

    Pour créer une tranche de TPU multi-hôtes, utilisez la même commande que pour une tranche de TPU à hôte unique, mais en spécifiant le type d'accélérateur multi-hôtes approprié :

    ACCELERATOR_TYPE=v5litepod-32
    ZONE=us-east1-c
    
    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --service-account ${SERVICE_ACCOUNT} \
      --reserved
    

    Pour extraire l'image de serveur de modèle SAX vers tous les hôtes/nœuds de calcul TPU et les lancer :

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --worker=all \
      --command="
        gcloud auth configure-docker \
          us-docker.pkg.dev
        # Pull sax model server image
        docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
        # Run model server
        docker run \
          --privileged  \
          -it \
          -d \
          --rm \
          --network host \
          --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
          --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
          ${SAX_MODEL_SERVER_IMAGE_URL} \
            --sax_cell=${SAX_CELL} \
            --port=10001 \
            --platform_chip=tpuv4 \
            --platform_topology=1x1"
    
  3. Publiez le modèle sur le cluster SAX.

    Cet exemple utilise un modèle LmCloudSpmd175B32Test :

    MODEL_NAME=lmcloudspmd175b32test
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test
    CHECKPOINT_PATH=None
    REPLICA=1

    Pour publier le modèle de test :

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
  4. Générez des résultats d'inférence.

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          "Q:  Who is Harry Porter's mother? A\: "
    

    Notez que, comme cet exemple utilise un modèle de test avec des pondérations aléatoires, il se peut que la sortie n'ait pas de sens.

  5. Effectuez un nettoyage.

    Arrêtez les conteneurs Docker :

    docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME}
    docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}

    Supprimez votre bucket de stockage administrateur Cloud Storage et tout bucket de stockage de données à l'aide de la gcloud CLI, comme indiqué ci-dessous :

    gcloud storage rm gs://${SAX_ADMIN_STORAGE_BUCKET} --recursive --continue-on-error
    gcloud storage rm gs://${SAX_DATA_STORAGE_BUCKET} --recursive --continue-on-error