Execute um cálculo numa VM de TPU na nuvem com o JAX
Este documento fornece uma breve introdução ao trabalho com o JAX e a Cloud TPU.
Antes de começar
Antes de executar os comandos neste documento, tem de criar uma conta do Google Cloud, instalar a CLI Google Cloud e configurar o comando gcloud
. Para mais informações, consulte o artigo Configure o ambiente do Cloud TPU.
Crie uma VM da Cloud TPU com gcloud
Defina algumas variáveis de ambiente para facilitar a utilização dos comandos.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Descrições das variáveis de ambiente
Variável Descrição PROJECT_ID
O seu Google Cloud ID do projeto. Use um projeto existente ou crie um novo. TPU_NAME
O nome da TPU. ZONE
A zona na qual criar a VM da TPU. Para mais informações sobre as zonas suportadas, consulte o artigo Regiões e zonas de TPUs. ACCELERATOR_TYPE
O tipo de acelerador especifica a versão e o tamanho do Cloud TPU que quer criar. Para mais informações sobre os tipos de aceleradores suportados para cada versão da TPU, consulte o artigo Versões da TPU. RUNTIME_VERSION
A versão do software do Cloud TPU. Crie a VM de TPU executando o seguinte comando a partir de um Cloud Shell ou do terminal do computador onde a CLI Google Cloud está instalada.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Estabeleça ligação à sua VM da Cloud TPU
Estabeleça ligação à VM de TPU através de SSH com o seguinte comando:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Se não conseguir estabelecer ligação a uma VM de TPU através de SSH, pode dever-se ao facto de a VM de TPU não ter um endereço IP externo. Para aceder a uma VM da TPU sem um endereço IP externo, siga as instruções em Estabeleça ligação a uma VM da TPU sem um endereço IP público.
Instale o JAX na sua VM do Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Verificação do sistema
Verifique se o JAX consegue aceder à TPU e executar operações básicas:
Inicie o intérprete do Python 3:
(vm)$ python3
>>> import jax
Apresente o número de núcleos da TPU disponíveis:
>>> jax.device_count()
É apresentado o número de núcleos da TPU. O número de núcleos apresentados depende da versão da TPU que está a usar. Para mais informações, consulte o artigo Versões da TPU.
Fazer um cálculo
>>> jax.numpy.add(1, 1)
É apresentado o resultado da adição numpy:
Resultado do comando:
Array(2, dtype=int32, weak_type=True)
Saia do intérprete Python
>>> exit()
Executar código JAX numa VM de TPU
Agora, pode executar qualquer código JAX que quiser. Os exemplos do Flax são um ótimo ponto de partida para executar modelos de ML padrão no JAX. Por exemplo, para preparar uma rede convolucional MNIST básica:
Instale as dependências dos exemplos do Flax:
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Instale o Flax:
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Execute o script de preparação do Flax MNIST:
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
O script transfere o conjunto de dados e inicia a preparação. O resultado do script deve ter o seguinte aspeto:
I0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Limpar
Para evitar incorrer em cobranças na sua Google Cloud conta pelos recursos usados nesta página, siga estes passos.
Quando terminar de usar a VM de TPU, siga estes passos para limpar os recursos.
Desassocie a instância do Cloud TPU, se ainda não o tiver feito:
(vm)$ exit
O comando deve ser agora username@projectname, o que indica que está no Cloud Shell.
Elimine a sua Cloud TPU:
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Valide se os recursos foram eliminados executando o seguinte comando. Certifique-se de que a TPU já não está listada. A eliminação pode demorar alguns minutos.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
Notas de desempenho
Seguem-se alguns detalhes importantes particularmente relevantes para a utilização de TPUs no JAX.
Preenchimento
Uma das causas mais comuns do desempenho lento nas TPUs é a introdução de preenchimento não intencional:
- As matrizes na Cloud TPU são divididas em mosaicos. Isto implica preencher uma das dimensões com um múltiplo de 8 e uma dimensão diferente com um múltiplo de 128.
- A unidade de multiplicação de matrizes tem o melhor desempenho com pares de matrizes grandes que minimizam a necessidade de preenchimento.
Tipo de dados bfloat16
Por predefinição, a multiplicação de matrizes no JAX em TPUs usa bfloat16
com acumulação float32. Isto pode ser controlado com o argumento de precisão nas chamadas de função jax.numpy
relevantes (matmul, dot, einsum, etc.). Concretamente:
precision=jax.lax.Precision.DEFAULT
: usa precisão mista bfloat16 (mais rápido)precision=jax.lax.Precision.HIGH
: usa vários passes de MXU para alcançar uma maior precisãoprecision=jax.lax.Precision.HIGHEST
: usa ainda mais passes de MXU para alcançar a precisão total de float32
O JAX também adiciona o dtype bfloat16, que pode usar para converter explicitamente matrizes em
bfloat16
. Por exemplo,
jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
O que se segue?
Para mais informações sobre a Cloud TPU, consulte: