Treinar um modelo usando a TPU v6e

Neste documento, você vai aprender sobre o treinamento de modelos no Cloud TPU v6e (também chamado de Trillium), a configuração do ambiente e a otimização do desempenho, além de conferir exemplos práticos de treinamento com o JAX e o PyTorch/XLA.

A TPU v6e, também chamada de Trillium, é a 6ª geração de TPUs do Google. Em todas as plataformas técnicas, como a API e os registros, e ao longo deste documento, a Trillium será chamada de v6e. Com 256 chips por pod, a arquitetura da TPU v6e é muito parecida com a da v5e. A TPU v6e é otimizada para treinamento, ajuste e disponibilização de transformadores, conversão de texto em imagem e redes neurais convolucionais (CNNs). Para mais informações sobre a arquitetura e as configurações do sistema da TPU v6e, consulte TPU v6e.

Para saber como executar a inferência no Cloud TPU v6e, consulte estes tutoriais:

Antes de começar

Antes de começar, faça o seguinte:

  • Crie uma conta e um projeto do Google Cloud com o faturamento ativado.
  • Instale os componentes Alfa da CLI do Google Cloud.
  • Ative a API Cloud TPU.
  • Crie um agente de serviço do Cloud TPU.
  • Crie uma conta de serviço do Cloud TPU e conceda permissões.

Para mais informações, consulte Configurar o ambiente do Cloud TPU.

Verificar cota e permissões

Verifique se o projeto tem estas cotas:

Ao usar o GKE com o XPK, você precisa de outras permissões no console do Google Cloud . Para mais informações, consulte Permissões necessárias no console doGoogle Cloud .

Provisionar TPUs

É possível provisionar e gerenciar TPUs v6e usando os seguintes métodos:

  • GKE: é possível usar o GKE para provisionar e gerenciar TPUs como um pool de aceleradores para cargas de trabalho conteinerizadas de machine learning. Para mais informações, consulte Sobre TPUs no GKE.
  • GKE e XPK: o XPK é uma ferramenta de linha de comando que simplifica a criação de clusters e a execução de cargas de trabalho no GKE. Com ele, os profissionais de ML podem provisionar TPUs e executar jobs de treinamento sem ter grande familiaridade com o Kubernetes. Para mais informações, consulte o repositório do GitHub do XPK.
  • Recursos em fila do Cloud TPU: com os recursos em fila, é possível solicitar capacidade de TPU, e ela é provisionada assim que fica disponível. Eles são ideais para jobs em lote e cargas de trabalho tolerantes a falhas que podem esperar em uma fila. É possível especificar um período para a solicitação. Para mais informações, consulte Gerenciar recursos em fila.

Provisionar Cloud TPUs v6e com o GKE e o XPK

Ao usar comandos do GKE com a v6e, use os comandos do Kubernetes ou do XPK para provisionar Cloud TPUs e treinar ou disponibilizar modelos. Consulte Planejamento para o uso de Cloud TPUs no GKE para saber como planejar as configurações do Cloud TPU em clusters do GKE. As seções a seguir fornecem comandos para criar um cluster do XPK com disponibilidade para uma e várias NICs.

Criar um cluster do XPK que aceite NIC única

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

Descrições de flags de comando

Variável Descrição
CLUSTER_NAME O nome atribuído pelo usuário para o cluster do XPK.
PROJECT_ID Nome do projeto doGoogle Cloud . Use um projeto atual ou crie um novo. Para mais informações, consulte Configurar o projeto do Google Cloud .
ZONE Consulte o documento Regiões e zonas do Cloud TPU para saber quais são as zonas disponíveis.
TPU_TYPE Consulte Tipos de aceleradores.
NUM_SLICES O número de frações que você quer criar.
CLUSTER_ARGUMENTS A rede e a sub-rede a serem usadas.

Por exemplo: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES O número de frações a serem criadas.
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Criar um cluster do XPK que aceite várias NICs

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Descrições de flags de comando

Variável Descrição
CLUSTER_NAME O nome atribuído pelo usuário para o cluster do XPK.
PROJECT_ID Nome do projeto doGoogle Cloud . Use um projeto atual ou crie um novo. Para mais informações, consulte Configurar o projeto do Google Cloud .
ZONE Consulte o documento Regiões e zonas do Cloud TPU para saber quais são as zonas disponíveis.
TPU_TYPE Consulte Tipos de aceleradores.
NUM_SLICES O número de frações que você quer criar.
CLUSTER_ARGUMENTS A rede e a sub-rede a serem usadas.

Por exemplo: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Rede de nós adicional a ser usada.

Por exemplo: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES O número de frações a serem criadas (necessário apenas para várias frações).
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Configurar o JAX ou o PyTorch

Os recursos a seguir mostram como configurar o JAX ou o PyTorch no Cloud TPU, de acordo com o método de provisionamento e o gerenciamento usados:

Para configurar e executar o XPK com o MaxText, consulte Como executar o MaxText em grande escala com o XPK .

Otimizar o desempenho da rede

Esta seção descreve como otimizar o desempenho da rede com a configuração da unidade máxima de transmissão (MTU), o uso de várias NICs para ambientes de várias frações e a melhora das configurações de TCP.

Configurar a MTU

Para ter o melhor desempenho de rede possível, use uma rede com uma MTU de 8.896 bytes.

Por padrão, uma nuvem privada virtual (VPC) só fornece uma MTU de 1.460 bytes, o que resulta em um desempenho de rede abaixo do ideal. É possível definir a MTU de uma rede VPC com qualquer valor de 1.300 a 8.896 bytes. Os tamanhos personalizados comuns de MTU são 1.500 bytes (Ethernet padrão) ou 8.896 bytes (o máximo possível). Para mais informações, consulte Tamanhos válidos para a MTU da rede VPC.

Para saber como mudar a configuração da MTU de uma rede padrão ou atual, consulte Alterar a configuração da MTU de uma rede VPC.

O exemplo a seguir cria uma rede com MTU de 8.896 bytes e uma regra de firewall correspondente que permite o tráfego de TCP, ICMP e UDP na rede.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

Substitua your-resource-name por um nome base para a rede e o firewall.

Usar a opção de várias NICs para várias frações

Ao usar um ambiente de várias frações, defina as seguintes variáveis de ambiente, que são necessárias para uma sub-rede secundária:

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Use os comandos a seguir para criar um roteamento de IP personalizado para a rede e a sub-rede.

  1. Crie a rede secundária.

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. Crie uma sub-rede para a rede secundária.

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. Crie uma regra de firewall para permitir o tráfego na nova sub-rede.

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. Crie um Cloud Router para a rede secundária.

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. Crie uma configuração NAT para o Cloud Router.

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

Depois de criar uma fração de várias redes, verifique se as duas placas de rede (NICs) estão sendo usadas. Para isso, configure um cluster do XPK e adicione a flag --command ifconfig ao comando de criação de carga de trabalho do XPK.

  1. Use o comando workload create a seguir para mostrar a saída do comando ifconfig nos registros do console do Google Cloud e verifique se eth0 e eth1 têm uma MTU de 8.896 bytes definida.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    Se você quiser ativar os registros de depuração ou usar o TensorBoard da Vertex AI, adicione os seguintes argumentos opcionais ao comando:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. Verifique se eth0 e eth1 têm uma MTU de 8.896 bytes definida. Para isso, confira a saída da carga de trabalho do XPK nos registros do console do Google Cloud .

Melhorar as configurações de TCP

Se você tiver provisionado os Cloud TPUs usando recursos em fila, execute o comando a seguir para melhorar o desempenho da rede aumentando os limites do buffer de recebimento de TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Otimizar o desempenho da alocação de memória

A biblioteca tcmalloc é usada por padrão em VMs do Cloud TPU para melhorar o desempenho de modelos com alocações de memória consideráveis e frequentes. Para fazer essa configuração, use a variável de ambiente LD_PRELOAD.

No entanto, tcmalloc pode causar lentidão para algumas cargas de trabalho, como DLRM com alocações de tabela de embedding muito grandes. Nesses casos, é possível fazer a reversão para a função malloc padrão. Para isso, remova a definição da variável LD_PRELOAD na sessão do shell antes de executar o script de treinamento:

unset LD_PRELOAD

Usar o SkyPilot

Você pode usar o Cloud TPU v6e com o SkyPilot. O SkyPilot é um framework de código aberto que simplifica a execução, o gerenciamento e o escalonamento de cargas de trabalho de IA. É possível adicionar ao SkyPilot informações de localização e preços relacionadas à v6e. Para mais informações, consulte o exemplo de TPU v6e do SkyPilot.

Exemplos de treinamento

As seções a seguir fornecem exemplos de treinamento para os modelos MaxText, MaxDiffusion e PyTorch no Cloud TPU v6e.

Esses exemplos foram testados com as seguintes versões de software:

  • Python 3.10 ou mais recente.
  • Versões noturnas de software:
    • JAX 0.4.32.dev20240912 noturno.
    • LibTPU 0.1.dev20240912+nightly noturna.
  • Versões estáveis de software:
    • JAX + JAX Lib v0.4.37.

Treinar o MaxText e o MaxDiffusion no Cloud TPU v6e

As seções a seguir abordam o ciclo de vida de treinamento dos modelos MaxText e MaxDiffusion.

Confira abaixo as etapas gerais:

  1. Crie a imagem de base da carga de trabalho.
  2. Execute a carga de trabalho usando o XPK.
    1. Crie o comando de treinamento para a carga de trabalho.
    2. Implante a carga de trabalho.
  3. Monitore a carga de trabalho e confira as métricas.
  4. Exclua a carga de trabalho do XPK se ela não for mais necessária.
  5. Exclua o cluster do XPK quando ele não for mais necessário.

Criar imagem de base

Instale o MaxText ou o MaxDiffusion e crie a imagem Docker:

  1. Clone o repositório que você quer usar e mude para o diretório dele:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Configure o Docker para usar a CLI do Google Cloud:

    gcloud auth configure-docker
    
  3. Crie a imagem do Docker usando o comando a seguir ou uma imagem de IA do JAX. Para mais informações sobre as imagens de IA do JAX, consulte Imagens de IA do JAX.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Defina o ID do projeto na configuração ativa da gcloud CLI:

    gcloud config set project ${PROJECT_ID}
    
  5. Ao iniciar a carga de trabalho em uma máquina que não tem a imagem criada localmente, faça o upload dela.

    1. Defina a variável de ambiente CLOUD_IMAGE_NAME:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Faça upload da imagem:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Executar cargas de trabalho usando o XPK

  1. Defina as seguintes variáveis de ambiente se você não estiver usando os valores padrão definidos pelo MaxText ou pelo MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Crie o script do modelo. Esse script será copiado como um comando de treinamento em uma etapa posterior.

    Não execute o script do modelo ainda.

    MaxText

    O MaxText é um LLM de código aberto e alto desempenho altamente escalonável escrito em Python e JAX puros e destinado a TPUs e GPUs do Google Cloud para treinamento e inferência.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    O Gemma é uma família de LLMs de peso aberto desenvolvidos pelo Google DeepMind com base na pesquisa e na tecnologia do Gemini.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    O Mixtral é um modelo de IA de última geração desenvolvido pela Mistral AI que usa uma arquitetura esparsa de combinação de especialistas (MoE).

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    O Llama é uma família de LLMs de peso aberto desenvolvidos pela Meta.

    Para conferir um exemplo de como executar o Llama3 no PyTorch, consulte Modelos torch_xla no repositório torchprime.

    MaxDiffusion

    O MaxDiffusion é uma coleção de implementações de referência de vários modelos de difusão baseados em espaço latente e escritos em Python e JAX puros que são executados em dispositivos XLA, incluindo GPUs e Cloud TPUs. O Stable Diffusion é um modelo de conversão de texto em imagem baseado em espaço latente que gera imagens realistas com base em qualquer entrada de texto.

    Você precisa instalar uma ramificação específica do Git para executar o MaxDiffusion, conforme mostrado no script de treinamento a seguir.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. Exporte as seguintes variáveis:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Descrições de variáveis de ambiente

    Variável Descrição
    CLUSTER_NAME O nome do cluster do XPK.
    ACCELERATOR_TYPE O tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores disponíveis em cada versão de TPU, consulte Versões de TPU.
    NUM_SLICES O número de frações de TPU.
    YOUR_MODEL_SCRIPT O script do modelo a ser executado como um comando de treinamento.
  4. Execute o modelo usando o script criado na etapa anterior. Você precisa especificar a flag --base-docker-image para usar a imagem de base do MaxText ou a flag --docker-image e a imagem que você quer usar.

    Você pode adicionar as seguintes flags opcionais:

    • Para ativar o registro de depuração, inclua a flag --enable-debug-logs. Para mais informações, consulte Depurar o JAX no MaxText.
    • É possível criar um Experimento da Vertex AI para fazer upload de dados para o TensorBoard da Vertex AI incluindo a flag --use-vertex-tensorboard. Para mais informações, consulte Monitorar o JAX no MaxText usando a Vertex AI.
    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    A saída inclui um link para monitorar a carga de trabalho. Abra o link e clique na guia Registros para monitorar a carga de trabalho em tempo real.

Depurar o JAX no MaxText

Use comandos complementares do XPK para diagnosticar por que o cluster ou a carga de trabalho não está em execução:

Monitorar o JAX no MaxText usando a Vertex AI

Para usar o TensorBoard, a conta de usuário do Google Cloud precisa ter o papel aiplatform.user. Execute o seguinte comando para conceder esse papel:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Confira dados escalares e de perfil no TensorBoard gerenciado da Vertex AI.

  1. Aumente as solicitações de gerenciamento de recursos (CRUD) para a zona em uso de 600 para 5.000. Isso pode não ser um problema em cargas de trabalho pequenas que usam menos de 16 VMs.

  2. Instale dependências como cloud-accelerator-diagnostics para a Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crie o cluster do XPK usando a flag --create-vertex-tensorboard, conforme documentado em Criar o TensorBoard da Vertex AI. Você também pode executar esse comando em clusters atuais.

  4. Crie seu Experimento da Vertex AI ao executar a carga de trabalho do XPK usando a flag --use-vertex-tensorboard e a flag opcional --experiment-name. Para conferir a lista completa de etapas, consulte Criar um Experimento da Vertex AI para fazer upload de dados no TensorBoard da Vertex AI.

Os registros incluem um link para um TensorBoard da Vertex AI, como este:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Também é possível encontrar o link do TensorBoard da Vertex AI no console do Google Cloud . Acesse Experimentos da Vertex AI no console do Google Cloud . Selecione a região apropriada no menu suspenso.

O diretório do TensorBoard também é gravado no bucket do Cloud Storage especificado com ${BASE_OUTPUT_DIR}.

Excluir uma carga de trabalho do XPK

Use o comando xpk workload delete para excluir uma ou mais cargas de trabalho com base no prefixo ou no status do job. Esse comando pode ser útil se você tiver jobs presos na fila ou cargas de trabalho do XPK enviadas que não precisam mais de execução.

Excluir um cluster do XPK

Use o comando xpk cluster delete para excluir o cluster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

Resultados do comparativo do MaxDiffusion

O script de treinamento do MaxDiffusion foi executado em uma v6e-4, uma v6e-16 e duas v6e-16. A tabela a seguir mostra a medição das capacidades de processamento.

v6e-4 v6e-16 Duas v6e-16
Etapas de treinamento 0,069 0,073 0,13
Tamanho global do lote 8 32 64
Capacidade de processamento (exemplos/segundo) 115,9 438,4 492,3

Treinar modelos Llama usando o PyTorch/XLA no Cloud TPU v6e

Esta seção descreve como treinar modelos Llama usando o PyTorch/XLA no Cloud TPU v6e com o conjunto de dados WikiText.

Acessar o Hugging Face e o modelo Llama 3

Você precisa de um token de acesso de usuário do Hugging Face para este exemplo. Para saber como criar tokens de acesso de usuário, consulte a documentação do Hugging Face sobre o assunto.

Você também precisa de permissão para acessar o modelo Llama-3-8B no Hugging Face. Para isso, navegue até o modelo Meta-Llama-3-8B no Hugging Face e solicite o acesso.

Criar uma VM do Cloud TPU

Crie um Cloud TPU v6e com oito chips para este exemplo.

  1. Configure as variáveis de ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Descrições de variáveis de ambiente

    Variável Descrição
    PROJECT_ID O ID do projeto do Google Cloud . Use um projeto atual ou crie um novo.
    TPU_NAME O nome da TPU.
    ZONE A zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU.
    ACCELERATOR_TYPE O tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU.
    RUNTIME_VERSION A versão do software do Cloud TPU.

  2. Crie uma VM do Cloud TPU:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Instalação

Instale o fork pytorch-tpu/transformers da biblioteca Transformers do Hugging Face e as respectivas dependências. Esse exemplo foi testado com as seguintes versões de dependência:

  • torch: aceita a 2.5.0
  • torch_xla[tpu]: aceita a 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Configurar arquivos de configuração de modelo

O comando de treinamento na próxima seção, Executar o modelo, usa dois arquivos de configuração JSON para definir parâmetros de modelo e uma configuração de paralelismo de dados totalmente fragmentados (FSDP). A fragmentação do FSDP permite usar um tamanho de lote maior durante o treinamento porque fragmenta os pesos do modelo em várias TPUs. Ao treinar com modelos menores, pode ser suficiente usar o paralelismo de dados e replicar os pesos em cada dispositivo. Para saber como fragmentar tensores em dispositivos no PyTorch/XLA, consulte o guia do usuário do SPMD do PyTorch/XLA.

  1. Crie o arquivo de configuração de parâmetros do modelo. Confira a seguir a configuração de parâmetros do modelo Llama-3-8B. Para outros modelos, encontre o arquivo de configuração no Hugging Face. Por exemplo, consulte a configuração do Llama-2-7B.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Crie o arquivo de configuração do FSDP:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Para mais informações sobre o FSDP, consulte Paralelismo de dados totalmente fragmentados com SPMD.

  3. Faça upload dos arquivos de configuração nas VMs do Cloud TPU usando este comando:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Executar o modelo

Usando os arquivos de configuração criados na seção anterior, execute o script run_clm.py para treinar o modelo Llama-3-8B no conjunto de dados WikiText. O script de treinamento leva aproximadamente 10 minutos para ser executado em um Cloud TPU v6e-8.

  1. No Cloud TPU, faça login no Hugging Face usando o seguinte comando:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Execute o treinamento do modelo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Solução de problemas do PyTorch/XLA

Se você definiu as variáveis opcionais para depuração na seção anterior, o perfil do modelo será armazenado no local especificado pela variável PROFILE_LOGDIR. É possível extrair o arquivo xplane.pb armazenado nesse local e usar tensorboard para conferir os perfis no navegador seguindo as instruções do TensorBoard.

Se o PyTorch/XLA não estiver funcionando como esperado, consulte o Guia de solução de problemas, que tem sugestões para depurar, criar perfis e otimizar o modelo.