Ejecuta una carga de trabajo por lotes con Pathways

A los efectos de este documento, las cargas de trabajo por lotes se definen como cargas de trabajo de JAX que se ejecutan hasta completarse y se implementan en el mismo clúster de GKE que el clúster de Pathways, específicamente junto con los componentes del controlador de Pathways (servidor proxy de IFRT y administrador de recursos de Pathways). Cuando se completa la carga de trabajo de JAX, se finalizan los componentes del clúster de Pathways. En esta guía, se usa una carga de trabajo de entrenamiento de JAX para demostrar esto.

Antes de comenzar

Asegúrate de tener lo siguiente:

Compila una imagen de entrenamiento con MaxText

MaxText es un proyecto de modelo de lenguaje grande (LLM) de código abierto desarrollado por Google. Está escrito en JAX y diseñado para ser altamente eficiente y escalable, y se ejecuta de manera eficiente en las TPU y GPU de Google Cloud.

Para compilar una imagen de Docker de MaxText con la versión estable más reciente de JAX del repositorio de GitHub de OSS, ejecuta el siguiente comando:

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

Este comando envía la imagen de Kubernetes de MaxText a gcr.io/$PROJECT/${USER}_runner. Puedes usar esta imagen de Docker para ejecutar el entrenamiento en TPU con el backend de Pathways.

Ejecuta una carga de trabajo por lotes con la API de PathwaysJob

El siguiente manifiesto implementa los componentes de Pathways y ejecuta una carga de trabajo de MaxText con la API de PathwaysJob. La carga de trabajo se encapsula en el contenedor main y ejercita train.py.

Copia el siguiente código YAML en un archivo llamado pathways-job-batch-training.yaml y actualiza los valores editables.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

Reemplaza lo siguiente:

  • USER : Tu Google Cloud ID de usuario
  • MAX_RESTARTS : Es la cantidad máxima de veces que se puede reiniciar el trabajo.
  • TPU_MACHINE_TYPE : Es el tipo de máquina de TPU.
  • TOPOLOGY : Es la topología de la TPU v4 o posterior. Para obtener más información sobre las versiones de TPU y las topologías compatibles, consulta Versiones de TPU.
  • WORKLOAD_NODEPOOL_COUNT : Es la cantidad de grupos de nodos que usa una carga de trabajo de Pathways.
  • BUCKET_NAME : Un bucket de Cloud Storage para almacenar archivos temporales
  • PROJECT : El ID de tu proyecto de Google Cloud
  • RUN_NAME : Es un nombre asignado por el usuario para identificar la ejecución del flujo de trabajo.

Puedes implementar el archivo PathwaysJob YAML de la siguiente manera:

kubectl apply -f pathways-job-batch-training.yaml

Para ver la instancia PathwaysJob creada por el comando anterior, usa el siguiente comando:

kubectl get pathwaysjob

El resultado debería verse así:

NAME             AGE
pathways-trial   9s

Para modificar un atributo de la instancia PathwaysJob, borra la instancia PathwaysJob, modifica el archivo YAML y aplícalo para crear una instancia PathwaysJob nueva.

Para seguir el progreso de tu carga de trabajo, navega al Explorador de registros de tu contenedor de JAX y elige main en el filtro Nombre del contenedor.

Deberías ver registros como los siguientes, que indican que el entrenamiento está progresando. La carga de trabajo se completará después de 30 pasos.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Para borrar la instancia PathwaysJob, puedes usar el siguiente comando:

kubectl delete -f pathways-job-batch-training.yaml

Ejecuta una carga de trabajo por lotes con XPK

Ahora puedes enviar la imagen de Docker de MaxText compilada previamente con XPK usando el mismo comando que usaste antes.

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

Reemplaza lo siguiente:

  • WORKLOAD: Es un nombre único para identificar tu carga de trabajo.
  • CLUSTER: Es el nombre del clúster de GKE.
  • WORKLOAD_NODEPOOL_COUNT : Es la cantidad máxima de veces que se puede reiniciar el trabajo.
  • TPU_TYPE: El tipo de TPU especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información sobre los tipos de TPU compatibles con cada versión de TPU, consulta Versiones de TPU.
  • PROJECT : ID de tu proyecto de Google Cloud
  • ZONE: Es la zona en la que planeas ejecutar tu carga de trabajo.
  • USER : Tu Google Cloud ID de usuario
  • RUN_NAME : Es un nombre asignado por el usuario para identificar la ejecución del flujo de trabajo.

Deberías ver un resultado como el siguiente:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

Usa el vínculo del resultado del comando XPK anterior para seguir el progreso de tu carga de trabajo. Puedes filtrar los registros de tu contenedor de JAX si eliges jax-tpu en el filtro Nombre del contenedor.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

La carga de trabajo se completa después de la cantidad especificada de pasos. Si deseas finalizarlo antes de tiempo, usa el siguiente comando:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

¿Qué sigue?