Exécuter une charge de travail par lot avec Pathways

Dans le présent document, les charges de travail par lot sont définies comme des charges de travail JAX qui s'exécutent jusqu'à la fin et sont déployées dans le même cluster GKE que le cluster Pathways, plus précisément aux côtés des composants du contrôleur Pathways (serveur proxy IFRT et gestionnaire de ressources Pathways). Une fois la charge de travail JAX terminée, les composants du cluster Pathways sont arrêtés. Ce guide utilise une charge de travail d'entraînement JAX pour illustrer ce point.

Avant de commencer

Vérifiez que vous disposez bien des éléments suivants :

Créer une image d'entraînement à l'aide de MaxText

MaxText est un projet de grand modèle de langage (LLM) Open Source développé par Google. Il est écrit en JAX et conçu pour être très performant et évolutif, et s'exécute efficacement sur les TPU et GPU Google Cloud.

Pour créer une image Docker MaxText à l'aide de la dernière version stable de JAX à partir du dépôt GitHub OSS, exécutez la commande suivante :

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

Cette commande transfère l'image Kubernetes MaxText vers gcr.io/$PROJECT/${USER}_runner. Vous pouvez utiliser cette image Docker pour exécuter l'entraînement sur des TPU à l'aide du backend Pathways.

Exécuter une charge de travail par lot à l'aide de l'API PathwaysJob

Le fichier manifeste suivant déploie les composants Pathways et exécute une charge de travail MaxText à l'aide de l'API PathwaysJob. La charge de travail est encapsulée dans le conteneur main et exerce train.py.

Copiez le fichier YAML suivant dans un fichier nommé pathways-job-batch-training.yaml et mettez à jour les valeurs modifiables.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

Remplacez les éléments suivants :

  • USER : votre ID utilisateur Google Cloud
  • MAX_RESTARTS : nombre maximal de fois où le Job peut être redémarré
  • TPU_MACHINE_TYPE : type de machine TPU
  • TOPOLOGY : topologie TPU v4 ou ultérieure. Pour en savoir plus sur les versions de TPU et les topologies compatibles, consultez Versions de TPU.
  • WORKLOAD_NODEPOOL_COUNT : nombre de pools de nœuds utilisés par une charge de travail Pathways
  • BUCKET_NAME : bucket Cloud Storage pour stocker les fichiers temporaires
  • PROJECT : ID de votre projet Google Cloud
  • RUN_NAME : nom attribué par l'utilisateur pour identifier l'exécution du workflow

Vous pouvez déployer le fichier YAML PathwaysJob comme suit :

kubectl apply -f pathways-job-batch-training.yaml

Pour afficher l'instance PathwaysJob créée par la commande précédente, utilisez :

kubectl get pathwaysjob

Le résultat doit se présenter comme suit :

NAME             AGE
pathways-trial   9s

Pour modifier un attribut de l'instance PathwaysJob, supprimez l'instance PathwaysJob, modifiez le fichier YAML et appliquez-le pour créer une instance PathwaysJob.

Vous pouvez suivre la progression de votre charge de travail en accédant à l'explorateur de journaux de votre conteneur JAX en sélectionnant main sous le filtre "Nom du conteneur".

Des journaux semblables à ceux ci-dessous devraient s'afficher, indiquant que l'entraînement progresse. La charge de travail se termine après 30 étapes.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Pour supprimer l'instance PathwaysJob, vous pouvez utiliser la commande suivante :

kubectl delete -f pathways-job-batch-training.yaml

Exécuter une charge de travail par lot à l'aide de XPK

Vous pouvez désormais envoyer l'image Docker MaxText prédéfinie à l'aide de XPK avec la même commande que celle utilisée précédemment.

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

Remplacez les éléments suivants :

  • WORKLOAD : nom unique permettant d'identifier votre charge de travail
  • CLUSTER : nom de votre cluster GKE
  • WORKLOAD_NODEPOOL_COUNT : nombre maximal de fois où le job peut être redémarré
  • TPU_TYPE : le type de TPU spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types de TPU compatibles avec chaque version de TPU, consultez Versions de TPU.
  • PROJECT : ID de votre projet Google Cloud
  • ZONE : zone dans laquelle vous prévoyez d'exécuter votre charge de travail.
  • USER : votre ID utilisateur Google Cloud
  • RUN_NAME : nom attribué par l'utilisateur pour identifier l'exécution du workflow

Vous devriez voir une sortie semblable à ce qui suit.

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

Utilisez le lien figurant dans le résultat de la commande XPK précédente pour suivre la progression de votre charge de travail. Vous pouvez filtrer les journaux de votre conteneur JAX en sélectionnant jax-tpu sous le filtre "Nom du conteneur".

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

La charge de travail se termine après le nombre d'étapes spécifié. Si vous souhaitez l'arrêter prématurément, utilisez la commande suivante :

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

Étapes suivantes