Exécuter une charge de travail par lot avec Pathways

Dans ce document, les charges de travail par lot sont définies comme des charges de travail JAX qui s'exécutent jusqu'à la fin et sont déployées dans le même cluster GKE que le cluster Pathways, en particulier à côté des composants du contrôleur Pathways (serveur proxy IFRT et gestionnaire de ressources Pathways). Une fois la charge de travail JAX terminée, les composants du cluster Pathways sont arrêtés. Ce guide utilise une charge de travail d'entraînement JAX pour illustrer ce point.

Avant de commencer

Vérifiez que vous disposez bien des éléments suivants :

Créer une image d'entraînement à l'aide de MaxText

MaxText est un projet de modèle de langage étendu (LLM) Open Source développé par Google. Il est écrit en JAX et conçu pour être très performant et évolutif, et s'exécute efficacement sur les TPU et les GPU Google Cloud.

Pour créer une image Docker MaxText à l'aide de la dernière version stable de JAX à partir du dépôt GitHub OSS, exécutez la commande suivante :

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

Cette commande transfère l'image Kubernetes MaxText vers gcr.io/$PROJECT/${USER}_runner. Vous pouvez utiliser cette image Docker pour exécuter l'entraînement sur des TPU à l'aide du backend Pathways.

Exécuter une charge de travail par lot à l'aide de l'API PathwaysJob

Le manifeste suivant déploie les composants Pathways et exécute une charge de travail MaxText à l'aide de l'API PathwaysJob. La charge de travail est encapsulée dans le main conteneur et exécute train.py.

Copiez le fichier YAML suivant dans un fichier nommé pathways-job-batch-training.yaml, puis mettez à jour les valeurs modifiables.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

Remplacez les éléments suivants :

  • USER : ID de votre Google Cloud utilisateur
  • MAX_RESTARTS : nombre maximal de redémarrages de la tâche
  • TPU_MACHINE_TYPE : le type de machine TPU
  • TOPOLOGY : topologie TPU v4 ou ultérieure. Pour en savoir plus sur les versions de TPU et les topologies compatibles, consultez Versions de TPU.
  • WORKLOAD_NODEPOOL_COUNT : nombre de pools de nœuds utilisés par une charge de travail Pathways
  • BUCKET_NAME : bucket Cloud Storage pour stocker les fichiers temporaires
  • PROJECT : ID de votre Google Cloud projet
  • RUN_NAME : nom attribué par l'utilisateur pour identifier l'exécution du workflow

Vous pouvez déployer le fichier YAML PathwaysJob comme suit :

kubectl apply -f pathways-job-batch-training.yaml

Pour afficher l'instance PathwaysJob créée par la commande précédente, utilisez :

kubectl get pathwaysjob

Le résultat doit se présenter comme suit :

NAME             AGE
pathways-trial   9s

Pour modifier un attribut de l'instance PathwaysJob, supprimez-la, modifiez le fichier YAML et appliquez-le pour créer une instance PathwaysJob.PathwaysJob

Vous pouvez suivre la progression de votre charge de travail en accédant à l'explorateur de journaux pour votre conteneur JAX et en sélectionnant main dans le filtre "Nom du conteneur".

Vous devriez voir des journaux semblables à ceux ci-dessous, indiquant que l'entraînement progresse. La charge de travail se terminera après 30 étapes.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Pour supprimer l'instance PathwaysJob, vous pouvez utiliser la commande suivante :

kubectl delete -f pathways-job-batch-training.yaml

Exécuter une charge de travail par lot à l'aide de XPK

Vous pouvez maintenant envoyer l'image Docker MaxText prédéfinie à l'aide de XPK avec la même commande que celle utilisée précédemment.

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

Remplacez les éléments suivants :

  • WORKLOAD : nom unique pour identifier votre charge de travail
  • CLUSTER : nom de votre cluster GKE
  • WORKLOAD_NODEPOOL_COUNT : nombre maximal de redémarrages de la tâche
  • TPU_TYPE : le type de TPU spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types de TPU compatibles avec chaque version de TPU, consultez Versions de TPU
  • PROJECT : ID de votre Google Cloud projet
  • ZONE : zone dans laquelle vous prévoyez d'exécuter votre charge de travail
  • USER : ID de votre Google Cloud utilisateur
  • RUN_NAME : nom attribué par l'utilisateur pour identifier l'exécution du workflow

Vous devriez voir une sortie semblable à ce qui suit.

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

Utilisez le lien dans le résultat de la commande XPK précédente pour suivre la progression de votre charge de travail. Vous pouvez filtrer les journaux de votre conteneur JAX en sélectionnant jax-tpu dans le filtre "Nom du conteneur".

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

La charge de travail se termine après le nombre d'étapes spécifié. Si vous souhaitez l'arrêter prématurément, utilisez la commande suivante :

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

Étape suivante