Portar cargas de trabalho do JAX para o Pathways

Devido à natureza distribuída do JAX com Pathways, algumas operações podem não ser escalonadas bem devido a sobrecargas de comunicação. Embora o Pathways minimize esses overheads com recursos como envio assíncrono, há algumas coisas que você precisa saber ao portar cargas de trabalho JAX para o Pathways ou escalonar uma carga de trabalho JAX com o Pathways para um grande número de aceleradores.

Antes de começar

Você precisa ter:

Índice do processo

O JAX com Pathways trata todos os dispositivos no cluster do Pathways como locais. Isso simplifica o gerenciamento de dispositivos e permite que o JAX use todos os recursos disponíveis. Na prática, isso significa:

  • jax.process_index() é sempre 0 para todos os dispositivos.
  • jax.devices() e jax.local_devices() retornam todos os dispositivos de TPU em todo o job.

Tipo de hardware e colocation

Para ter o melhor desempenho, coloque todos os componentes dos programas e o trabalho do usuário na mesma Google Cloud zona do Cloud. Use uma CPU grande, como o proxy IFRT e o gerenciador de recursos. Recomendamos pelo menos um n2-standard-64 dedicado, que vem com 64 vCPUs e 256 GB de memória.

PathwaysUtils

O Pathways-utils é um repositório do GitHub baseado em Python que oferece utilitários e ferramentas essenciais para simplificar a implantação e a execução de cargas de trabalho do JAX na arquitetura Pathways no Cloud. Esse pacote processa as adaptações necessárias para o ambiente de nuvem, permitindo que os desenvolvedores do JAX se concentrem nos principais fluxos de trabalho de aprendizado de máquina com configuração mínima específica da plataforma. Especificamente, ele oferece:

  • Um back-end JAX "proxy": esse back-end personalizado permite que seu aplicativo JAX use a infraestrutura do Pathways definindo a variável de ambiente JAX_PLATFORMS=proxy.
  • Utilitários de criação de perfis integrados: recursos de criação de perfis que permitem entender o desempenho do seu aplicativo. Ao usar APIs de criação de perfil JAX padrão, como jax.profiler.start_trace e jax.profiler.start_server, é possível criar perfis não apenas do seu código JAX, mas também dos componentes do Pathways subjacentes, fornecendo uma visão holística da execução no ambiente de nuvem.
  • Checkpoint distribuído com Orbax: um manipulador de checkpoint Orbax personalizado que permite usar checkpoints distribuídos e restaurar seus checkpoints ao usar a biblioteca Orbax no ambiente do Pathways. Essa integração visa funcionar sem exigir mudanças no código de checkpoint do Orbax desde que ele importe pathwaysutils.
  • Primitivos de treinamento elástico: fornecem primitivos de treinamento elástico fundamentais que podem ser usados para criar fluxos de trabalho de treinamento robustos e escalonáveis usando os programas. Essas primitivas permitem que seus jobs de treinamento se adaptem dinamicamente às mudanças nos recursos disponíveis, melhorando a eficiência e a resiliência em ambientes de nuvem.

Como estabelecer pontos de verificação

O Orbax é totalmente testado com o Pathways para checkpointing e restauração distribuídos com o Cloud Storage. Quando você chama import pathwaysutils; pathwaysutils.initialize() no seu train.py, um ArrayHandler personalizado é registrado para processar com eficiência as operações de checkpoint pelo proxy IFRT, permitindo que os workers do Pathways em aceleradores salvem e restaurem dados diretamente.

Python colocados

O Python colocated é uma API JAX de código aberto que permite executar o código Python especificado pelo usuário diretamente nos hosts de TPU ou GPU, o que é mais simples no JAX multicontrolador. Isso permite que tarefas mais exigentes em termos de computação, como carregamento de dados e criação de pontos de verificação, evitem a transferência de dados entre o cliente e as máquinas de TPU. Para configurar seu cluster do Pathways para executar a API JAX do Python colocada, siga as instruções no README do Python colocado. Essas instruções explicam como iniciar um sidecar do Python colocado ao lado dos workers do Pathways.

Carregamento de dados

Durante o treinamento, lotes de um conjunto de dados são carregados repetidamente para alimentar o modelo. Ter um carregador de dados assíncrono e eficiente que fragmenta o lote em vários hosts é importante para evitar que os aceleradores fiquem sem trabalho. Ao executar o treinamento com os programas de aprendizado, o carregador de dados é executado em uma VM de CPU (ao contrário de uma VM de TPU, que é usada em configurações de vários controladores) e envia dados para VMs de TPU. Isso causa uma latência maior na leitura de dados, mas é parcialmente mitigado pela leitura antecipada de X lotes no host da CPU e pelo envio assíncrono dos dados lidos para as TPUs. Essa solução é suficiente quando executada em escala pequena a média.

Para um desempenho ideal em grande escala, recomendamos a colocalização do pipeline de dados de entrada usando Python colocalizado para executar o pipeline de dados diretamente nos aceleradores. Isso elimina o gargalo da CPU e aproveita as interconexões rápidas da TPU para transferência de dados.

Você pode encontrar uma implementação de referência da migração de um pipeline de entrada baseado no TFDS na implementação RemoteIterator em multihost_dataloading.py. Essa implementação funciona no JAX e no Pathways com vários controladores de maneira distribuída usando a API Python JAX colocada.

Controle de versões do Jax

As versões do Pathways são fortemente acopladas às versões do JAX para garantir compatibilidade e estabilidade. Para evitar possíveis problemas, verifique se os artefatos do programa de aprendizado e a versão do JAX estão alinhados. Cada versão do Pathways especifica claramente as versões compatíveis do JAX usando uma tag no formato jax-<version>.

Cache de compilação

O cache de compilação persistente do programa de aprendizado é um recurso que permite que os servidores do programa de aprendizado armazenem executáveis XLA compilados em um local persistente, como o Cloud Storage, para evitar compilação redundante. Esse recurso é ativado por padrão. O local do cache é transmitido como uma flag --gcs_scratch_location para os contêineres do gerenciador de recursos e do worker do Pathways. Para manter os custos de armazenamento associados no mínimo, o cache anexa uma política de ciclo de vida ao local do Cloud Storage. Há um limite de 50 políticas por bucket do Cloud Storage. Portanto, recomendamos usar um local comum do Cloud Storage em todas as cargas de trabalho.

Esse cache é semelhante ao cache de compilação do JAX, que é desativado pelo pathwaysutils.initialize() para cargas de trabalho do Pathways.

Criação de perfil

É possível usar o criador de perfil do JAX para gerar rastreamentos de um programa JAX. Há duas maneiras comuns de fazer isso com os programas:

  • Programática
    • Capturar perfis de maneira programática no seu código JAX
  • Manual
    • Capturar perfis on demand depois de iniciar o servidor do criador de perfil no seu código JAX

Em ambos os casos, os perfis são gravados em um bucket do Cloud Storage. Vários arquivos de rastreamento serão criados no bucket do Cloud Storage, possivelmente em pastas de carimbo de data/hora diferentes. Por exemplo:

  • Processo principal do Python que invocou o rastreamento (normalmente a VM do notebook): <jax-client-vm-name>.xplane.pb
  • Proxy IFRT de programas: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Gerenciador de recursos dos programas: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Trabalhadores de programas de aprendizado: server.*<tpu-node-name>.xplane.pb

Esses arquivos de rastreamento podem ser analisados com o TensorBoard executando o seguinte comando. Para mais informações sobre o TensorBoard e todas as ferramentas de criação de perfis dele, consulte Otimizar o desempenho do TensorFlow usando o Profiler.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

Substitua:

  • BUCKET : um bucket do Cloud Storage para armazenar os arquivos de rastreamento
  • PREFIX: um caminho no bucket do Cloud Storage para armazenar os arquivos de rastreamento

Captura programática de perfil

Capture um perfil de dentro do seu código. Os perfis são salvos em gs://<bucket>/<prefix> em um diretório de carimbo de data/hora

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Captura manual de perfil

Para capturar manualmente informações de perfil, inicie o servidor do criador de perfil no código Python:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

Enquanto o servidor do criador de perfis está em execução, é possível capturar um perfil e exportar os dados para o local de destino do Cloud Storage:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

É possível encontrar informações de tempo para métodos de cliente proxy IFRT, como Compile e Execute, no rastreamento do seu programa. Esses eventos, que detalham as interações com o servidor gRPC do proxy IFRT durante a compilação e a execução, aparecem na linha de execução chamada GrpcClientSessionUserFuturesWorkQueue. Ao examinar essa linha de execução no seu rastreamento, você pode ter insights sobre a performance dessas operações.

Flags do XLA

Ao usar o Pathways, você precisa definir as flags XLA no contêiner pathways-proxy. É possível fazer isso usando XPK ou a API PathwaysJob.

Ao usar o XPK, defina flags XLA como as seguintes:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

Ao usar a API PathwaysJob, defina flags XLA como as seguintes:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

Substitua:

  • USER : seu Google Cloud nome de usuário
  • value[n]: as flags do XLA que você quer definir.

Despejo de HLO

Para se aprofundar nas entradas do High Level Optimizer (HLO) fornecidas ao compilador XLA, configure o Pathways para despejar o HLO em um local especificado do Cloud Storage da seguinte maneira:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"

A seguir