Esegui un carico di lavoro batch con Pathways

Ai fini di questo documento, i carichi di lavoro batch sono definiti come carichi di lavoro JAX che vengono eseguiti fino al completamento e di cui viene eseguito il deployment nello stesso cluster GKE del cluster Pathways, in particolare insieme ai componenti del controller Pathways (server proxy IFRT e gestore delle risorse Pathways). Il completamento del carico di lavoro JAX termina i componenti del cluster Pathways. Questa guida utilizza un carico di lavoro di addestramento JAX per dimostrarlo.

Prima di iniziare

Assicurati di avere:

Creare un'immagine di addestramento utilizzando MaxText

MaxText è un progetto di modello linguistico di grandi dimensioni (LLM) open source sviluppato da Google. È scritto in JAX ed è progettato per essere altamente performante e scalabile, con un'esecuzione efficiente su TPU e GPU Google Cloud.

Per creare un'immagine Docker MaxText utilizzando l'ultima versione di JAX stabile dal repository GitHub OSS, esegui il seguente comando:

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

Questo comando esegue il push dell'immagine Kubernetes MaxText in gcr.io/$PROJECT/${USER}_runner. Puoi utilizzare questa immagine Docker per eseguire l'addestramento sulle TPU utilizzando il backend Pathways.

Eseguire un carico di lavoro batch utilizzando l'API PathwaysJob

Il seguente manifest esegue il deployment dei componenti Pathways ed esegue un carico di lavoro MaxText utilizzando l'API PathwaysJob. Il carico di lavoro è incapsulato nel container main ed esegue train.py.

Copia il seguente YAML in un file denominato pathways-job-batch-training.yaml e aggiorna i valori modificabili.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

Sostituisci quanto segue:

  • USER : il tuo Google Cloud ID utente
  • MAX_RESTARTS : il numero massimo di volte in cui il job può essere riavviato
  • TPU_MACHINE_TYPE : il tipo di macchina TPU
  • TOPOLOGY : la topologia TPU v4 o versioni successive. Per saperne di più sulle versioni TPU e sulle topologie supportate, consulta Versioni TPU
  • WORKLOAD_NODEPOOL_COUNT : il numero di node pool utilizzati da un carico di lavoro Pathways
  • BUCKET_NAME : un bucket Cloud Storage per l'archiviazione dei file temporanei
  • PROJECT : il tuo Google Cloud ID progetto
  • RUN_NAME : un nome assegnato dall'utente per identificare l'esecuzione del workflow

Puoi eseguire il deployment del file YAML PathwaysJob nel seguente modo:

kubectl apply -f pathways-job-batch-training.yaml

Per visualizzare l'istanza PathwaysJob creata dal comando precedente, utilizza:

kubectl get pathwaysjob

L'output dovrebbe essere simile al seguente:

NAME             AGE
pathways-trial   9s

Per modificare un attributo dell'istanza PathwaysJob, elimina l'istanza PathwaysJob, modifica il file YAML e applicalo per creare una nuova istanza PathwaysJob.

Puoi seguire l'avanzamento del carico di lavoro accedendo a Esplora log per il container JAX scegliendo main nel filtro Nome container.

Dovresti visualizzare log simili ai seguenti, che indicano che l'addestramento è in corso. Il carico di lavoro verrà completato dopo 30 passaggi.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Per eliminare l'istanza PathwaysJob, puoi utilizzare il seguente comando:

kubectl delete -f pathways-job-batch-training.yaml

Eseguire un carico di lavoro batch utilizzando XPK

Ora puoi inviare l'immagine Docker MaxText precompilata utilizzando XPK con lo stesso comando che hai utilizzato in precedenza.

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

Sostituisci quanto segue:

  • WORKLOAD: un nome univoco per identificare il carico di lavoro
  • CLUSTER: il nome del cluster GKE
  • WORKLOAD_NODEPOOL_COUNT : il numero massimo di volte in cui il job può essere riavviato
  • TPU_TYPE: il tipo di TPU specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di TPU supportati per ogni versione TPU, consulta Versioni TPU
  • PROJECT : il tuo ID progetto Google Cloud
  • ZONE: la zona in cui prevedi di eseguire il carico di lavoro
  • USER : il tuo Google Cloud ID utente
  • RUN_NAME : un nome assegnato dall'utente per identificare l'esecuzione del workflow

Dovresti vedere un output simile al seguente:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

Utilizza il link nell'output del comando XPK precedente per seguire l'avanzamento del carico di lavoro. Puoi filtrare i log per il container JAX scegliendo jax-tpu nel filtro Nome container.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Il carico di lavoro viene completato dopo il numero di passaggi specificato. Se vuoi terminarlo prima del tempo, utilizza il seguente comando:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

Passaggi successivi