Train a model using TPU v6e
This document guides you through training models on Cloud TPU v6e (also called Trillium), covering environment setup, performance optimization, and practical training examples using JAX and PyTorch/XLA.
TPU v6e, also called Trillium, is Google's 6th generation of TPUs. On all technical surfaces, such as the API and logs, and throughout this document, Trillium will be referred to as v6e. With 256 chips per Pod, the architecture of TPU v6e shares many similarities with v5e. TPU v6e is optimized for transformer, text-to-image, and convolutional neural network (CNN) training, fine-tuning, and serving. For more information about the TPU v6e system architecture and configurations, see TPU v6e.
For information about running inference on Cloud TPU v6e, see the following tutorials:
Before you begin
Before you begin, you need to:
- Create a Google Cloud account and project with billing enabled
- Install Google Cloud CLI alpha components
- Enable the Cloud TPU API
- Create a Cloud TPU service agent
- Create a Cloud TPU service account and grant permissions
For more information, see Set up the Cloud TPU environment.
Verify quota and permissions
Verify that your project has the following quotas:
- TPU v6e preemptible or on-demand quota
- IP address quota
Quota for Hyperdisk Balanced and for any other disk types you want to use
If you're using Google Kubernetes Engine (GKE) with XPK (Accelerated Processing Kit), you need additional permissions in the Google Cloud console. For more information, see Permissions needed on Google Cloud console .
Provisioning options
You can provision and manage TPU v6e using the following methods:
- GKE: You can use GKE to provision and manage TPUs as a pool of accelerators for your containerized machine learning workloads. For more information, see About TPUs in GKE.
- GKE and XPK: XPK is a command-line tool that simplifies cluster creation and workload execution on GKE. It's designed for ML practitioners to provision TPUs and run training jobs without needing deep Kubernetes expertise. For more information, see the XPK GitHub repository.
- Cloud TPU queued resources: Queued resources let you request TPU capacity that is provisioned when it becomes available. It's ideal for batch jobs and fault-tolerant workloads that can wait in a queue. You can specify a time window for your request. For more information, see Manage queued resources.
Provision v6e TPUs with GKE and XPK
If you are using GKE with v6e TPUs, you can use Kubernetes commands or XPK to provision TPUs and train or serve models. For more information about using GKE with TPUs, see About TPUs in GKE.
Cloud TPU v6e supports network interface card (NIC) configurations that let you scale throughput across multiple networks. The following sections provide commands to create a GKE cluster with single-NIC support or multi-NIC support using XPK. For most single-slice workloads, single-NIC provides sufficient performance with less configuration. For Multislice workloads and workloads that require high data ingestion speeds, use multi-NIC.
Create a cluster with single-NIC support using XPK
For most single-slice workloads, single-NIC provides sufficient performance with less configuration. For Multislice workloads and workloads that require high data ingestion speeds, use multi-NIC.
The following sections show how to create a GKE cluster with single-NIC support using XPK.
Install XPK and set up environment variables
Install XPK. Follow the instructions in the XPK GitHub repository.
Set up environment variables for your cluster:
export CLUSTER_NAME=XPK_CLUSTER_NAME export ZONE=us-east1-d export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=1
Set the following environment variables:
CLUSTER_NAME: A name for your cluster.ZONE: The zone where the TPU cluster will be created. For more information about supported zones, see Regions and zones.PROJECT_ID: Your Google Cloud project ID.ACCELERATOR_TYPE: The TPU type, also called accelerator type, specifies the version and size of the Cloud TPU you want to create. For example,v6e-256. For more information about supported accelerator types for each TPU version, see TPU versions.NUM_SLICES: The number of TPU slices for your cluster. Each slice has the number of chips specified inACCELERATOR_TYPE. For a single-slice cluster, setNUM_SLICESto 1. For a Multislice cluster, specify the number of slices based on your workload's scalability requirements. The total number of chips in the cluster is the number of chips inACCELERATOR_TYPEmultiplied byNUM_SLICES.
Create the cluster
Choose one of the following options to create your cluster. Using a custom network with 8,896 MTU is recommended for optimal performance. For more information, see Configure MTU.
Custom network
To create a custom network with 8,896 MTU and use it for your cluster, follow these steps:
Set environment variables for the network and firewall names:
export NETWORK_NAME=NETWORK_NAME export NETWORK_FW_NAME=FIREWALL_NAME
Replace the following:
- NETWORK_NAME: A name for the network.
- FIREWALL_NAME: A name for the network firewall rule.
Create a custom network with an MTU of 8,896:
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
Create a firewall rule that allows TCP, ICMP, and UDP traffic on your network:
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
Set an environment variable for the XPK cluster arguments to use the network you created:
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
Create the XPK cluster. The following command provisions on-demand capacity:
xpk cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"
To use reserved capacity, replace
--on-demandwith--reservation=RESERVATION_NAME. To use TPU Spot VMs, replace--on-demandwith--spot.
Default network
If you don't require a high-MTU network, you can create a cluster that uses the default VPC network. The following command provisions on-demand capacity:
xpk cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand
To use reserved capacity, replace --on-demand with
--reservation=RESERVATION_NAME. To use TPU
Spot VMs, replace --on-demand with --spot.
Create a cluster with multi-NIC support using XPK
For Multislice workloads or other workloads that require high network bandwidth, such as for data ingestion, you can use multi-NIC to improve performance. When you use multi-NIC, each TPU VM is allocated additional network interfaces, each connected to a unique VPC network, increasing overall network throughput. For most single-slice workloads, single-NIC provides sufficient performance with less configuration.
The following sections show how to create a GKE cluster with multi-NIC support using XPK.
Install XPK and set up environment variables
Install XPK. Follow the instructions in the XPK GitHub repository.
Set up environment variables for your cluster and primary network:
export CLUSTER_NAME=XPK_CLUSTER_NAME export REGION=REGION export ZONE=us-east1-d export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export FIREWALL_RULE_NAME_1=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME_1=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG_1=${CLUSTER_NAME}-natconfig-1-${ZONE} export NETWORK_NAME_2=${CLUSTER_NAME}-mtu9k-2-${ZONE} export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE} export FIREWALL_RULE_NAME_2=${CLUSTER_NAME}-privatefirewall-2-${ZONE} export ROUTER_NAME_2=${CLUSTER_NAME}-network-2-${ZONE} export NAT_CONFIG_2=${CLUSTER_NAME}-natconfig-2-${ZONE}
Set the following environment variables:
CLUSTER_NAME: A name for your cluster.REGION: The region where your TPU cluster will be created.ZONE: The zone where the TPU cluster will be created. For more information about supported zones, see Regions and zones.PROJECT_ID: Your Google Cloud project ID.ACCELERATOR_TYPE: The accelerator type specifies the version and size of the Cloud TPU you want to create. For example,v6e-256. For more information about supported accelerator types for each TPU version, see TPU versions.NUM_SLICES: The number of TPU slices for your cluster. Each slice has the number of chips specified inACCELERATOR_TYPE. For a single-slice cluster, setNUM_SLICESto 1. For a Multislice cluster, specify the number of slices based on your workload's scalability requirements. The total number of chips in the cluster is the number of chips inACCELERATOR_TYPEmultiplied byNUM_SLICES.
Create the primary network resources
Create the primary network with a maximum transmission unit (MTU) of 8,896:
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
Using a custom network with an MTU of 8,896 provides better performance. For more information, see Configure MTU.
Create the primary subnetwork:
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
Create a firewall rule for the primary network that allows
tcp,icmp, andudptraffic on the primary network:gcloud compute firewall-rules create ${FIREWALL_RULE_NAME_1} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
Create a Cloud Router for the primary network:
gcloud compute routers create ${ROUTER_NAME_1} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
Configure NAT for the primary network. The following command allows traffic from your cluster to reach the internet:
gcloud compute routers nats create ${NAT_CONFIG_1} \ --router=${ROUTER_NAME_1} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
Create the secondary network resources
Create the secondary network:
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}Create a subnetwork for the secondary network:
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}Create a firewall rule to allow traffic within the new network:
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME_2} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --source-ranges 10.10.0.0/18 \ --project=${PROJECT_ID}Create a Cloud Router for the secondary network:
gcloud compute routers create ${ROUTER_NAME_2} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}Create a NAT configuration for the Cloud Router:
gcloud compute routers nats create ${NAT_CONFIG_2} \ --router=${ROUTER_NAME_2} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
Create the cluster
Set an environment variable for the cluster and node pool arguments to use the networks and subnetworks you created:
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}" export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"These arguments configure the cluster to use the two networks you created for multi-NIC support.
Create the cluster. The following command provisions on-demand capacity:
xpk cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
To use reserved capacity, replace
--on-demandwith--reservation=RESERVATION_NAME. To use TPU Spot VMs, replace--on-demandwith--spot.
Validate multi-NIC setup
After you create a cluster with multi-NIC support, you can validate that
both NICs are being used by creating an XPK
workload
and adding the --command ifconfig flag.
Use the following command to display the output of the
ifconfigcommand in Google Cloud console logs. You must either specify the--base-docker-image maxtext_base_imageflag to use the MaxText base image, as in the following example, or specify the--docker-imageflag and the image you want to use.xpk workload create \ --cluster ${CLUSTER_NAME} \ --base-docker-image maxtext_base_image \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
If you want to enable debug logs or use Vertex AI TensorBoard, add the following optional arguments to the command:
--enable-debug-logs \ --use-vertex-tensorboard
Verify that both eth0 and eth1 have MTU set to 8,896 by checking the output of the XPK workload in Google Cloud console logs.
Set up JAX or PyTorch
The following resources show how to set up JAX or PyTorch on your TPU, depending on which provisioning and management method you use:
- GKE Autopilot: Prepare your TPU application
- GKE Standard: Prepare your workloads
- GKE and XPK: XPK README
- Single-host Cloud TPU using JAX: Run a calculation on a Cloud TPU VM using JAX
- Multi-host Cloud TPU using JAX: Run JAX code on TPU slices
- Single-host Cloud TPU using PyTorch: Run a calculation on a Cloud TPU VM using PyTorch
- Multi-host Cloud TPU using PyTorch: Run PyTorch code on TPU slices
To set up and run XPK with MaxText, see Running MaxText at Scale with XPK .
Improve TCP settings
If you provisioned your v6e TPUs using queued resources, you can run the following command to improve network performance by increasing TCP receive buffer limits.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
Use SkyPilot
You can use Cloud TPU v6e with SkyPilot. SkyPilot is an open-source framework that simplifies the process of running, managing, and scaling AI workloads. You can add v6e-related location and pricing information to SkyPilot. For more information, see the SkyPilot TPU v6e example.
Training examples
The following sections provide examples for training MaxText, MaxDiffusion, and PyTorch models on Cloud TPU v6e.
These examples have been tested with the following software versions:
- Python
3.10or later - Nightly software versions:
- Nightly JAX
0.4.32.dev20240912 - Nightly LibTPU
0.1.dev20240912+nightly
- Nightly JAX
- Stable software versions:
- JAX + JAX Lib of v0.4.37
Train MaxText and MaxDiffusion on Cloud TPU v6e
The following sections cover the training lifecycle of the MaxText and MaxDiffusion models.
In general, the high-level steps are:
- Build the workload base image.
- Run your workload using XPK.
- Build the training command for the workload.
- Deploy the workload.
- Follow the workload and view metrics.
- Delete the XPK workload if it isn't needed.
- Delete the cluster when it's no longer needed.
Build base image
Install MaxText or MaxDiffusion and build the Docker image:
Clone the repository you want to use and change to the directory for the repository:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtextMaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103Configure Docker to use the Google Cloud CLI:
gcloud auth configure-dockerBuild the Docker image using the following command or using a JAX AI image. For more information about JAX AI images, see JAX AI images.
MaxText:
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35MaxDiffusion:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latestSet your project ID in your active gcloud CLI configuration:
gcloud config set project ${PROJECT_ID}If you're launching the workload from a machine that doesn't have the image built locally, upload the image.
Set the
CLOUD_IMAGE_NAMEenvironment variable:export CLOUD_IMAGE_NAME=${USER}_runnerUpload the image:
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
Run your workload using XPK
Set the following environment variables if you're not using the default values set by MaxText or MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Build your model script. This script will be copied as a training command in a later step.
Don't execute the model script yet.
MaxText
MaxText is a high performance, highly scalable, open-source LLM written in pure Python and JAX and targeting Google Cloud TPUs and GPUs for training and inference.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"Gemma2
Gemma is a family of open-weights LLMs developed by Google DeepMind, based on Gemini research and technology.
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flashMixtral 8x7b
Mixtral is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture.
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplaneLlama3-8b
Llama is a family of open-weights LLMs developed by Meta.
For an example of how to run Llama3 on PyTorch, see torch_xla models in the torchprime repository.
MaxDiffusion
MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python and JAX that run on XLA devices including Cloud TPUs and GPUs. Stable Diffusion is a latent text-to-image model that generates photo-realistic images from any text input.
You need to install a specific Git branch to run MaxDiffusion as shown in the following training script.
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash \ run_name=sdxl-ddp-v6eExport the following variables:
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
Environment variable descriptions
CLUSTER_NAME: The name of your cluster.ACCELERATOR_TYPE: The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.NUM_SLICES: The number of TPU slices.YOUR_MODEL_SCRIPT: The model script to execute as a training command.
Run the model using the script you created in the previous step. You must either specify the
--base-docker-imageflag to use the MaxText base image or specify the--docker-imageflag and the image you want to use.You can choose to add the following optional flags:
- You can enable debug logging by including the
--enable-debug-logsflag. For more information, see Debug JAX on MaxText. - You can create a Vertex AI Experiment to upload data to
Vertex AI TensorBoard by including the
--use-vertex-tensorboardflag. For more information, see Monitor JAX on MaxText using Vertex AI.
xpk workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command="${YOUR_MODEL_SCRIPT}"
The output includes a link to follow your workload. Open the link and click the Logs tab to track your workload in real time.
- You can enable debug logging by including the
Debug JAX on MaxText
Use supplemental XPK commands to diagnose why the cluster or workload isn't running:
- XPK workload list
- XPK inspector
- Enable verbose logging in your workload logs using the
--enable-debug-logsflag when you create the XPK workload
Monitor JAX on MaxText using Vertex AI
To use TensorBoard, your Google Cloud user account must have the aiplatform.user
role. Run the following command to grant this role:
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
View scalar and profile data through the Vertex AI managed TensorBoard.
Increase resource management (CRUD) requests for the zone you're using from 600 to 5000. This might not be an issue for small workloads using less than 16 VMs.
Install dependencies such as
cloud-accelerator-diagnosticsfor Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Create your cluster using the
--create-vertex-tensorboardflag, as documented in Create Vertex AI TensorBoard. You can also run this command on existing clusters.Create your Vertex AI experiment when running your XPK workload using the
--use-vertex-tensorboardflag and the optional--experiment-nameflag. For the full list of steps, see Create Vertex AI Experiment to upload data to Vertex AI TensorBoard.
The logs include a link to a Vertex AI TensorBoard, similar to the following:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
You can also find the Vertex AI TensorBoard link in the Google Cloud console. Go to Vertex AI Experiments in the Google Cloud console. Select the appropriate region from the drop-down.
The TensorBoard directory is also written to the Cloud Storage bucket that
you specified with ${BASE_OUTPUT_DIR}.
Delete your XPK workload
Use the xpk workload delete command
to delete one or more workloads based on the job prefix or job status. This
command might be useful if you sent XPK workloads that no longer need to be run,
or if you have jobs that are stuck in the queue.
Delete your cluster
Use the xpk cluster delete command to delete your cluster:
xpk cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
MaxDiffusion benchmarking results
We ran the training script for MaxDiffusion on a v6e-4, a v6e-16, and two v6e-16. The following table shows the measured throughputs.
| v6e-4 | v6e-16 | Two v6e-16 | |
|---|---|---|---|
| Training steps | 0.069 | 0.073 | 0.13 |
| Global batch size | 8 | 32 | 64 |
| Throughput (examples/sec) | 115.9 | 438.4 | 492.3 |
Train Llama models using PyTorch/XLA on Cloud TPU v6e
This section describes how to train Llama models using PyTorch/XLA on Cloud TPU v6e using the WikiText dataset.
Get access to Hugging Face and the Llama 3 model
You need a Hugging Face user access token for this example. For information about creating user access tokens, see the Hugging Face documentation on user access tokens.
You also need permission to access the Llama-3-8B model on Hugging Face. To get access, go to the Meta-Llama-3-8B model on HuggingFace and request access.
Create a Cloud TPU VM
Create a Cloud TPU v6e with 8 chips for this example.
Set up environment variables:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east1-d export ACCELERATOR_TYPE=v6e-8 export RUNTIME_VERSION=v2-alpha-tpuv6e
Environment variable descriptions
PROJECT_ID: Your Google Cloud project ID. Use an existing project or create a new one.TPU_NAME: The name of the TPU.ZONE: The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones.ACCELERATOR_TYPE: The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.RUNTIME_VERSION: The Cloud TPU software version.
Create a Cloud TPU VM:
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Installation
Install the pytorch-tpu/transformers fork
of Hugging Face transformers and dependencies. This example was tested with the
following dependency versions:
torch: compatible with 2.5.0torch_xla[tpu]: compatible with 2.5.0jax: 0.4.33jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
Set up model configuration files
The training command in the next section, Run the model, uses two JSON configuration files to define model parameters and Fully Sharded Data Parallel (FSDP) configuration. FSDP sharding lets you use a bigger batch size while training by sharding your model weights across multiple TPUs. When training with smaller models, it might be sufficient to use data parallelism and replicate the weights on each device. For more information about how to shard tensors across devices in PyTorch/XLA, see PyTorch/XLA SPMD user guide.
Create the model parameter configuration file. The following is the model parameter configuration for Llama-3-8B. For other models, find the configuration file on Hugging Face. For example, see the Llama-2-7B config.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOFCreate the FSDP configuration file:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOFFor more information about FSDP, see Fully Sharded Data Parallel using SPMD .
Upload the configuration files to your Cloud TPU VMs using the following command:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Run the model
Using the configuration files you created in the previous section, run the
run_clm.py script to train the Llama-3-8B model on the WikiText dataset. The
training script takes approximately 10 minutes to run on a Cloud TPU v6e-8.
Sign in to Hugging Face on your Cloud TPU using the following command:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Run the model training:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Troubleshooting PyTorch/XLA
If you set the optional variables for debugging in the previous section,
the profile for the model will be stored at the location specified by the
variable PROFILE_LOGDIR. You can extract the xplane.pb file stored
at this location and use tensorboard to view the profiles in your
browser using the TensorBoard instructions.
If PyTorch/XLA isn't performing as expected, see the Troubleshooting guide, which has suggestions for debugging, profiling, and optimizing your model.