Executar o código do PyTorch em frações de TPU

Antes de executar os comandos neste documento, verifique se você seguiu as instruções da seção Configurar uma conta e um projeto do Cloud TPU.

Depois de executar o código PyTorch em uma VM de TPU única, é possível escalonar verticalmente o código executando-o em uma fração de TPU. As frações de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código PyTorch em frações de TPU.

Criar uma fração do Cloud TPU

  1. Defina algumas variáveis de ambiente para facilitar o uso dos comandos.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-32
    export RUNTIME_VERSION=v2-alpha-tpuv5

    Descrições de variáveis de ambiente

    Variável Descrição
    PROJECT_ID O ID do projeto do Google Cloud . Use um projeto atual ou crie um novo.
    TPU_NAME O nome da TPU.
    ZONE A zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU.
    ACCELERATOR_TYPE O tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU.
    RUNTIME_VERSION A versão do software do Cloud TPU.

  2. Crie a VM de TPU executando o seguinte comando:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --version=${RUNTIME_VERSION}

Instalar o PyTorch/XLA na sua fração

Depois de criar a fração de TPU, instale o PyTorch em todos os hosts dessa fração. Para isso, use o comando gcloud compute tpus tpu-vm ssh com os parâmetros --worker=all e --commamnd.

Se os comandos a seguir falharem devido a um erro de conexão SSH, talvez seja porque as VMs de TPU não têm endereços IP externos. Para acessar uma VM de TPU sem um endereço IP externo, siga as instruções em Conectar-se a uma VM de TPU sem um endereço IP público.

  1. Instale o PyTorch/XLA em todos os workers da VM de TPU:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. Clone o XLA em todos os workers da VM de TPU:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="git clone https://github.com/pytorch/xla.git"

Executar um script de treinamento na fração da TPU

Execute o script de treinamento em todos os workers. O script de treinamento usa uma estratégia de fragmentação de programa único e vários dados (SPMD). Para mais informações sobre SPMD, consulte o Guia do usuário do SPMD para PyTorch/XLA.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --worker=all \
   --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py  \
   --fake_data \
   --model=resnet50  \
   --num_epochs=1 2>&1 | tee ~/logs.txt"

O treinamento leva cerca de 15 minutos. Ao final, você vai receber uma mensagem parecida com esta:

Epoch 1 test end 23:49:15, Accuracy=100.00
     10.164.0.11 [0] Max Accuracy: 100.00%

Limpeza

Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se da instância do Cloud TPU, caso ainda não tenha feito isso:

    (vm)$ exit

    Agora o prompt precisa ser username@projectname, mostrando que você está no Cloud Shell.

  2. Exclua os recursos do Cloud TPU.

    $ gcloud compute tpus tpu-vm delete  \
        --zone=${ZONE}
  3. Execute gcloud compute tpus tpu-vm list para verificar se os recursos foram excluídos. A exclusão pode levar vários minutos. A saída do comando abaixo não pode incluir nenhum dos recursos criados neste tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}