Devido à natureza distribuída do JAX com o Pathways, algumas operações podem não ser bem dimensionadas devido a sobrecargas de comunicação. Embora o Pathways minimize essas sobrecargas com recursos como o envio assíncrono, há algumas coisas que você precisa saber ao migrar cargas de trabalho do JAX para o Pathways ou dimensionar uma carga de trabalho do JAX com o Pathways para um grande número de aceleradores.
Antes de começar
Você precisa ter:
- Ferramentas do Kubernetes instaladas
- A CLI gcloud instalada
- A API TPU ativada
- API do Google Kubernetes Engine ativada
Índice de processos
O JAX com o Pathways trata todos os dispositivos no cluster do Pathways como locais. Isso simplifica o gerenciamento de dispositivos e permite que o JAX utilize todos os recursos disponíveis. Na prática, isso significa:
jax.process_index()é sempre 0 para todos os dispositivos.jax.devices()ejax.local_devices()retornam todos os dispositivos TPU em todo o job.
Tipo de hardware e colocalização
Para melhor desempenho, coloque todos os componentes do Pathways e o job do usuário na
mesma Google Cloud zona do Cloud. Use uma CPU grande, como o proxy IFRT e o gerenciador de recursos. Recomendamos pelo menos um n2-standard-64 dedicado, que vem com 64 vCPUs e 256 GB de memória.
PathwaysUtils
O Pathways-utils é um repositório do GitHub baseado em Python que oferece utilitários e ferramentas essenciais para simplificar a implantação e a execução de cargas de trabalho do JAX na arquitetura do Pathways on Cloud. Esse pacote processa as adaptações necessárias para o ambiente de nuvem, permitindo que os desenvolvedores do JAX se concentrem nos fluxos de trabalho principais de machine learning com configuração mínima específica da plataforma. Especificamente, ele oferece:
- Um back-end JAX "proxy": esse back-end personalizado permite que o aplicativo JAX use a infraestrutura do Pathways definindo a variável de ambiente
JAX_PLATFORMS=proxy. - Utilitários de criação de perfil integrados: recursos de criação de perfil que permitem entender o desempenho do aplicativo. Ao usar APIs de criação de perfil JAX padrão, como
jax.profiler.start_traceejax.profiler.start_server, é possível criar o perfil não apenas do código JAX, mas também dos componentes do Pathways subjacentes, oferecendo uma visão holística da execução no ambiente de nuvem. - Checkpoint distribuído com o Orbax: um handler de checkpoint Orbax personalizado que permite usar checkpoints distribuídos e restaurar os checkpoints ao usar a biblioteca Orbax no ambiente do Pathways. Essa integração funciona sem exigir mudanças no código de checkpoint Orbax atual, desde que ele importe
pathwaysutils. - Primitivos de treinamento elástico: fornece primitivos de treinamento elástico fundamentais que podem ser usados para criar fluxos de trabalho de treinamento robustos e escalonáveis usando o Pathways. Esses primitivos permitem que os jobs de treinamento se adaptem dinamicamente às mudanças nos recursos disponíveis, melhorando a eficiência e a resiliência em ambientes de nuvem.
Criação de checkpoints
Orbax é totalmente testado com o Pathways para
checkpoint distribuído e restauração com o Cloud Storage. Quando você
define a variável de ambiente ENABLE_PATHWAYS_PERSISTENCE=1 e chama
import pathwaysutils; pathwaysutils.initialize() no seu train.py, um
ArrayHandler personalizado é registrado para processar com eficiência as operações de checkpoint
pelo proxy IFRT, permitindo que os workers do Pathways em aceleradores salvem e restaurem dados diretamente.
Python colocalizado
O Python colocalizado é uma API JAX de código aberto que permite executar o código Python especificado pelo usuário diretamente nos hosts de TPU ou GPU, o que é mais simples no JAX de vários controladores JAX. Isso permite que tarefas mais intensivas em computação, como carregamento de dados e checkpoint, evitem a transferência de dados entre o cliente e as máquinas de TPU. Para configurar o cluster do Pathways para executar a API Python JAX colocalizada, siga as instruções no arquivo README do Python colocalizado. Essas instruções explicam como iniciar um sidecar Python colocalizado junto com os workers do Pathways.
Carregamento de dados
Durante o treinamento, lotes de um conjunto de dados são carregados repetidamente para alimentar o modelo. Ter um carregador de dados assíncrono e eficiente que fragmenta o lote em vários hosts é importante para evitar que os aceleradores fiquem sem trabalho. Ao executar o treinamento com o Pathways, o carregador de dados é executado em uma VM de CPU (ao contrário de uma VM de TPU, que é usada em configurações de vários controladores) e envia dados para VMs de TPU. Isso gera uma latência maior na leitura de dados, mas é parcialmente atenuado pela leitura antecipada de X lotes no host da CPU e pelo envio assíncrono dos dados lidos para as TPUs. Essa solução é suficiente ao executar em escala pequena a média.
Para um desempenho ideal em escala, recomendamos a colocalização do pipeline de dados de entrada usando o Python colocalizado para executar o pipeline de dados diretamente nos aceleradores. Isso elimina o gargalo da CPU e aproveita as interconexões rápidas da TPU para transferência de dados.
É possível encontrar uma implementação de referência da migração de um pipeline de entrada baseado em TFDS
input pipeline na RemoteIterator implementação em
multihost_dataloading.py.
Essa implementação funciona no JAX de vários controladores e no Pathways de maneira distribuída usando a API Python JAX colocalizada.
Controle de versões do Jax
As versões do Pathways são fortemente acopladas às versões do JAX para garantir compatibilidade e estabilidade. Para evitar possíveis problemas, verifique se os artefatos do Pathways e a versão do JAX estão alinhados. Cada versão do Pathways especifica claramente as versões compatíveis do JAX usando uma tag do formulário jax-<version>.
Cache de compilação
O cache de compilação persistente do Pathways é um recurso que permite que os servidores do Pathways armazenem executáveis XLA compilados em um local persistente, como o Cloud Storage, para evitar a compilação redundante. Esse recurso é ativado por padrão. O local do cache é transmitido como a flag --gcs_scratch_location para o gerenciador de recursos e os contêineres de worker do Pathways. Para manter os custos de armazenamento associados no mínimo, o cache anexa uma política de ciclo de vida ao local do Cloud Storage. Há um limite de 50 políticas por bucket do Cloud Storage. Portanto, recomendamos o uso de um local comum do Cloud Storage em todas as cargas de trabalho.
Esse cache é semelhante ao cache de compilação do JAX
que é desativado por pathwaysutils.initialize() para cargas de trabalho do Pathways.
As seguintes permissões do Cloud Storage são necessárias para o cache de compilação:
storage.buckets.get: para recuperar metadados do bucket.storage.buckets.update: essencial para que o Pathways configure políticas de ciclo de vida de objetos para aplicar o TTL para remoção de cache.storage.objects.list: para listar objetos de cache existentes no bucket.storage.objects.create: para gravar novos executáveis compilados no cache.storage.objects.get: para ler executáveis armazenados em cache do bucket.
Criação de perfil
É possível usar o criador de perfil JAX para gerar rastreamentos de um programa JAX. Há duas maneiras comuns com suporte do Pathways:
- Programática
- Capturar perfis de maneira programática no código JAX
- Manual
- Capturar perfis on demand depois de iniciar o servidor do criador de perfil no código JAX
Em ambos os casos, os perfis são gravados em um bucket do Cloud Storage. Vários arquivos de rastreamento serão criados no bucket do Cloud Storage, possivelmente em pastas de carimbo de data/hora diferentes. Por exemplo:
- Processo principal do Python que invocou o rastreamento (normalmente a VM do notebook):
<jax-client-vm-name>.xplane.pb - Proxy IFRT do Pathways:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Gerenciador de recursos do Pathways:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Workers do Pathways:
server.*<tpu-node-name>.xplane.pb
Esses arquivos de rastreamento podem ser analisados com o TensorBoard executando o comando a seguir. Para mais informações sobre o TensorBoard e todas as ferramentas de criação de perfil, consulte Otimizar o desempenho do TensorFlow usando o Profiler.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Substitua:
BUCKET: um bucket do Cloud Storage para armazenar os arquivos de rastreamentoPREFIX: um caminho no bucket do Cloud Storage para armazenar os arquivos de rastreamento
Captura de perfil programática
Capture um perfil de dentro do código. Os perfis são salvos em
gs://<bucket>/<prefix> em um diretório de carimbo de data/hora
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Captura de perfil manual
Para capturar manualmente as informações do perfil, inicie o servidor do criador de perfil no código Python:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functionally a no-op
Enquanto o servidor do criador de perfil estiver em execução, é possível capturar um perfil e exportar os dados para o local de destino do Cloud Storage:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port>
É possível encontrar informações de tempo para métodos de cliente de proxy IFRT, como Compile e Execute, no rastreamento do programa. Esses eventos, que detalham as interações com o servidor gRPC do proxy IFRT durante a compilação e a execução, aparecem na linha de execução chamada GrpcClientSessionUserFuturesWorkQueue. Ao examinar essa linha de execução no rastreamento, é possível ter insights sobre a performance dessas operações.
Flags XLA
Ao usar o Pathways, é necessário definir as flags XLA no contêiner de proxy do Pathways. É possível fazer isso usando o XPK ou a API PathwaysJob.
Ao usar o XPK, defina flags XLA como as seguintes:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Ao usar a API PathwaysJob, defina flags XLA como as seguintes:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Substitua:
USER: seu Google Cloud nome de usuáriovalue[n]: as flags XLA que você quer definir
Dump HLO
Para analisar detalhadamente as entradas do High Level Optimizer (HLO) fornecidas ao compilador XLA, é possível configurar o Pathways para fazer o dump do HLO em um local especificado do Cloud Storage da seguinte maneira:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"
A seguir
- Criar um cluster do GKE com o Pathways
- Inferência de vários hosts com o Pathways
- Cargas de trabalho em lote com o Pathways
- Modo interativo do Pathways
- Treinamento resiliente com o Pathways
- Solução de problemas do Pathways