SAX en TPU de Cloud v5e

Clúster SAX (celda SAX)

El servidor de administración de SAX y el servidor de modelos de SAX son dos componentes esenciales que ejecutan un clúster de SAX.

Servidor de administración de SAX

El servidor de administración de SAX monitoriza y coordina todos los servidores de modelos de SAX en un clúster de SAX. En un clúster de SAX, puedes iniciar varios servidores de administración de SAX, de los cuales solo uno está activo mediante la elección de líder, y los demás son servidores de reserva. Si falla el servidor de administrador activo, se activará un servidor de administrador de reserva. El servidor de administrador SAX activo asigna réplicas de modelos y solicitudes de inferencia a los servidores de modelos SAX disponibles.

Segmento de almacenamiento de administrador de SAX

Cada clúster de SAX requiere un segmento de Cloud Storage para almacenar las configuraciones y las ubicaciones de los servidores de administración y de los servidores de modelos de SAX en el clúster de SAX.

Servidor de modelo SAX

El servidor de modelos SAX carga un punto de control de un modelo y ejecuta la inferencia con GSPMD. Un servidor de modelos SAX se ejecuta en un solo trabajador de máquina virtual de TPU. Para servir modelos de TPU de un solo host, se necesita un solo servidor de modelos SAX en una VM de TPU de un solo host. Para servir modelos de TPU de varios hosts, se necesita un grupo de servidores de modelos SAX en una porción de TPU de varios hosts. Actualmente, no se puede usar el servicio de modelos multihost, pero en este documento se ofrece un ejemplo con un modelo de prueba de 175 B para que lo pruebes.

Servicio de modelos SAX

En la siguiente sección se explica el flujo de trabajo para servir modelos de lenguaje con SAX. Usa el modelo GPT-J 6B como ejemplo de servicio de modelo de un solo host.

Antes de empezar, instala las imágenes Docker de SAX de TPU de Cloud en tu máquina virtual de TPU:

sudo usermod -a -G docker ${USER}
newgrp docker

gcloud auth configure-docker us-docker.pkg.dev

SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server"
SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server"
SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util"

SAX_VERSION=v1.0.0

export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}"

docker pull ${SAX_ADMIN_SERVER_IMAGE_URL}
docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
docker pull ${SAX_UTIL_IMAGE_URL}

Define otras variables que usarás más adelante:

export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server"
export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server"
export SAX_CELL="/sax/test"

Ejemplo de aplicación de modelo de un solo host de GPT-J 6B

El servicio de modelos de un solo host se aplica a las slices de TPU de un solo host, es decir, v5litepod-1, v5litepod-4 y v5litepod-8.

  1. Crear un clúster de SAX

    1. Crea un segmento de Cloud Storage para el clúster de SAX:

      SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket}
      
      gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \
      --project=${PROJECT_ID}

      Es posible que necesites otro segmento de Cloud Storage para almacenar el punto de control.

      SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
    2. Accede a tu VM de TPU mediante SSH en una terminal para iniciar el servidor de administración de SAX:

      docker run \
      --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \
      -it \
      -d \
      --rm \
      --network host \
      --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \
      ${SAX_ADMIN_SERVER_IMAGE_URL}

      Para consultar el registro de Docker, sigue estos pasos:

      docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}

      La salida del registro tendrá un aspecto similar al siguiente:

      I0829 01:22:31.184198       7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.347883       7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.360837      24 admin_server.go:44] Starting the server
      I0829 01:22:31.361420      24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8.
      I0829 01:22:31.361455      24 ipaddr.go:39] Skipping non-global IP address ::1/128.
      I0829 01:22:31.361462      24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64.
      I0829 01:22:31.361469      24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64.
      I0829 01:22:31.361474      24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64.
      I0829 01:22:31.361482      24 ipaddr.go:56] IPNet address 10.142.15.200
      I0829 01:22:31.361488      24 ipaddr.go:56] IPNet address 172.17.0.1
      I0829 01:22:31.456952      24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.609323      24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000"
      I0829 01:22:31.656021      24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.773245      24 mgr.go:781] Loaded manager state
      I0829 01:22:31.773260      24 mgr.go:784] Refreshing manager state every 10s
      I0829 01:22:31.773285      24 admin.go:350] Starting the server on port 10000
      I0829 01:22:31.773292      24 cloud.go:506] Starting the HTTP server on port 8080
  2. Inicia un servidor de modelos SAX de un solo host en el clúster SAX:

    En este punto, el clúster SAX solo contiene el servidor de administración de SAX. Puedes conectarte a tu VM de TPU a través de SSH en una segunda terminal para iniciar un servidor de modelos SAX en tu clúster SAX:

    docker run \
        --privileged  \
        -it \
        -d \
        --rm \
        --network host \
        --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
        --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1
  3. Convertir punto de control del modelo:

    Debes instalar PyTorch y Transformers para descargar el punto de control de GPT-J de EleutherAI:

    pip3 install accelerate
    pip3 install torch
    pip3 install transformers

    Para convertir el punto de control en un punto de control SAX, debes instalar paxml:

    pip3 install paxml==1.1.0

    El siguiente script convierte el punto de control de GPT-J en un punto de control de SAX:

    python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b

    Una vez completada la conversión:

    ls checkpoint_00000000/

    Debe crear un archivo commit_success y colocarlo en los subdirectorios:

    gcloud storage cp checkpoint_00000000 ${CHECKPOINT_PATH} --recursive
    
    touch commit_success.txt
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/metadata/
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/state/
    
  4. Publicar el modelo en el clúster de SAX

    Ahora puedes publicar GPT-J con el punto de control convertido en el paso anterior.

    MODEL_NAME=gptjtokenizedbf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32
    REPLICA=1

    Para publicar GPT-J (y los pasos posteriores), usa SSH para conectarte a tu VM de TPU en un tercer terminal:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}

    Verá mucha actividad en el registro de Docker del servidor de modelos hasta que aparezca algo similar a lo siguiente, que indica que el modelo se ha cargado correctamente:

    I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
    
  5. Generar resultados de inferencia

    En el caso de GPT-J, la entrada y la salida deben tener el formato de una cadena de IDs de tokens separados por comas. Deberás tokenizar el texto introducido.

    TEXT = "Below is an instruction that describes a task, paired with
    an input that provides further context. Write a response that
    appropriately completes the request.\n\n### Instruction\:\nSummarize the
    following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly
    international in scope on Tuesday. We're visiting Italy, Russia, the
    United Arab Emirates, and the Himalayan Mountains. Find out who's
    attempting to circumnavigate the globe in a plane powered partially by the
    sun, and explore the mysterious appearance of craters in northern Asia.
    You'll also get a view of Mount Everest that was previously reserved for
    climbers. On this page you will find today's show Transcript and a place
    for you to request to be on the CNN Student News Roll Call. TRANSCRIPT .
    Click here to access the transcript of today's CNN Student News program.
    Please note that there may be a delay between the time when the video is
    available and when the transcript is published. CNN Student News is
    created by a team of journalists who consider the Common Core State
    Standards, national standards in different subject areas, and state
    standards when producing the show. ROLL CALL . For a chance to be
    mentioned on the next CNN Student News, comment on the bottom of this page
    with your school name, mascot, city and state. We will be selecting
    schools from the comments of the previous show. You must be a teacher or a
    student age 13 or older to request a mention on the CNN Student News Roll
    Call! Thank you for using CNN Student News!\n\n### Response\:

    Puedes obtener la cadena de IDs de token a través del tokenizador EleutherAI/gpt-j-6b:

    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b")                  :

    Tokeniza el texto de entrada:

    encoded_example = tokenizer(TEXT)
    input_ids = encoded_example.input_ids
    INPUT_STR = ",".join([str(input_id) for input_id in input_ids])

    Verás una cadena de ID de token similar a la siguiente:

    >>> INPUT_STR
    '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'

    Para generar un resumen de tu artículo, sigue estos pasos:

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          ${INPUT_STR}
    

    Puedes esperar algo similar a lo siguiente:

    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    |                                                                                                                                                    GENERATE                                                                                                                                                    |    SCORE     |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256  |  -0.91842502 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256     |   -1.1726116 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256                            |   -1.2472695 |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+

    Para destokenizar la cadena de IDs de tokens de salida, sigue estos pasos:

    output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')]
    OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
    

    El texto destokenizado será el siguiente:

    >>> OUTPUT_TEXT
    'This page includes the show Transcript.\nUse the Transcript to help
    students with reading comprehension and vocabulary.\nAt the bottom of
    the page, comment for a chance to be mentioned on CNN Student News.
    You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
  6. Limpia tus contenedores Docker y tus segmentos de Cloud Storage.

Vista previa del servicio de modelos multihost de 175.000 millones de parámetros

Algunos de los modelos de lenguaje extensos requerirán un slice de TPU multihost, es decir, v5litepod-16 y versiones posteriores. En esos casos, todos los hosts de TPU de varios hosts deberán tener una copia de un servidor de modelos SAX, y todos los servidores de modelos funcionarán como un grupo de servidores de modelos SAX para servir el modelo grande en un segmento de TPU de varios hosts.

  1. Crear un clúster de SAX

    Puedes seguir el mismo paso para crear un clúster SAX en la guía de GPT-J para crear un clúster SAX y un servidor de administrador SAX.

    Si ya tienes un clúster de SAX, puedes iniciar un servidor de modelos de varios hosts en él.

  2. Lanzar un servidor de modelos SAX de varios hosts en un clúster SAX

    Usa el mismo comando para crear un segmento de TPU de varios hosts que para crear un segmento de TPU de un solo host. Solo tienes que especificar el tipo de acelerador de varios hosts adecuado:

    ACCELERATOR_TYPE=v5litepod-32
    ZONE=us-east1-c
    
    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --service-account ${SERVICE_ACCOUNT} \
      --reserved
    

    Para extraer la imagen del servidor del modelo SAX a todos los hosts o trabajadores de TPU e iniciarlos, haz lo siguiente:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --worker=all \
      --command="
        gcloud auth configure-docker \
          us-docker.pkg.dev
        # Pull sax model server image
        docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
        # Run model server
        docker run \
          --privileged  \
          -it \
          -d \
          --rm \
          --network host \
          --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
          --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
          ${SAX_MODEL_SERVER_IMAGE_URL} \
            --sax_cell=${SAX_CELL} \
            --port=10001 \
            --platform_chip=tpuv4 \
            --platform_topology=1x1"
    
  3. Publicar el modelo en el clúster de SAX

    En este ejemplo se usa un modelo LmCloudSpmd175B32Test:

    MODEL_NAME=lmcloudspmd175b32test
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test
    CHECKPOINT_PATH=None
    REPLICA=1

    Para publicar el modelo de prueba, haz lo siguiente:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
  4. Generar resultados de inferencia

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          "Q:  Who is Harry Porter's mother? A\: "
    

    Ten en cuenta que, como en este ejemplo se usa un modelo de prueba con pesos aleatorios, puede que el resultado no sea significativo.

  5. Limpieza

    Detén los contenedores Docker:

    docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME}
    docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}

    Elimina tu segmento de almacenamiento de administrador de Cloud Storage y cualquier segmento de almacenamiento de datos con la CLI de gcloud, tal como se muestra a continuación.

    gcloud storage rm gs://${SAX_ADMIN_STORAGE_BUCKET} --recursive --continue-on-error
    gcloud storage rm gs://${SAX_DATA_STORAGE_BUCKET} --recursive --continue-on-error