Executar o código JAX em frações de TPU
Antes de executar os comandos neste documento, verifique se você seguiu as instruções da seção Configurar uma conta e um projeto do Cloud TPU.
Depois de executar o código JAX em uma única placa de TPU, é possível escaloná-lo verticalmente com a execução em uma fração de TPU. As frações de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código JAX nas frações de TPU. Para informações mais detalhadas, consulte Como usar o JAX em ambientes com vários hosts e processos.
Criar uma fração do Cloud TPU
Crie algumas variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Descrições de variáveis de ambiente
Variável Descrição PROJECT_IDO ID do projeto do Google Cloud . Use um projeto atual ou crie um novo. TPU_NAMEO nome da TPU. ZONEA zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU. ACCELERATOR_TYPEO tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSIONA versão do software do Cloud TPU. Crie uma fração de TPU usando o comando
gcloud. Por exemplo, para criar uma fração v5litepod-32, use o seguinte comando:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
Instalar o JAX na sua fração
Depois de criar a fração de TPU, é necessário instalar o JAX em todos os hosts
dessa fração. Para isso, use o comando gcloud compute tpus tpu-vm ssh com os
parâmetros --worker=all e --commamnd.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Executar o código JAX na fração
Para executar o código JAX em uma fração de TPU, é preciso executá-lo em cada host
dessa fração. A chamada jax.device_count() para de responder até ser
chamada em cada host na fração. O exemplo a seguir ilustra como
executar um cálculo JAX em uma fração de TPU.
Preparar o código
Você precisa da versão gcloud >= 344.0.0 (para o
comando scp).
Use gcloud --version para verificar a versão de gcloud
e execute gcloud components upgrade, se necessário.
Crie um arquivo chamado example.py executando o seguinte código:
import jax
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Copie example.py para todas as VMs de worker de TPU na fração.
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
Se você nunca usou o comando scp, pode
receber um erro parecido com este:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
Para resolver o erro, execute o comando ssh-add conforme exibido na
mensagem de erro e, em seguida, execute-o novamente.
Executar o código na fração
Inicie o programa example.py em todas as VMs:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
Saída (produzida com uma fração v5litepod-32):
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
Limpeza
Quando terminar de usar a VM de TPU, siga as etapas abaixo para limpar os recursos.
Exclua os recursos do Cloud TPU e do Compute Engine.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Execute
gcloud compute tpus execution-groups listpara verificar se os recursos foram excluídos. A exclusão pode levar vários minutos. A saída do comando abaixo não pode incluir nenhum dos recursos criados neste tutorial:$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}