SAX di Cloud TPU v5e

Cluster SAX (sel SAX)

Server admin SAX dan server model SAX adalah dua komponen penting yang menjalankan cluster SAX.

Server admin SAX

Server admin SAX memantau dan mengoordinasikan semua server model SAX dalam cluster SAX. Di cluster SAX, Anda dapat meluncurkan beberapa server admin SAX, dengan hanya satu server admin SAX yang aktif melalui pemilihan pemimpin, yang lainnya adalah server standby. Jika server admin aktif gagal, server admin standby akan menjadi aktif. Server admin SAX yang aktif menetapkan replika model dan permintaan inferensi ke server model SAX yang tersedia.

Bucket penyimpanan admin SAX

Setiap cluster SAX memerlukan bucket Cloud Storage untuk menyimpan konfigurasi dan lokasi server admin SAX serta server model SAX di cluster SAX.

Server model SAX

Server model SAX memuat checkpoint model dan menjalankan inferensi dengan GSPMD. Server model SAX berjalan di satu pekerja VM TPU. Penyertaan model TPU host tunggal memerlukan satu server model SAX di VM TPU host tunggal. Penyertaan model TPU multi-host memerlukan grup server model SAX di slice TPU multi-host. Penyertaan model multi-host saat ini tidak tersedia, tetapi dokumen ini memberikan contoh dengan model pengujian 175B untuk pratinjau.

Penyajian model SAX

Bagian berikut membahas alur kerja untuk menayangkan model bahasa menggunakan SAX. Contoh ini menggunakan model GPT-J 6B sebagai contoh untuk penayangan model host tunggal.

Sebelum memulai, instal image Docker Cloud TPU SAX di VM TPU Anda:

sudo usermod -a -G docker ${USER}
newgrp docker

gcloud auth configure-docker us-docker.pkg.dev

SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server"
SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server"
SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util"

SAX_VERSION=v1.0.0

export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}"

docker pull ${SAX_ADMIN_SERVER_IMAGE_URL}
docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
docker pull ${SAX_UTIL_IMAGE_URL}

Tetapkan beberapa variabel lain yang akan Anda gunakan nanti:

export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server"
export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server"
export SAX_CELL="/sax/test"

Contoh penayangan model host tunggal GPT-J 6B

Penyertaan model host tunggal berlaku untuk slice TPU host tunggal, yaitu v5litepod-1, v5litepod-4, dan v5litepod-8.

  1. Membuat cluster SAX

    1. Buat bucket penyimpanan Cloud Storage untuk cluster SAX:

      SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket}
      
      gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \
      --project=${PROJECT_ID}

      Anda mungkin memerlukan bucket penyimpanan Cloud Storage lain untuk menyimpan titik pemeriksaan.

      SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
    2. Gunakan SSH untuk terhubung ke VM TPU Anda di terminal guna meluncurkan server admin SAX:

      docker run \
      --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \
      -it \
      -d \
      --rm \
      --network host \
      --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \
      ${SAX_ADMIN_SERVER_IMAGE_URL}

      Anda dapat memeriksa log docker dengan:

      docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}

      Output dalam log akan terlihat mirip dengan berikut ini:

      I0829 01:22:31.184198       7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.347883       7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.360837      24 admin_server.go:44] Starting the server
      I0829 01:22:31.361420      24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8.
      I0829 01:22:31.361455      24 ipaddr.go:39] Skipping non-global IP address ::1/128.
      I0829 01:22:31.361462      24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64.
      I0829 01:22:31.361469      24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64.
      I0829 01:22:31.361474      24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64.
      I0829 01:22:31.361482      24 ipaddr.go:56] IPNet address 10.142.15.200
      I0829 01:22:31.361488      24 ipaddr.go:56] IPNet address 172.17.0.1
      I0829 01:22:31.456952      24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.609323      24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000"
      I0829 01:22:31.656021      24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.773245      24 mgr.go:781] Loaded manager state
      I0829 01:22:31.773260      24 mgr.go:784] Refreshing manager state every 10s
      I0829 01:22:31.773285      24 admin.go:350] Starting the server on port 10000
      I0829 01:22:31.773292      24 cloud.go:506] Starting the HTTP server on port 8080
  2. Luncurkan server model SAX host tunggal ke cluster SAX:

    Pada tahap ini, cluster SAX hanya berisi server admin SAX. Anda dapat terhubung ke VM TPU melalui SSH di terminal kedua untuk meluncurkan server model SAX di cluster SAX:

    docker run \
        --privileged  \
        -it \
        -d \
        --rm \
        --network host \
        --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
        --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1
  3. Mengonversi checkpoint model:

    Anda perlu menginstal PyTorch dan Transformers untuk mendownload checkpoint GPT-J dari EleutherAI:

    pip3 install accelerate
    pip3 install torch
    pip3 install transformers

    Untuk mengonversi checkpoint ke checkpoint SAX, Anda harus menginstal paxml:

    pip3 install paxml==1.1.0

    Skrip berikut akan mengonversi checkpoint GPT-J ke checkpoint SAX:

    python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b

    Setelah konversi selesai:

    ls checkpoint_00000000/

    Anda perlu membuat file commit_success dan menempatkannya di subdirektori:

    gcloud storage cp checkpoint_00000000 ${CHECKPOINT_PATH} --recursive
    
    touch commit_success.txt
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/metadata/
    gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/state/
    
  4. Memublikasikan model ke cluster SAX

    Sekarang Anda dapat memublikasikan GPT-J dengan checkpoint yang dikonversi di langkah sebelumnya.

    MODEL_NAME=gptjtokenizedbf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32
    REPLICA=1

    Untuk memublikasikan GPT-J (dan langkah-langkah setelahnya), gunakan SSH untuk terhubung ke VM TPU di terminal ketiga:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}

    Anda akan melihat banyak aktivitas dari log Docker server model hingga Anda melihat sesuatu seperti berikut untuk menunjukkan bahwa model telah berhasil dimuat:

    I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
    
  5. Membuat hasil inferensi

    Untuk GPT-J, input dan output harus diformat sebagai string ID token yang dipisahkan koma. Anda harus membuat token input teks.

    TEXT = "Below is an instruction that describes a task, paired with
    an input that provides further context. Write a response that
    appropriately completes the request.\n\n### Instruction\:\nSummarize the
    following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly
    international in scope on Tuesday. We're visiting Italy, Russia, the
    United Arab Emirates, and the Himalayan Mountains. Find out who's
    attempting to circumnavigate the globe in a plane powered partially by the
    sun, and explore the mysterious appearance of craters in northern Asia.
    You'll also get a view of Mount Everest that was previously reserved for
    climbers. On this page you will find today's show Transcript and a place
    for you to request to be on the CNN Student News Roll Call. TRANSCRIPT .
    Click here to access the transcript of today's CNN Student News program.
    Please note that there may be a delay between the time when the video is
    available and when the transcript is published. CNN Student News is
    created by a team of journalists who consider the Common Core State
    Standards, national standards in different subject areas, and state
    standards when producing the show. ROLL CALL . For a chance to be
    mentioned on the next CNN Student News, comment on the bottom of this page
    with your school name, mascot, city and state. We will be selecting
    schools from the comments of the previous show. You must be a teacher or a
    student age 13 or older to request a mention on the CNN Student News Roll
    Call! Thank you for using CNN Student News!\n\n### Response\:

    Anda bisa mendapatkan string ID token melalui tokenizer EleutherAI/gpt-j-6b:

    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b")                  :

    Buat token teks input:

    encoded_example = tokenizer(TEXT)
    input_ids = encoded_example.input_ids
    INPUT_STR = ",".join([str(input_id) for input_id in input_ids])

    Anda dapat melihat string ID token yang mirip dengan yang berikut ini:

    >>> INPUT_STR
    '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'

    Untuk membuat ringkasan artikel:

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          ${INPUT_STR}
    

    Anda akan melihat sesuatu yang mirip dengan:

    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    |                                                                                                                                                    GENERATE                                                                                                                                                    |    SCORE     |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256  |  -0.91842502 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256     |   -1.1726116 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256                            |   -1.2472695 |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+

    Untuk mendetokenisasi string ID token output:

    output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')]
    OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
    

    Anda dapat mengharapkan teks yang didetoksifikasi sebagai:

    >>> OUTPUT_TEXT
    'This page includes the show Transcript.\nUse the Transcript to help
    students with reading comprehension and vocabulary.\nAt the bottom of
    the page, comment for a chance to be mentioned on CNN Student News.
    You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
  6. Bersihkan penampung Docker dan bucket penyimpanan Cloud Storage Anda.

Pratinjau penayangan model multi-host 175B

Beberapa model bahasa besar akan memerlukan slice TPU multi-host, yaitu v5litepod-16 dan yang lebih baru. Dalam kasus tersebut, semua host TPU multi-host harus memiliki salinan server model SAX, dan semua server model berfungsi sebagai grup server model SAX untuk menayangkan model besar di slice TPU multi-host.

  1. Membuat cluster SAX baru

    Anda dapat mengikuti langkah yang sama untuk Membuat cluster SAX di panduan GPT-J untuk membuat cluster SAX baru dan server admin SAX.

    Atau, jika sudah memiliki cluster SAX, Anda dapat meluncurkan server model multi-host ke cluster SAX.

  2. Meluncurkan server model SAX multi-host ke cluster SAX

    Gunakan perintah yang sama untuk membuat slice TPU multi-host seperti yang Anda gunakan untuk slice TPU host tunggal, cukup tentukan jenis akselerator multi-host yang sesuai:

    ACCELERATOR_TYPE=v5litepod-32
    ZONE=us-east1-c
    
    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --service-account ${SERVICE_ACCOUNT} \
      --reserved
    

    Untuk mengambil image server model SAX ke semua host/pekerja TPU dan meluncurkannya:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --worker=all \
      --command="
        gcloud auth configure-docker \
          us-docker.pkg.dev
        # Pull sax model server image
        docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
        # Run model server
        docker run \
          --privileged  \
          -it \
          -d \
          --rm \
          --network host \
          --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
          --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
          ${SAX_MODEL_SERVER_IMAGE_URL} \
            --sax_cell=${SAX_CELL} \
            --port=10001 \
            --platform_chip=tpuv4 \
            --platform_topology=1x1"
    
  3. Memublikasikan model ke cluster SAX

    Contoh ini menggunakan model LmCloudSpmd175B32Test:

    MODEL_NAME=lmcloudspmd175b32test
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test
    CHECKPOINT_PATH=None
    REPLICA=1

    Untuk memublikasikan model pengujian:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
  4. Membuat hasil inferensi

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          "Q:  Who is Harry Porter's mother? A\: "
    

    Perhatikan bahwa karena contoh ini menggunakan model pengujian dengan bobot acak, output-nya mungkin tidak bermakna.

  5. Pembersihan

    Hentikan container docker:

    docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME}
    docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}

    Hapus bucket penyimpanan admin Cloud Storage dan bucket penyimpanan data apa pun menggunakan gcloud CLI seperti yang ditunjukkan di bawah.

    gcloud storage rm gs://${SAX_ADMIN_STORAGE_BUCKET} --recursive --continue-on-error
    gcloud storage rm gs://${SAX_DATA_STORAGE_BUCKET} --recursive --continue-on-error