Batcharbeitslast mit Pathways ausführen

Für dieses Dokument werden Batcharbeitslasten als JAX-Arbeitslasten definiert, die vollständig ausgeführt werden und im selben GKE-Cluster wie der Pathways-Cluster bereitgestellt werden, insbesondere neben den Pathways-Controllerkomponenten (IFRT-Proxyserver und Pathways-Ressourcenmanager). Wenn die JAX-Arbeitslast abgeschlossen ist, werden die Pathways-Clusterkomponenten beendet. In dieser Anleitung wird eine JAX-Trainingsarbeitslast verwendet, um dies zu demonstrieren.

Hinweise

Sie benötigen Folgendes:

Trainings-Image mit MaxText erstellen

MaxText ist ein von Google entwickeltes Open-Source-Projekt für Large Language Models (LLMs). Es ist in JAX geschrieben und für hohe Leistung und Skalierbarkeit konzipiert. Es lässt sich effizient auf Google Cloud-TPUs und ‑GPUs ausführen.

Wenn Sie ein MaxText-Docker-Image mit der neuesten Version von stable JAX aus dem OSS-GitHub-Repository erstellen möchten, führen Sie den folgenden Befehl aus:

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

Mit diesem Befehl wird das MaxText-Kubernetes-Image an gcr.io/$PROJECT/${USER}_runner übertragen. Sie können dieses Docker-Image verwenden, um das Training auf TPUs mit dem Pathways-Backend auszuführen.

Batcharbeitslast mit der PathwaysJob API ausführen

Mit dem folgenden Manifest werden die Pathways-Komponenten bereitgestellt und ein MaxText-Arbeitslast mit der PathwaysJob API ausgeführt. Die Arbeitslast ist im Container main gekapselt und nutzt train.py.

Kopieren Sie das folgende YAML-Manifest in eine Datei mit dem Namen pathways-job-batch-training.yaml und aktualisieren Sie die bearbeitbaren Werte.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

Ersetzen Sie Folgendes:

  • USER : Ihre Google Cloud Nutzer-ID
  • MAX_RESTARTS : die maximale Anzahl an Neustarts des Jobs
  • TPU_MACHINE_TYPE : der TPU-Maschinentyp
  • TOPOLOGY : Die TPU v4- oder höhere Topologie. Weitere Informationen zu TPU-Versionen und unterstützten Topologien finden Sie unter TPU-Versionen.
  • WORKLOAD_NODEPOOL_COUNT : Die Anzahl der Knotenpools, die von einer Pathways-Arbeitslast verwendet werden.
  • BUCKET_NAME : Ein Cloud Storage-Bucket zum Speichern temporärer Dateien
  • PROJECT : Ihre Google Cloud Projekt-ID
  • RUN_NAME : Ein vom Nutzer zugewiesener Name zur Identifizierung des Workflow-Laufs.

Sie können das PathwaysJob-YAML so bereitstellen:

kubectl apply -f pathways-job-batch-training.yaml

Verwenden Sie den folgenden Befehl, um die mit dem vorherigen Befehl erstellte PathwaysJob-Instanz aufzurufen:

kubectl get pathwaysjob

Die Ausgabe sollte in etwa so aussehen:

NAME             AGE
pathways-trial   9s

Wenn Sie ein Attribut der PathwaysJob-Instanz ändern möchten, löschen Sie die PathwaysJob-Instanz, ändern Sie das YAML und wenden Sie es an, um eine neue PathwaysJob-Instanz zu erstellen.

Sie können den Fortschritt Ihrer Arbeitslast im Log-Explorer für Ihren JAX-Container verfolgen, indem Sie unter dem Filter „Containername“ die Option main auswählen.

Sie sollten Logs wie die folgenden sehen, die darauf hinweisen, dass das Training fortschreitet. Die Arbeitslast wird nach 30 Schritten abgeschlossen.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Mit dem folgenden Befehl können Sie die Instanz PathwaysJob löschen:

kubectl delete -f pathways-job-batch-training.yaml

Batcharbeitslast mit XPK ausführen

Sie können das vorgefertigte MaxText-Docker-Image jetzt mit XPK mit demselben Befehl wie zuvor einreichen.

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

Ersetzen Sie Folgendes:

  • WORKLOAD: ein eindeutiger Name zur Identifizierung Ihres Workloads
  • CLUSTER: der Name Ihres GKE-Cluster
  • WORKLOAD_NODEPOOL_COUNT : die maximale Anzahl von Neustarts des Jobs
  • TPU_TYPE: Der TPU-Typ gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten TPU-Typen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
  • PROJECT : Ihre Google Cloud Projekt-ID
  • ZONE: Die Zone, in der Sie Ihre Arbeitslast ausführen möchten
  • USER : Ihre Google Cloud Nutzer-ID
  • RUN_NAME : Ein vom Nutzer zugewiesener Name zur Identifizierung des Workflow-Laufs.

Die Ausgabe sollte etwa so aussehen:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

Verwenden Sie den Link in der Ausgabe des vorherigen XPK-Befehls, um den Fortschritt Ihrer Arbeitslast zu verfolgen. Sie können die Logs für Ihren JAX-Container filtern, indem Sie unter dem Filter „Container Name“ (Containername) jax-tpu auswählen.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Die Arbeitslast wird nach der angegebenen Anzahl von Schritten abgeschlossen. Wenn Sie sie vorzeitig beenden möchten, verwenden Sie den folgenden Befehl:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

Nächste Schritte