Run supervised fine-tuning on a TPU VM by using MaxText

This tutorial provides a step-by-step guide for running supervised fine-tuning (SFT) on a single v6e-8 Tensor Processing Unit (TPU) virtual machine (VM) instance on Google Cloud by using MaxText, a high-performance JAX-based training stack for large language models (LLMs).

Objectives

  • Set up a Cloud TPU VM instance.
  • Install MaxText and its dependencies.
  • Convert a Hugging Face model to MaxText format.
  • Run an SFT training workload on the TPU.
  • Convert the fine-tuned model back to Hugging Face format for serving.

Costs

In this document, you use the following billable components of Google Cloud:

To generate a cost estimate based on your projected usage, use the pricing calculator.

New Google Cloud users might be eligible for a free trial.

When you finish the tasks that are described in this document, you can avoid continued billing by deleting the resources that you created. For more information, see Clean up.

Before you begin

  • You need a Hugging Face access token to use this tutorial. You can sign up for a free account at Hugging Face. Once you have an account, generate an access token:

    1. On the Welcome to Hugging Face page, click your account avatar and select Access tokens.
    2. On the Access tokens page, click Create new token.
    3. Select the Read token type and enter a name for your token.
    4. Your access token is displayed. Save the token in a safe place.

  • On the Hugging Face website, accept the license agreement for the model that you plan to train. This tutorial uses the model gemma3-4b.

To get the permissions that you need to complete this tutorial, ask your administrator to grant you the following IAM roles on your project:

For more information about granting roles, see Manage access to projects, folders, and organizations.

You might also be able to get the required permissions through custom roles or other predefined roles.

Set up the environment

Set up your environment variables by running the following script:

export PROJECT="YOUR_PROJECT_ID"
export ZONE="ZONE_NAME"
export RESERVATION="RESERVATION_NAME"
export NAME="TPU_MACHINE_NAME"

Replace the following:

  • YOUR_PROJECT_ID: your Google Cloud project ID
  • ZONE_NAME: the zone that you want to use
  • RESERVATION_NAME: your capacity reservation
  • TPU_MACHINE_NAME: the name of your Cloud TPU VM instance

Authenticate with Google Cloud by running the following command:

gcloud auth login

Create your Cloud TPU VM

Create a Cloud TPU VM instance with 8 v6e TPU chips, bound to your capacity reservation.

gcloud alpha compute tpus tpu-vm create $NAME \
    --zone=$ZONE \
    --project=$PROJECT \
    --accelerator-type=v6e-8 \
    --version=v2-alpha-tpuv6e \
    --provisioning-model=reservation-bound \
    --reservation=$RESERVATION

After the VM instance has been created, connect to it by using SSH.

gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --project $PROJECT

Complete the steps that follow within your TPU VM instance.

Install MaxText

Update the system packages within the TPU VM instance.

sudo apt update && sudo apt upgrade -y --fix-missing

Install Python 3.12, which MaxText requires, and its virtual environment package.

sudo apt install -y python3.12 python3.12-venv

Use uv to speed up the installation of the Python package.

curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env

Create a virtual environment named maxtext_venv and activate it.

uv venv --python 3.12 --seed maxtext_venv
source maxtext_venv/bin/activate

Install MaxText and the dependencies that it requires for post-training tasks.

uv pip install maxtext[tpu-post-train]==0.2.2 --resolution=lowest

Install the remaining required dependencies by running the following command:

#install_maxtext_tpu_post_train_extra_deps
install_tpu_post_train_extra_deps

Convert the model to MaxText format

To train the model in MaxText format, you must convert it from Hugging Face format to MaxText format.

Specify your environment variables, such as your Hugging Face access token, the name of the model that you want to use, and the directory where you want to save the model in MaxText format.

export HF_TOKEN="YOUR_HF_TOKEN"
export MODEL_NAME='gemma3-4b'
export MODEL_CHECKPOINT_DIRECTORY=/dev/shm/$MODEL_NAME/mt-format/
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX
export LAZY_LOAD_TENSORS=False # True to use lazy load, False to use eager load.

Replace YOUR_HF_TOKEN with the Hugging Face access token that you previously created.

To convert the model from Hugging Face format to MaxText format, run the following script. This conversion takes about five minutes to complete.

python3 -m maxtext.checkpoint_conversion.to_maxtext \
    model_name=${MODEL_NAME?} \
    hf_access_token=${HF_TOKEN?} \
    base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \
    scan_layers=True \
    use_multimodal=False \
    hardware=cpu \
    skip_jax_distributed_system=true \
    checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
    checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
    --lazy_load_tensors=${LAZY_LOAD_TENSORS?}

Start the training workload

After the conversion process has completed, you can start the SFT workload.

  1. Configure the SFT workload training parameters.

    # -- MaxText configuration --
    export BASE_OUTPUT_DIRECTORY=/dev/shm/$MODEL_NAME/post-train/
    export RUN_NAME=$(date +%Y-%m-%d-%H-%M-%S)
    export STEPS=1000
    export PER_DEVICE_BATCH_SIZE=1
    
    # -- Dataset configuration --
    export DATASET_NAME="HuggingFaceH4/ultrachat_200k"
    export TRAIN_SPLIT="train_sft"
    export TRAIN_DATA_COLUMNS="['messages']"
    
    export MAXTEXT_CKPT_PATH=$MODEL_CHECKPOINT_DIRECTORY/0/items
  2. Start the training job. This takes about 10 minutes on a v6e-8 VM instance.

    python3 -m maxtext.trainers.post_train.sft.train_sft \
        run_name="${RUN_NAME?}" \
        base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \
        model_name="${MODEL_NAME?}" \
        load_parameters_path="${MAXTEXT_CKPT_PATH?}" \
        per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \
        steps="${STEPS?}" \
        hf_path="${DATASET_NAME?}" \
        train_split="${TRAIN_SPLIT?}" \
        train_data_columns="${TRAIN_DATA_COLUMNS?}" \
        profiler=xplane

Convert the trained model back into Hugging Face format

After the training workload has completed, convert the model back to Hugging Face format.

  1. Set the paths for export and the trained parameters.

    export HF_EXPORT=/dev/shm/$MODEL_NAME/hf-trained/
    export POST_TRAIN_PATH=$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints/$STEPS/model_params
  2. Run the conversion back to Hugging Face format.

    python3 -m maxtext.checkpoint_conversion.to_huggingface \
        model_name=$MODEL_NAME \
        load_parameters_path=$POST_TRAIN_PATH \
        base_output_directory=$HF_EXPORT \
        scan_layers=True \
        use_multimodal=False \
        weight_dtype=bfloat16

After the conversion has completed, your tuned model stored in /dev/shm/gemma3-4b/hf-trained is ready to be used. Because you lose access to the contents of the /dev/shm folder when the VM reboots, you should move the tuned model to persistent storage or upload it to the Hugging Face Hub.

Clean up

To avoid incurring additional charges, delete the resources created during this tutorial.

Delete your TPU VM instance

Exit your Cloud TPU VM instance, and then delete it.

gcloud alpha compute tpus tpu-vm delete $NAME --zone=$ZONE --project=$PROJECT --quiet

What's next