Entrena un modelo con la TPU v6e

En este documento, se explica cómo entrenar modelos con la Cloud TPU v6e (también llamada Trillium) y se cubren temas como la configuración del entorno y la optimización del rendimiento, además de brindar ejemplos prácticos de entrenamiento con JAX y PyTorch/XLA.

La TPU v6e, también llamada Trillium, es la 6ª generación de TPU de Google. En todas las plataformas técnicas, como la API y los registros, y en todo este documento, se hará referencia a Trillium como v6e. Con 256 chips por Pod, la arquitectura de la TPU v6e comparte muchas similitudes con la v5e. La TPU v6e está optimizada para el entrenamiento, el ajuste y la entrega de transformadores, redes neuronales convolucionales (CNN) y modelos de texto a imagen. Para obtener más información sobre la arquitectura y los parámetros de configuración del sistema de la TPU v6e, consulta TPU v6e.

Para obtener información sobre cómo ejecutar la inferencia en la Cloud TPU v6e, consulta los instructivos que se indican a continuación:

Antes de empezar

Antes de dar el primer paso, debes completar los pasos que se indican a continuación:

  • Crear una cuenta y un proyecto de Google Cloud con la facturación habilitada
  • Instalar los componentes alfa de la Google Cloud CLI
  • Habilitar la API de Cloud TPU
  • Crear un agente de servicio de Cloud TPU
  • Crear una cuenta de servicio de Cloud TPU y otorgar permisos

Para obtener más información, consulta Configura el entorno de Cloud TPU.

Verifica la cuota y los permisos

Verifica que tu proyecto tenga las cuotas que se indican más abajo:

Si usas GKE con XPK, necesitas permisos adicionales en la consola de Google Cloud . Para obtener más información, consulta Permisos necesarios en la consola deGoogle Cloud .

Aprovisiona TPU

Para aprovisionar y administrar la TPU v6e, sigue los pasos que se indican a continuación:

  • GKE: Puedes usar GKE con el objetivo de aprovisionar y administrar TPU como un grupo de aceleradores para las cargas de trabajo de aprendizaje automático alojadas en contenedores. Para obtener más información, consulta Acerca de las TPU en GKE.
  • GKE y XPK: XPK es una herramienta de línea de comandos que simplifica la creación de clústeres y la ejecución de cargas de trabajo en GKE. Está diseñada para que los profesionales del AA aprovisionen TPU y ejecuten trabajos de entrenamiento sin necesidad de tener un conocimiento profundo de Kubernetes. Para obtener más información, consulta el repositorio de XPK en GitHub.
  • Recursos en cola de Cloud TPU: Los recursos en cola permiten solicitar la capacidad de TPU que se aprovisiona cuando está disponible. Es ideal para trabajos por lotes y cargas de trabajo tolerantes a errores que pueden esperar en una cola. Para la solicitud, se puede especificar un período. Para obtener más información, consulta Administra recursos en cola.

Aprovisiona la Cloud TPU v6e con GKE y XPK

Si usas comandos de GKE con la v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar Cloud TPU y entrenar o entregar modelos. Para obtener información sobre cómo planificar la configuración de la Cloud TPU en clústeres de GKE, consulta Planifica las Cloud TPU en GKE. En las secciones siguientes, se proporcionan comandos que crean un clúster de XPK con compatibilidad para una sola NIC y varias NIC.

Crea un clúster de XPK con compatibilidad con una sola NIC

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

Descripciones de las marcas de comandos

Variable Descripción
CLUSTER_NAME Es el nombre que el usuario asigna al clúster de XPK.
PROJECT_ID Es el nombre del proyecto deGoogle Cloud . Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Configura tu proyecto de Google Cloud .
ZONE Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles.
TPU_TYPE Consulta Tipos de aceleradores.
NUM_SLICES Es la cantidad de porciones que quieres crear.
CLUSTER_ARGUMENTS Es la red y la subred que se usarán.

Por ejemplo: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Es la cantidad de porciones que se crearán.
NETWORK_NAME Es el nombre de una red secundaria que se usará.
NETWORK_FW_NAME Es el nombre de un firewall de red secundario que se usará.

Crea un clúster de XPK con compatibilidad para varias NIC

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Descripciones de las marcas de comandos

Variable Descripción
CLUSTER_NAME Es el nombre que el usuario asigna al clúster de XPK.
PROJECT_ID Es el nombre del proyecto deGoogle Cloud . Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Configura tu proyecto de Google Cloud .
ZONE Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles.
TPU_TYPE Consulta Tipos de aceleradores.
NUM_SLICES Es la cantidad de porciones que quieres crear.
CLUSTER_ARGUMENTS Es la red y la subred que se usarán.

Por ejemplo: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Es la red de nodos adicional que se usará.

Por ejemplo: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Es la cantidad de porciones que se crearán (solo es necesario para porciones múltiples).
NETWORK_NAME Es el nombre de una red secundaria que se usará.
NETWORK_FW_NAME Es el nombre de un firewall de red secundario que se usará.

Configura JAX o PyTorch

En los recursos siguientes, se muestra se cómo configura JAX o PyTorch en tu Cloud TPU según el método de aprovisionamiento y administración que uses:

Para configurar y ejecutar XPK con MaxText, consulta Ejecuta MaxText a gran escala con XPK .

Optimiza el rendimiento de la red

En esta sección, se describe cómo optimizar el rendimiento de tu red configurando la unidad máxima de transmisión (MTU), usando varias NIC en entornos de porciones múltiples y mejorando la configuración de la TCP.

Configura la MTU

Para obtener el mejor rendimiento de la red, usa una red con una MTU (unidad de transmisión máxima) de 8,896 bytes.

De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1,460 bytes, lo que genera un rendimiento de red deficiente. Puedes configurar la MTU de una red de VPC con cualquier valor de entre 1,300 bytes y 8,896 bytes (inclusive). Los tamaños de MTU personalizados comunes son de 1,500 bytes (Ethernet estándar) o de 8,896 bytes (el máximo posible). Para obtener más información, consulta Tamaños válidos de MTU de las redes de VPC.

Para obtener más información sobre cómo cambiar la configuración de la MTU de una red existente o predeterminada, consulta Cambia la configuración de la MTU de una red de VPC.

En el ejemplo siguiente, se crea una red con una MTU de 8,896 bytes y una regla de firewall correspondiente que permite el tráfico de TCP, ICMP y UDP dentro de la red.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

Reemplaza your-resource-name por un nombre base para la red y el firewall.

Usa la opción de varias NIC para porciones múltiples

Si usas un entorno de porciones múltiples, configura las siguientes variables de entorno, las cuales son obligatorias para una subred secundaria:

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Usa los comandos siguientes con el objetivo de crear un enrutamiento de IP personalizado para la red y la subred.

  1. Crea la red secundaria.

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. Crea una subred para la red secundaria.

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. Crea una regla de firewall que permita el tráfico dentro de la subred nueva.

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. Crea un Cloud Router para la red de secundaria.

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. Crea una configuración de NAT para el Cloud Router.

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

Después de crear una porción de varias redes, puedes validar que las dos tarjetas de interfaz de red (NIC) se estén usando configurando un clúster de XPK y agregando la marca --command ifconfig al comando de creación de la carga de trabajo de XPK.

  1. Usa el comando siguiente workload create para mostrar el resultado del comando ifconfig en los registros de la consola de Google Cloud y comprueba que eth0 y eth1 tengan la MTU establecida en 8,896 bytes.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    Si quieres habilitar los registros de depuración o usar Vertex AI TensorBoard, agrega los siguientes argumentos opcionales al comando:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. Comprueba que eth0 y eth1 tengan la MTU establecida en 8,896 bytes. Para ello, revisa el resultado de la carga de trabajo de XPK en los registros de la consola de Google Cloud .

Mejora la configuración del TCP

Si aprovisionaste las Cloud TPU con recursos en cola, puedes ejecutar el comando siguiente con el objetivo de mejorar el rendimiento de la red; para ello, aumenta los límites de recepción del búfer del TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Optimiza el rendimiento de la asignación de memoria

La biblioteca tcmalloc se usa de forma predeterminada en las VMs de Cloud TPU para mejorar el rendimiento de los modelos con asignaciones de memoria frecuentes y de gran tamaño. Esto se configura a través de la variable de entorno LD_PRELOAD.

Sin embargo, para algunas cargas de trabajo (por ejemplo, DLRM con asignaciones de tablas de incorporación muy grandes), tcmalloc puede ralentizar el proceso. En esos casos, puedes volver a la función malloc estándar anulando la variable LD_PRELOAD de la sesión de shell antes de ejecutar la secuencia de comandos de entrenamiento:

unset LD_PRELOAD

Usa SkyPilot

La Cloud TPU v6e se puede usar con SkyPilot. SkyPilot es un framework de código abierto que simplifica el proceso de ejecución, administración y escalamiento de cargas de trabajo de IA. Además, puedes agregar información sobre la ubicación y los precios relacionados con la v6e a SkyPilot. Para obtener más información, consulta el ejemplo de la TPU v6e con SkyPilot.

Ejemplos de entrenamiento

En las secciones siguientes, se proporcionan ejemplos para entrenar modelos de MaxText, MaxDiffusion y PyTorch en la Cloud TPU v6e.

Estos ejemplos se probaron con las versiones de software que se indican a continuación:

  • Python 3.10 o una versión posterior
  • Versiones nocturnas de software:
    • JAX nocturno 0.4.32.dev20240912
    • LibTPU nocturna 0.1.dev20240912+nightly
  • Versiones estables de software:
    • JAX y JAX Lib versión 0.4.37

Entrena MaxText y MaxDiffusion con la Cloud TPU v6e

En las secciones siguientes, se analiza el ciclo de vida del entrenamiento de los modelos de MaxText y MaxDiffusion.

En general, los pasos generales son los que se indican a continuación:

  1. Crea la imagen base de la carga de trabajo.
  2. Ejecuta la carga de trabajo con XPK.
    1. Crea el comando de entrenamiento para la carga de trabajo.
    2. Implementa la carga de trabajo.
  3. Sigue la carga de trabajo y consulta sus métricas.
  4. Borra la carga de trabajo de XPK si no es necesaria.
  5. Borra el clúster de XPK cuando ya no lo necesites.

Crea la imagen base

Instala MaxText o MaxDiffusion y crea la imagen de Docker siguiendo los pasos que se indican a continuación:

  1. Clona el repositorio que quieres usar y cambia al directorio del repositorio:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Configura Docker para usar la Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Crea la imagen de Docker con el comando siguiente o con una imagen generada por IA de JAX. Para obtener más información sobre las imágenes generadas por IA de JAX, consulta Imágenes generadas por IA de JAX.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Establece el ID del proyecto en la configuración activa de la gcloud CLI:

    gcloud config set project ${PROJECT_ID}
    
  5. Si inicias la carga de trabajo desde una máquina que no tiene la imagen creada a nivel local, súbela.

    1. Establece la variable de entorno CLOUD_IMAGE_NAME:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Sube la imagen:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Ejecuta la carga de trabajo con XPK

  1. Establece las siguientes variables de entorno si no usas los valores predeterminados establecidos por MaxText o MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Crea la secuencia de comandos del modelo. Esta se copiará como un comando de entrenamiento en un paso posterior.

    Aún no ejecutes la secuencia de comandos del modelo.

    MaxText

    MaxText es un LLM de código abierto, de alto rendimiento y altamente escalable escrito en Python y JAX en su totalidad, y orientado a TPU y GPU de Google Cloud para el entrenamiento y la inferencia.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma es una familia de LLM de código abierto desarrollada por Google DeepMind que se basa en la investigación y tecnología de Gemini.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral es un modelo de IA de vanguardia desarrollado por Mistral AI y que usa una arquitectura de mezcla de expertos (MoE) dispersa.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama es una familia de LLMs de código abierto desarrollada por Meta.

    Para ver un ejemplo de cómo ejecutar Llama3 en PyTorch, consulta los modelos de torch_xla en el repositorio de torchprime.

    MaxDiffusion

    MaxDiffusion es un conjunto de implementaciones de referencia de varios modelos de difusión latentes escritos en Python y JAX en su totalidad que se ejecutan en dispositivos XLA, incluidas las Cloud TPU y las GPU. El modelo Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto.

    Para ejecutar MaxDiffusion, debes instalar una rama de Git específica, como se muestra en la siguiente secuencia de comandos de entrenamiento.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. Exporta las variables siguientes que se indican a continuación:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Descripciones de las variables de entorno

    Variable Descripción
    CLUSTER_NAME Es el nombre del clúster de XPK.
    ACCELERATOR_TYPE Es el tipo de acelerador que especifica la versión y el tamaño de la Cloud TPU que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    NUM_SLICES Es la cantidad de porciones de TPU.
    YOUR_MODEL_SCRIPT Es la secuencia de comandos del modelo que se ejecutará como un comando de entrenamiento.
  4. Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior. Debes especificar la marca --base-docker-image para usar la imagen base de MaxText o la marca --docker-image y la imagen que quieras usar.

    Puedes agregar las marcas opcionales que se indican a continuación:

    • Para habilitar el registro de depuración, incluye la marca --enable-debug-logs. Para obtener más información, consulta Depura JAX en MaxText.
    • Para crear un experimento de Vertex AI Experiments con el objetivo de subir datos a Vertex AI TensorBoard, incluye la marca --use-vertex-tensorboard. Para obtener más información, consulta Supervisa JAX en MaxText con Vertex AI.
    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    El resultado incluye un vínculo para seguir la carga de trabajo. Abre el vínculo y haz clic en la pestaña Registros para hacer un seguimiento de la carga de trabajo en tiempo real.

Depura JAX con MaxText

Usa comandos de XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo:

Supervisa JAX con MaxText a través de Vertex AI

Para usar TensorBoard, tu cuenta de usuario de Google Cloud debe tener el rol de aiplatform.user. Ejecuta el comando que se indica a continuación para otorgar este rol:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Consulta los datos escalares y de perfil a través de TensorBoard administrado por Vertex AI.

  1. Aumenta las solicitudes de administración de recursos (CRUD) para la zona que usas de 600 a 5,000. Es posible que esto no sea un problema para las cargas de trabajo pequeñas que usan menos de 16 VMs.

  2. Sigue los pasos que se indican a continuación con el objetivo de instalar dependencias como cloud-accelerator-diagnostics para Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crea tu clúster de XPK con la marca --create-vertex-tensorboard, como se documenta en Crea Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres existentes.

  4. Crea tu experimento de Vertex AI Experiments cuando ejecutes tu carga de trabajo de XPK con la marca --use-vertex-tensorboard y la marca opcional --experiment-name. Para obtener la lista completa de pasos, consulta Crea un experimento de Vertex AI Experiments para subir datos a Vertex AI TensorBoard.

Los registros incluyen un vínculo a un Vertex AI TensorBoard, similar al que se muestra más abajo:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

También puedes encontrar el vínculo al Vertex AI TensorBoard AI en la consola de Google Cloud . Accede a Vertex AI Experiments en la consola de Google Cloud . Elige la región adecuada en el menú desplegable.

El directorio de TensorBoard también se escribe en el bucket de Cloud Storage que especificaste con ${BASE_OUTPUT_DIR}.

Borra la carga de trabajo de XPK

Usa el comando xpk workload delete para borrar una o más cargas de trabajo según el prefijo o el estado del trabajo. Este comando puede ser útil si enviaste cargas de trabajo de XPK que ya no es necesario ejecutar o si hay trabajos que permanecen en la cola.

Borra el clúster de XPK

Usa el comando xpk cluster delete para borrar el clúster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

Resultados de las comparativas de MaxDiffusion

Ejecutamos la secuencia de comandos de entrenamiento de MaxDiffusion en una v6e-4, una v6e-16 y dos v6e-16. En la tabla siguiente, se muestran las capacidades de procesamiento medidas.

v6e-4 v6e-16 Dos v6e-16
Pasos de entrenamiento 0.069 0.073 0.13
Tamaño del lote global 8 32 64
Capacidad de procesamiento (ejemplos/s) 115.9 438.4 492.3

Entrena modelos de Llama a través de PyTorch/XLA con la Cloud TPU v6e

En esta sección, se describe cómo entrenar modelos de Llama a través de PyTorch/XLA con la Cloud TPU v6e usando el conjunto de datos de WikiText.

Obtén acceso a Hugging Face y al modelo de Llama 3

Para este ejemplo, necesitas un token de acceso de usuario de Hugging Face. Para obtener información sobre cómo crear tokens de acceso de usuario, consulta la documentación de Hugging Face sobre tokens de acceso de usuario.

También necesitas permiso para acceder al modelo Llama-3-8B en Hugging Face. Para ello, accede al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.

Crea una VM de Cloud TPU

Crea una Cloud TPU v6e con 8 chips para este ejemplo.

  1. Configura las variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID Es el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Es el nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Es la versión del software de la Cloud TPU.

  2. Crea una VM de Cloud TPU:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Instalación

Instala la bifurcación pytorch-tpu/transformers de los transformadores de Hugging Face y las dependencias. Este ejemplo se probó con las versiones de dependencias que se indican a continuación:

  • torch: Compatible con 2.5.0
  • torch_xla[tpu]: Compatible con 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Establece los archivos de configuración del modelo

El comando de entrenamiento de la sección siguiente, Ejecuta el modelo, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de paralelismo de datos completamente fragmentados (FSDP). La fragmentación de FSDP permite usar un tamaño de lote más grande durante el entrenamiento fragmentando los pesos del modelo en varias TPU. Cuando se entrena con modelos más pequeños, puede ser suficiente usar el paralelismo de datos y replicar los pesos en cada dispositivo. Para obtener más información sobre cómo fragmentar tensores entre dispositivos en PyTorch/XLA, consulta la guía del usuario de SPMD para PyTorch/XLA.

  1. Crea el archivo de configuración de los parámetros del modelo. A continuación, se muestra la configuración de los parámetros del modelo para Llama-3-8B. Para otros modelos, busca el archivo de configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama-2-7B.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Crea el archivo de configuración de FSDP:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Para obtener más información sobre FSDP, consulta Paralelismo de datos completamente fragmentados con SPMD.

  3. Sube los archivos de configuración a tus VMs de Cloud TPU con el comando que se indica más abajo:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Ejecuta el modelo

Con los archivos de configuración que creaste en la sección anterior, ejecuta la secuencia de comandos run_clm.py para entrenar el modelo de Llama-3-8B en el conjunto de datos de WikiText. La secuencia de comandos de entrenamiento tarda alrededor de 10 minutos en ejecutarse en una Cloud TPU v6e-8.

  1. Accede a Hugging Face desde tu Cloud TPU con el comando que se indica a continuación:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Ejecuta el entrenamiento del modelo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Solución de problemas de PyTorch/XLA

Si configuraste las variables opcionales para la depuración en la sección anterior, el perfil del modelo se almacenará en la ubicación que especifica la variable PROFILE_LOGDIR. Puedes extraer el archivo xplane.pb almacenado en esta ubicación y usar tensorboard para ver los perfiles en el navegador con las instrucciones de TensorBoard.

Si PyTorch/XLA no funciona como se espera, consulta la guía de solución de problemas, que incluye sugerencias para depurar y optimizar el modelo, además de generar perfiles.