Modell mit TPU v6e trainieren

In diesem Dokument wird beschrieben, wie Sie Modelle auf Cloud TPU v6e (Trillium) trainieren. Wir behandeln dabei die Einrichtung der Umgebung, die Leistungsoptimierung und praktische Trainingsbeispiele mit JAX und PyTorch/XLA.

TPU v6e (Trillium) ist die 6. Generation der TPUs von Google. Auf allen technischen Oberflächen, z. B. in der API und in Logs, sowie in diesem Dokument wird Trillium als v6e bezeichnet. Mit 256 Chips pro Pod hat die Architektur von TPU v6e viele Ähnlichkeiten mit v5e. TPU v6e ist für das Training, die Feinabstimmung und die Bereitstellung von Transformer-, Text-zu-Bild- und CNN-Modellen (Convolutional Neural Network) optimiert. Weitere Informationen zur Systemarchitektur und zu den Konfigurationen von TPU v6e finden Sie unter TPU v6e.

Informationen zum Ausführen von Inferenzen auf Cloud TPU v6e finden Sie in den folgenden Anleitungen:

Vorbereitung

Zur Vorbereitung sind folgende Schritte erforderlich:

  • Konto und Projekt in Google Cloud mit aktivierter Abrechnung erstellen
  • Google Cloud CLI-Alphakomponenten installieren
  • Cloud TPU API aktivieren
  • Cloud TPU-Dienst-Agent erstellen
  • Cloud TPU-Dienstkonto erstellen und Berechtigungen erteilen

Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.

Kontingent und Berechtigungen prüfen

Prüfen Sie, ob Ihr Projekt die folgenden Kontingente hat:

Wenn Sie GKE mit XPK verwenden, benötigen Sie zusätzliche Berechtigungen in der Console von Google Cloud . Weitere Informationen finden Sie in der XPK-README unter Erforderliche Berechtigungen für die Console vonGoogle Cloud .

TPUs bereitstellen

Sie können TPU v6e mit den folgenden Methoden bereitstellen und verwalten:

  • GKE: Mit der GKE können Sie TPUs als Pool von Beschleunigern für Ihre containerisierten ML-Arbeitslasten bereitstellen und verwalten. Weitere Informationen finden Sie unter TPUs in GKE.
  • GKE und XPK: XPK ist ein Befehlszeilentool, das die Clustererstellung und die Ausführung von Arbeitslasten in der GKE vereinfacht. Es wurde für ML-Nutzer ohne umfassende Kubernetes-Kenntnisse entwickelt, die TPUs bereitstellen und Trainingsjobs ausführen möchten. Weitere Informationen finden Sie im XPK-GitHub-Repository.
  • In die Warteschlange gestellte Cloud TPU-Ressourcen: Sie können TPU-Kapazität anfordern, die bereitgestellt wird, sobald sie verfügbar ist. Diese Möglichkeit eignet sich ideal für Batchjobs und fehlertolerante Arbeitslasten, die in einer Warteschlange warten können. Sie können ein Zeitfenster für Ihre Anfrage angeben. Weitere Informationen finden Sie unter In die Warteschlange gestellte Ressourcen verwalten.

v6e-Cloud TPUs mit GKE und XPK bereitstellen

Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um Cloud TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter TPUs in GKE planen erfahren Sie, wie Sie Ihre Cloud TPU-Konfigurationen in GKE-Clustern planen. In den folgenden Abschnitten finden Sie Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für eine einzelne NIC und für mehrere NICs.

XPK-Cluster mit Unterstützung für eine einzelne NIC erstellen

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Vom Nutzer zugewiesener Name für den XPK-Cluster
PROJECT_ID Projektname inGoogle Cloud . Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues Projekt. Weitere Informationen finden Sie unter Ihr Projekt in Google Cloud einrichten.
ZONE Eine Liste der unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Anzahl der Slices, die Sie erstellen möchten
CLUSTER_ARGUMENTS Zu verwendendes Netzwerk und Subnetzwerk.

Beispiel: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Anzahl der zu erstellenden Slices
NETWORK_NAME Name eines zu verwendenden sekundären Netzwerks
NETWORK_FW_NAME Name einer zu verwendenden sekundären Netzwerkfirewall

XPK-Cluster mit Unterstützung für mehrere NICs erstellen

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Vom Nutzer zugewiesener Name für den XPK-Cluster
PROJECT_ID Projektname inGoogle Cloud . Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues Projekt. Weitere Informationen finden Sie unter Ihr Projekt in Google Cloud einrichten.
ZONE Eine Liste der unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Anzahl der Slices, die Sie erstellen möchten
CLUSTER_ARGUMENTS Zu verwendendes Netzwerk und Subnetzwerk.

Beispiel: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Zusätzliches zu verwendendes Knotennetzwerk.

Beispiel: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Anzahl der zu erstellenden Slices (nur für Multislice erforderlich)
NETWORK_NAME Name eines zu verwendenden sekundären Netzwerks
NETWORK_FW_NAME Name einer zu verwendenden sekundären Netzwerkfirewall

JAX oder PyTorch einrichten

In den folgenden Ressourcen wird beschrieben, wie Sie JAX oder PyTorch auf Ihrer Cloud TPU einrichten, je nachdem, welche Bereitstellungs- und Verwaltungsmethode Sie verwenden:

Informationen zum Einrichten und Ausführen von XPK mit MaxText finden Sie unter Running MaxText at Scale with XPK.

Netzwerkleistung optimieren

In diesem Abschnitt wird beschrieben, wie Sie die Netzwerkleistung optimieren, indem Sie die maximale Übertragungseinheit (Maximum Transmission Unit, MTU) konfigurieren, mehrere NICs in Multislice-Umgebungen verwenden und die TCP-Einstellungen verbessern.

MTU konfigurieren

Um die beste Netzwerkleistung zu erzielen, sollten Sie ein Netzwerk mit einer MTU von 8.896 verwenden.

Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerks auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Standard-Ethernet) und 8.896 Byte (größtmöglicher Wert). Weitere Informationen finden Sie unter Gültige MTU-Größen für VPC-Netzwerke.

Weitere Informationen zum Ändern der MTU-Einstellung für ein bestehendes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.

Im folgenden Beispiel wird ein Netzwerk mit einer MTU von 8.896 und einer entsprechenden Firewallregel erstellt, die TCP-, ICMP- und UDP-Traffic innerhalb des Netzwerks zulässt.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

Ersetzen Sie your-resource-name durch einen Basisnamen für das Netzwerk und die Firewall.

Option für mehrere NICs in Multislice verwenden

Wenn Sie eine Multislice-Umgebung verwenden, legen Sie die folgenden Umgebungsvariablen fest, die für ein sekundäres Subnetz erforderlich sind:

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Verwenden Sie die folgenden Befehle, um benutzerdefiniertes IP-Routing für das Netzwerk und das Subnetz zu erstellen.

  1. Erstellen Sie das sekundäre Netzwerk.

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. Erstellen Sie ein Subnetzwerk für das sekundäre Netzwerk.

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. Erstellen Sie eine Firewallregel, die Traffic innerhalb des neuen Subnetzwerks zulässt.

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. Erstellen Sie einen Cloud Router für das sekundäre Netzwerk.

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. Erstellen Sie eine NAT-Konfiguration für den Cloud Router.

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

Nachdem Sie einen Slice mit mehreren Netzwerken erstellt haben, können Sie prüfen, ob beide Netzwerkkarten (NICs) verwendet werden. Dazu richten Sie einen XPK-Cluster ein und fügen dem Befehl zum Erstellen einer XPK-Arbeitslast das Flag --command ifconfig hinzu.

  1. Verwenden Sie den folgenden Befehl workload create, um die Ausgabe des Befehls ifconfig in den Logs der Console von Google Cloud anzuzeigen und zu prüfen, ob die MTU für „eth0“ und „eth1“ auf 8.896 festgelegt ist.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    Wenn Sie Debugging-Logs aktivieren oder Vertex AI TensorBoard verwenden möchten, fügen Sie dem Befehl die folgenden optionalen Argumente hinzu:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. Prüfen Sie, ob die MTU für „eth0“ und „eth1“ auf 8.896 festgelegt ist. Sehen Sie dazu in der Ausgabe der XPK-Arbeitslast in den Logs der Console von Google Cloud nach.

TCP-Einstellungen verbessern

Wenn Sie Ihre Cloud TPUs mit in die Warteschlange gestellten Ressourcen bereitgestellt haben, können Sie die Netzwerkleistung verbessern, indem Sie die Limits für den TCP-Empfangszwischenspeicher erhöhen. Führen Sie dazu den folgenden Befehl aus.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Leistung der Arbeitsspeicherzuweisung optimieren

Die Bibliothek tcmalloc wird standardmäßig auf Cloud TPU-VMs verwendet, um die Leistung von Modellen mit umfangreichen, häufigen Arbeitsspeicherzuweisungen zu verbessern. Die Konfiguration erfolgt über die Umgebungsvariable LD_PRELOAD.

Bei einigen Arbeitslasten (z. B. DLRM mit sehr großen Zuweisungen für Einbettungstabellen) kann tcmalloc jedoch zu einer Verlangsamung führen. In solchen Fällen können Sie zur Standardfunktion malloc zurückkehren, indem Sie die Festlegung der Variable LD_PRELOAD in Ihrer Shell-Sitzung vor dem Ausführen des Trainingsscripts aufheben:

unset LD_PRELOAD

SkyPilot verwenden

Sie können Cloud TPU v6e mit SkyPilot verwenden. SkyPilot ist ein Open-Source-Framework, das das Ausführen, Verwalten und Skalieren von KI-Arbeitslasten vereinfacht. Sie können SkyPilot v6e-bezogene Standort- und Preisinformationen hinzufügen. Weitere Informationen finden Sie im SkyPilot-Beispiel für TPU v6e.

Trainingsbeispiele

In den folgenden Abschnitten finden Sie Beispiele für das Training von MaxText-, MaxDiffusion- und PyTorch-Modellen auf Cloud TPU v6e.

Diese Beispiele wurden mit den folgenden Softwareversionen getestet:

  • Python 3.10 oder höher
  • Nightly-Softwareversionen:
    • Nightly JAX 0.4.32.dev20240912
    • Nightly LibTPU 0.1.dev20240912+nightly
  • Stabile Softwareversionen:
    • JAX + jaxlib v0.4.37

MaxText und MaxDiffusion auf Cloud TPU v6e trainieren

In den folgenden Abschnitten wird der Trainingslebenszyklus der MaxText- und MaxDiffusion-Modelle beschrieben.

Es müssen allgemein folgende Schritte ausgeführt werden:

  1. Basis-Image für die Arbeitslast erstellen.
  2. Arbeitslast mit XPK ausführen.
    1. Trainingsbefehl für die Arbeitslast erstellen.
    2. Arbeitslast bereitstellen.
  3. Arbeitslast verfolgen und Messwerte ansehen.
  4. XPK-Arbeitslast löschen, wenn sie nicht benötigt wird.
  5. XPK-Cluster löschen, wenn er nicht mehr benötigt wird.

Basis-Image erstellen

Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:

  1. Klonen Sie das gewünschte Repository und wechseln Sie zum Verzeichnis des Repositorys:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Konfigurieren Sie Docker für die Verwendung der Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit einem JAX AI-Image. Weitere Informationen zu JAX AI-Images finden Sie unter JAX AI-Images.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Legen Sie Ihre Projekt-ID in der aktiven gcloud CLI-Konfiguration fest:

    gcloud config set project ${PROJECT_ID}
    
  5. Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch.

    1. Legen Sie die Umgebungsvariable CLOUD_IMAGE_NAME fest:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Laden Sie das Image hoch:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Arbeitslast mit XPK ausführen

  1. Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die Standardwerte von MaxText oder MaxDiffusion verwenden:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Erstellen Sie Ihr Modellscript. Dieses Script wird in einem späteren Schritt als Trainingsbefehl kopiert.

    Führen Sie das Modellscript noch nicht aus.

    MaxText

    MaxText ist ein leistungsstarkes, hochgradig skalierbares Open-Source-LLM, das in reinem Python und JAX geschrieben und auf TPUs von Google Cloud und GPUs für Training und Inferenz ausgerichtet ist.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma ist eine LLM-Reihe mit offenen Gewichtungen, die von Google DeepMind entwickelt wurde und auf der Forschung und Technologie von Gemini basiert.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7B

    Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine dünnbesetzte MoE-Architektur (Mixture of Experts) nutzt.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama ist eine LLM-Reihe mit offenen Gewichtungen, die von Meta entwickelt wurden.

    Ein Beispiel für die Ausführung von Llama3 in PyTorch finden Sie in den torch_xla-Modellen im torchprime-Repository.

    MaxDiffusion

    MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reinem Python und JAX geschrieben sind und auf XLA-Geräten ausgeführt werden, einschließlich von Cloud TPUs und GPUs. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert.

    Sie müssen einen bestimmten Git-Branch installieren, um MaxDiffusion auszuführen, wie im folgenden Trainingsscript gezeigt.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. Exportieren Sie die folgenden Variablen:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Beschreibungen der Umgebungsvariablen

    Variable Beschreibung
    CLUSTER_NAME Name des XPK-Clusters
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    NUM_SLICES Anzahl der TPU-Slices
    YOUR_MODEL_SCRIPT Modellscript, das als Trainingsbefehl ausgeführt werden soll
  4. Führen Sie das Modell mit dem Script aus, das Sie im vorherigen Schritt erstellt haben. Sie müssen entweder das Flag --base-docker-image angeben, um das MaxText-Basis-Image zu verwenden, oder das Flag --docker-image und das gewünschte Image.

    Sie können die folgenden optionalen Flags hinzufügen:

    • Sie können das Debugging-Logging aktivieren, indem Sie das Flag --enable-debug-logs einfügen. Weitere Informationen finden Sie unter JAX in MaxText debuggen.
    • Sie können ein Vertex AI-Experiment erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Dazu müssen Sie das Flag --use-vertex-tensorboard einfügen. Weitere Informationen finden Sie unter JAX mit Vertex AI auf MaxText überwachen.
    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können. Öffnen Sie den Link und klicken Sie auf den Tab Logs, um Ihre Arbeitslast in Echtzeit zu verfolgen.

JAX in MaxText debuggen

Verwenden Sie zusätzliche XPK-Befehle, um herauszufinden, warum der Cluster oder die Arbeitslast nicht ausgeführt wird:

  • XPK workload list
  • XPK inspector
  • Aktivieren Sie das ausführliche Logging in Ihren Arbeitslastlogs mit dem Flag --enable-debug-logs, wenn Sie die XPK-Arbeitslast erstellen.

JAX mit Vertex AI auf MaxText überwachen

Damit Sie TensorBoard verwenden können, muss Ihrem Nutzerkonto in Google Cloud die Rolle aiplatform.user zugewiesen sein. Führen Sie den folgenden Befehl aus, um diese Rolle zuzuweisen:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

So können Sie skalare und Profildaten über das von Vertex AI verwaltete TensorBoard ansehen:

  1. Erhöhen Sie die Resource-Management-Anfragen (CRUD) für die Zone, die Sie verwenden, von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.

  2. Installieren Sie Abhängigkeiten wie cloud-accelerator-diagnostics für Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Erstellen Sie Ihren XPK-Cluster mit dem Flag --create-vertex-tensorboard, wie in der XPK-README unter Create Vertex AI TensorBoard beschrieben. Sie können diesen Befehl auch auf vorhandenen Clustern ausführen.

  4. Erstellen Sie Ihr Vertex AI-Experiment mit dem Flag --use-vertex-tensorboard und dem optionalen Flag --experiment-name, wenn Sie Ihre XPK-Arbeitslast ausführen. Eine vollständige Liste der Schritte finden Sie in der XPK-README unter Create Vertex AI Experiment to upload data to Vertex AI Tensorboard.

Die Logs enthalten einen Link zu einem Vertex AI TensorBoard, der in etwa so aussieht:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Sie finden den Link zum Vertex AI TensorBoard auch in der Console von Google Cloud . Rufen Sie Vertex AI Experiments in der Console von Google Cloud auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.

Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR} angegeben haben.

XPK-Arbeitslast löschen

Verwenden Sie den Befehl xpk workload delete, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder dem Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen.

XPK-Cluster löschen

Verwenden Sie den Befehl xpk cluster delete, um den Cluster zu löschen:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

Benchmark-Ergebnisse für MaxDiffusion

Wir haben das Trainingsscript für MaxDiffusion auf einer v6e-4, einer v6e-16 und zwei v6e-16 ausgeführt. In der folgenden Tabelle sehen Sie die gemessenen Durchsätze.

v6e-4 v6e-16 Zwei v6e-16
Trainingsschritte 0,069 0,073 0,13
Globale Batchgröße 8 32 64
Durchsatz (Beispiele/s) 115,9 438,4 492,3

Llama-Modelle mit PyTorch/XLA auf Cloud TPU v6e trainieren

In diesem Abschnitt wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf Cloud TPU v6e mit dem WikiText-Dataset trainieren.

Zugriff auf Hugging Face und das Llama 3-Modell erhalten

Für dieses Beispiel benötigen Sie ein Hugging Face-Nutzerzugriffstoken. Informationen zum Erstellen von Nutzerzugriffstokens finden Sie in der Hugging Face-Dokumentation zu Nutzerzugriffstokens.

Außerdem benötigen Sie die Berechtigung für den Zugriff auf das Modell „Llama-3-8B“ auf Hugging Face. Rufen Sie dazu das Meta-Llama-3-8B-Modell auf Hugging Face auf und beantragen Sie den Zugriff.

Cloud TPU-VM erstellen

Erstellen Sie für dieses Beispiel eine Cloud TPU v6e mit 8 Chips.

  1. Richten Sie Umgebungsvariablen ein:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Beschreibungen der Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes oder erstellen Sie ein neues Projekt.
    TPU_NAME Name der TPU
    ZONE Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und -Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Softwareversion der Cloud TPU

  2. Erstellen Sie eine Cloud TPU-VM:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Installation

Installieren Sie den pytorch-tpu/transformers-Fork von Hugging Face Transformers und die Abhängigkeiten. Dieses Beispiel wurde mit den folgenden Abhängigkeitsversionen getestet:

  • torch: kompatibel mit 2.5.0
  • torch_xla[tpu]: kompatibel mit 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Modellkonfigurationsdateien einrichten

Im Trainingsbefehl im nächsten Abschnitt, Modell ausführen, werden zwei JSON-Konfigurationsdateien verwendet, um Modellparameter und die FSDP-Konfiguration zu definieren. Mit der FSDP-Fragmentierung (Fully Sharded Data Parallelism, vollständig fragmentierte Datenparallelität) können Sie beim Training eine größere Batchgröße verwenden, indem die Modellgewichtungen auf mehrere TPUs verteilt werden. Beim Training mit kleineren Modellen kann es ausreichen, Datenparallelität zu verwenden und die Gewichtungen auf jedem Gerät zu replizieren. Weitere Informationen zur Fragmentierung von Tensoren über mehrere Geräte in PyTorch/XLA finden Sie im PyTorch/XLA SPMD User Guide.

  1. Erstellen Sie die Modellparameter-Konfigurationsdatei. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama-3-8B. Die Konfigurationsdatei für andere Modelle finden Sie auf Hugging Face. Ein Beispiel finden Sie in der Llama-2-7B-Konfiguration.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Erstellen Sie die FSDP-Konfigurationsdatei:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Weitere Informationen zu FSDP finden Sie in der PyTorch-Dokumentation unter Fully Sharded Data Parallel using SPMD.

  3. Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre Cloud TPU-VMs hoch:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Modell ausführen

Führen Sie mit den Konfigurationsdateien, die Sie im vorherigen Abschnitt erstellt haben, das Script run_clm.py aus, um das Llama-3-8B-Modell mit dem WikiText-Dataset zu trainieren. Die Ausführung des Trainingsscripts dauert auf einer Cloud TPU v6e-8 etwa 10 Minuten.

  1. Melden Sie sich mit dem folgenden Befehl auf Ihrer Cloud TPU bei Hugging Face an:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Führen Sie das Modelltraining aus:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Fehlerbehebung bei PyTorch/XLA

Wenn Sie die optionalen Variablen für das Debugging im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell am Speicherort gespeichert, der in der Variable PROFILE_LOGDIR angegeben ist. Sie können die Datei xplane.pb, die sich an diesem Speicherort befindet, extrahieren und tensorboard verwenden, um die Profile im Browser anzusehen. Beachten Sie dazu die TensorBoard-Anleitung.

Wenn PyTorch/XLA nicht wie erwartet funktioniert, finden Sie im Leitfaden zur Fehlerbehebung Vorschläge zum Debuggen, Profiling und Optimieren Ihres Modells.