在 Cloud TPU v5e 上运行 SAX
SAX 集群(SAX 单元)
SAX 管理服务器和 SAX 模型服务器是运行 SAX 集群的两个基本组件。
SAX 管理服务器
SAX 管理服务器会监控和协调 SAX 集群中的所有 SAX 模型服务器。在 SAX 集群中,您可以启动多个 SAX 管理服务器,其中只有一个 SAX 管理服务器通过主节点选举处于活跃状态,其他服务器为备用服务器。当活跃管理服务器发生故障时,备用管理服务器将变为活跃管理服务器。活跃 SAX 管理服务器会将模型副本和推理请求分配给可用的 SAX 模型服务器。
SAX 管理存储桶
每个 SAX 集群都需要一个 Cloud Storage 存储桶来存储 SAX 集群中 SAX 管理服务器和 SAX 模型服务器的配置和位置。
SAX 模型服务器
SAX 模型服务器会加载模型检查点并使用 GSPMD 运行推理。SAX 模型服务器在单个 TPU 虚拟机工作器上运行。单主机 TPU 模型部署需要在单主机 TPU 虚拟机上运行单个 SAX 模型服务器。多主机 TPU 模型部署需要在多主机 TPU 切片上运行一组 SAX 模型服务器。目前无法提供多主机模型部署,但本文档提供了一个包含 175B 测试模型的示例供您预览。
SAX 模型部署
以下部分逐步介绍了使用 SAX 部署语言模型的工作流。它使用 GPT-J 6B 模型作为单主机模型部署的示例。
在开始之前,请在 TPU 虚拟机上安装 Cloud TPU SAX Docker 映像:
sudo usermod -a -G docker ${USER} newgrp docker gcloud auth configure-docker us-docker.pkg.dev SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server" SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server" SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util" SAX_VERSION=v1.0.0 export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION} export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION} export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}" docker pull ${SAX_ADMIN_SERVER_IMAGE_URL} docker pull ${SAX_MODEL_SERVER_IMAGE_URL} docker pull ${SAX_UTIL_IMAGE_URL}
设置一些稍后要用到的其他变量:
export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server" export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server" export SAX_CELL="/sax/test"
GPT-J 6B 单主机模型部署示例
单主机模型部署适用于单主机 TPU 切片,即 v5litepod-1、v5litepod-4 和 v5litepod-8。
创建 SAX 集群
为 SAX 集群创建一个 Cloud Storage 存储桶:
SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket} gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \ --project=${PROJECT_ID}
您可能需要另一个 Cloud Storage 存储桶来存储检查点。
SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
通过 SSH 在终端中连接到 TPU 虚拟机,以启动 SAX 管理服务器:
docker run \ --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \ -it \ -d \ --rm \ --network host \ --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \ ${SAX_ADMIN_SERVER_IMAGE_URL}
您可以通过以下方式查看 Docker 日志:
docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}
日志中的输出类似于以下内容:
I0829 01:22:31.184198 7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.347883 7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.360837 24 admin_server.go:44] Starting the server I0829 01:22:31.361420 24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8. I0829 01:22:31.361455 24 ipaddr.go:39] Skipping non-global IP address ::1/128. I0829 01:22:31.361462 24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64. I0829 01:22:31.361469 24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64. I0829 01:22:31.361474 24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64. I0829 01:22:31.361482 24 ipaddr.go:56] IPNet address 10.142.15.200 I0829 01:22:31.361488 24 ipaddr.go:56] IPNet address 172.17.0.1 I0829 01:22:31.456952 24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.609323 24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000" I0829 01:22:31.656021 24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.773245 24 mgr.go:781] Loaded manager state I0829 01:22:31.773260 24 mgr.go:784] Refreshing manager state every 10s I0829 01:22:31.773285 24 admin.go:350] Starting the server on port 10000 I0829 01:22:31.773292 24 cloud.go:506] Starting the HTTP server on port 8080
将单主机 SAX 模型服务器发布到 SAX 集群中:
此时,SAX 集群仅包含 SAX 管理服务器。您可以在第二个终端中通过 SSH 连接到 TPU 虚拟机,以在 SAX 集群中启动 SAX 模型服务器:
docker run \ --privileged \ -it \ -d \ --rm \ --network host \ --name ${SAX_MODEL_SERVER_DOCKER_NAME} \ --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ ${SAX_MODEL_SERVER_IMAGE_URL} \ --sax_cell=${SAX_CELL} \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=1x1
转换模型检查点:
您需要安装 PyTorch 和 Transformer 才能从 EleutherAI 下载 GPT-J 检查点:
pip3 install accelerate pip3 install torch pip3 install transformers
如需将该检查点转换为 SAX 检查点,您需要安装
paxml:pip3 install paxml==1.1.0
以下脚本可将 GPT-J 检查点转换为 SAX 检查点:
python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b
转换完成后:
ls checkpoint_00000000/您需要创建一个 commit_success 文件并将其放入子目录中:
gcloud storage cp checkpoint_00000000 ${CHECKPOINT_PATH} --recursive touch commit_success.txt gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/ gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/metadata/ gcloud storage cp commit_success.txt ${CHECKPOINT_PATH}/state/将模型发布到 SAX 集群
现在,您可以使用上一步中转换的检查点发布 GPT-J。
MODEL_NAME=gptjtokenizedbf16bs32 MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32 REPLICA=1
如需发布 GPT-J(后续步骤),请使用 SSH 在第三个终端中连接到 TPU 虚拟机:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ publish \ ${SAX_CELL}/${MODEL_NAME} \ ${MODEL_CONFIG_PATH} \ ${CHECKPOINT_PATH} \ ${REPLICA}
您会看到模型服务器 Docker 日志中的大量活动,直到您看到如下所示的内容,表明模型已成功加载:
I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
生成推理结果
对于 GPT-J,输入和输出必须采用以英文逗号分隔的 token ID 字符串的格式。您需要对文本输入进行词元化处理。
TEXT = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction\:\nSummarize the following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly international in scope on Tuesday. We're visiting Italy, Russia, the United Arab Emirates, and the Himalayan Mountains. Find out who's attempting to circumnavigate the globe in a plane powered partially by the sun, and explore the mysterious appearance of craters in northern Asia. You'll also get a view of Mount Everest that was previously reserved for climbers. On this page you will find today's show Transcript and a place for you to request to be on the CNN Student News Roll Call. TRANSCRIPT . Click here to access the transcript of today's CNN Student News program. Please note that there may be a delay between the time when the video is available and when the transcript is published. CNN Student News is created by a team of journalists who consider the Common Core State Standards, national standards in different subject areas, and state standards when producing the show. ROLL CALL . For a chance to be mentioned on the next CNN Student News, comment on the bottom of this page with your school name, mascot, city and state. We will be selecting schools from the comments of the previous show. You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call! Thank you for using CNN Student News!\n\n### Response\:
您可以通过 EleutherAI/gpt-j-6b 词元化器获取 token ID 字符串:
from transformers import GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b") :
对输入文本进行词元化处理:
encoded_example = tokenizer(TEXT) input_ids = encoded_example.input_ids INPUT_STR = ",".join([str(input_id) for input_id in input_ids])
token ID 字符串可能类似于以下内容:
>>> INPUT_STR '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'
如需生成文章摘要,请执行以下操作:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ lm.generate \ ${SAX_CELL}/${MODEL_NAME} \ ${INPUT_STR}您可能会看到如下所示的内容:
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+ | GENERATE | SCORE | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+ | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256 | -0.91842502 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -1.1726116 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -1.2472695 | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
如需对输出 token ID 字符串进行去词元化处理,请执行以下操作:
output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')] OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)去词元化处理的文本可能为:
>>> OUTPUT_TEXT 'This page includes the show Transcript.\nUse the Transcript to help students with reading comprehension and vocabulary.\nAt the bottom of the page, comment for a chance to be mentioned on CNN Student News. You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
清理 Docker 容器和 Cloud Storage 存储桶。
175B 多主机模型部署预览
某些大语言模型需要多主机 TPU 切片,即 v5litepod-16 及更高版本。在这些情况下,所有多主机 TPU 主机都需要有一个 SAX 模型服务器副本,并且所有模型服务器都作为 SAX 模型服务器组在多主机 TPU 切片上部署大型模型。
创建新的 SAX 集群
您可以按照 GPT-J 演示中的“创建 SAX 集群”步骤创建新的 SAX 集群和 SAX 管理服务器。
或者,如果您已有 SAX 集群,则可以将多主机模型服务器发布到 SAX 集群中。
将多主机 SAX 模型服务器发布到 SAX 集群中
使用创建单主机 TPU 切片时使用的相同命令创建多主机 TPU 切片,只需指定相应的多主机加速器类型:
ACCELERATOR_TYPE=v5litepod-32 ZONE=us-east1-c gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --service-account ${SERVICE_ACCOUNT} \ --reserved如需将 SAX 模型服务器映像拉取到所有 TPU 主机/工作器并启动它们,请执行以下操作:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=" gcloud auth configure-docker \ us-docker.pkg.dev # Pull sax model server image docker pull ${SAX_MODEL_SERVER_IMAGE_URL} # Run model server docker run \ --privileged \ -it \ -d \ --rm \ --network host \ --name ${SAX_MODEL_SERVER_DOCKER_NAME} \ --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ ${SAX_MODEL_SERVER_IMAGE_URL} \ --sax_cell=${SAX_CELL} \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=1x1"将模型发布到 SAX 集群
此示例使用 LmCloudSpmd175B32Test 模型:
MODEL_NAME=lmcloudspmd175b32test MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test CHECKPOINT_PATH=None REPLICA=1
如需发布测试模型,请执行以下操作:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ publish \ ${SAX_CELL}/${MODEL_NAME} \ ${MODEL_CONFIG_PATH} \ ${CHECKPOINT_PATH} \ ${REPLICA}
生成推理结果
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ lm.generate \ ${SAX_CELL}/${MODEL_NAME} \ "Q: Who is Harry Porter's mother? A\: "请注意,由于此示例使用的是具有随机权重的测试模型,因此输出可能没有意义。
清理
停止 Docker 容器:
docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME} docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}
使用 gcloud CLI 删除 Cloud Storage 管理存储桶和任何数据存储桶,如下所示。
gcloud storage rm gs://${SAX_ADMIN_STORAGE_BUCKET} --recursive --continue-on-error gcloud storage rm gs://${SAX_DATA_STORAGE_BUCKET} --recursive --continue-on-error