Melatih LLM menggunakan JAX, Ray Train, dan TPU Trillium di GKE

Tutorial ini menunjukkan cara melatih model bahasa besar (LLM) Llama 3 8B di Google Kubernetes Engine (GKE) menggunakan MaxText, Ray Train, dan TPU.

Tutorial ini memberikan panduan lengkap menyeluruh, mulai dari mengonfigurasi infrastruktur cloud yang diperlukan hingga mengirimkan dan berhasil menjalankan workload pelatihan di TPU multi-host.

Tutorial ini ditujukan bagi admin dan operator Platform serta spesialis Data dan AI yang ingin mempelajari cara melatih model besar pada slice TPU multi-host terdistribusi.

Latar belakang

Kombinasi GKE, KubeRay, MaxText, dan TPU menyediakan platform yang canggih dan skalabel untuk pelatihan model berskala besar. Bagian ini menjelaskan teknologi utama yang digunakan dalam panduan ini:

JAX

JAX adalah library Python untuk komputasi array dan transformasi program yang berorientasi pada akselerator, yang dirancang untuk komputasi numerik berperforma tinggi dan machine learning skala besar.

JAX menyediakan sistem yang dapat di-extend untuk mengubah fungsi numerik seperti jax.grad, jax.jit, dan jax.vmap, dengan memanfaatkan compiler XLA untuk membuat kode yang sangat dioptimalkan dan dapat diskalakan secara efisien pada akselerator seperti GPU dan TPU. Kekuatan inti JAX terletak pada kemampuannya untuk menyusun, yang memungkinkan pengguna menggabungkan transformasi ini untuk membangun program numerik berperforma tinggi yang kompleks untuk eksekusi terdistribusi.

MaxText

MaxText adalah model bahasa besar (LLM) open source berperforma tinggi yang dirancang untuk skalabilitas dan kemampuan penyesuaian. MaxText dibuat di atas JAX dan dioptimalkan agar berjalan secara efisien di Cloud TPU dan GPU.

TPU

Tensor Processing Unit (TPU) adalah akselerator yang dirancang khusus dan dibuat oleh Google untuk mengoptimalkan beban kerja machine learning. Tidak seperti CPU serbaguna atau GPU pemrosesan paralel, TPU sangat terspesialisasi untuk komputasi matriks dan tensor masif yang menjadi dasar deep learning, sehingga TPU efisien dalam tugas khusus ini. Keuntungan utama TPU adalah performa dalam skala besar.

Tutorial ini menggunakan TPU Trillium, yang merupakan TPU generasi keenam. Untuk mengetahui informasi selengkapnya, lihat Manfaat menggunakan TPU Trillium.

KubeRay

KubeRay adalah operator Kubernetes yang menyediakan cara terpadu untuk men-deploy, mengelola, dan memantau aplikasi Ray di Kubernetes. Operator KubeRay diinstal dan dikelola melalui add-on Ray di GKE, yang merupakan cara yang direkomendasikan untuk men-deploy dan mengelola cluster Ray di GKE.

Tujuan

Tutorial ini menunjukkan kepada Anda cara melakukan hal berikut:

  1. Siapkan cluster GKE dengan node pool TPU multi-host.
  2. Konfigurasi KubeRay untuk mengelola lingkungan pelatihan terdistribusi.
  3. Bangun image Docker kustom yang berisi dependensi MaxText, Ray, dan JAX.
  4. Buat skrip pelatihan Python yang menggunakan JaxTrainer Ray Train untuk mengatur loop pelatihan MaxText di seluruh slice TPU.
  5. Tentukan RayCluster resource kustom untuk menyediakan node head dan worker dengan resource TPU yang diperlukan.
  6. Kirim Tugas pelatihan ke RayCluster dan pantau progresnya.
  7. Gunakan Cloud Storage untuk menyimpan checkpoint model.

Sebelum memulai

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • Jika Anda menggunakan penyedia identitas (IdP) eksternal, Anda harus login ke gcloud CLI dengan identitas gabungan Anda terlebih dahulu.

  • Untuk melakukan inisialisasi gcloud CLI, jalankan perintah berikut:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • Jika Anda menggunakan penyedia identitas (IdP) eksternal, Anda harus login ke gcloud CLI dengan identitas gabungan Anda terlebih dahulu.

  • Untuk melakukan inisialisasi gcloud CLI, jalankan perintah berikut:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • Karena tutorial ini menggunakan TPU Trillium (v6e), pilih region atau zona dengan ketersediaan. Untuk mengetahui informasi selengkapnya, lihat Kuota Cloud TPU.

Menyiapkan lingkungan Anda

Dalam tutorial ini, Anda akan menggunakan Cloud Shell. Cloud Shell telah diinstal sebelumnya dengan alat command line gcloud, helm, dan kubectl yang digunakan dalam tutorial ini.

  1. Buka Google Cloud console.

  2. Di bagian atas jendela konsol, klik tombol Activate Cloud Shell Tombol Aktifkan Shell
. Google Cloud

    Sesi Cloud Shell akan terbuka di dalam frame baru di konsolGoogle Cloud dan menampilkan perintah command line.

  3. Membuat dan mengaktifkan lingkungan virtual Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Instal Ray CLI dan dependensi lainnya:

    pip install "ray[default]==2.49.1"
    
  5. Tetapkan variabel lingkungan berikut:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    Ganti kode berikut:

    • GS_BUCKET: nama bucket Cloud Storage.
    • KSA_NAME: nama Akun Layanan Kubernetes.
    • CLUSTER_NAME: nama cluster baru.
    • REGION: region tempat kapasitas TPU Trillium Anda tersedia.
    • ZONE: zona tempat kapasitas TPU Trillium Anda tersedia. Untuk mengetahui informasi selengkapnya, lihat Ketersediaan TPU di GKE.
    • ARTIFACT_REGISTRY: nama repositori Artifact Registry.

Membuat cluster GKE

Anda dapat mengonfigurasi KubeRay di TPU dalam cluster GKE Autopilot atau Standard. Sebaiknya gunakan cluster Autopilot untuk mendapatkan pengalaman Kubernetes yang terkelola sepenuhnya. Untuk memilih mode operasi GKE yang paling sesuai untuk workload Anda, lihat Tentang mode operasi GKE.

Autopilot

  1. Jalankan perintah berikut di Cloud Shell:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. Untuk berkomunikasi dengan cluster Anda, konfigurasi kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

Standar

  1. Di Cloud Shell, buat cluster Standard yang mengaktifkan add-on Ray operator dengan menjalankan perintah berikut:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    Perintah ini juga mengaktifkan GcsFuseCsiDriver, yang memungkinkan Pod memasang bucket Cloud Storage sebagai sistem file lokal. Pembuatan cluster mungkin memerlukan waktu beberapa menit.

  2. Untuk berkomunikasi dengan cluster Anda, konfigurasi kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. Buat node pool slice TPU multi-host:

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE menyediakan node pool yang terdiri dari empat VM TPU Trillium (v6e), yang dikonfigurasi bersama sebagai slice TPU multi-host, dengan topologi 4x4, yang siap untuk workload pelatihan terdistribusi.

Cluster GKE yang mendukung operator Ray akan otomatis menginstal KubeRay dan webhook TPU KubeRay di cluster Anda.

Mengonfigurasi bucket Cloud Storage dan akun layanan

  1. Buat bucket Cloud Storage untuk checkpoint bersama antara node TPU multi-host.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Untuk mengaktifkan akses ke bucket Cloud Storage, buat Akun Layanan Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Untuk mengaktifkan akses ke bucket Cloud Storage, tambahkan binding kebijakan IAM yang diperlukan ke akun layanan:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Membuat skrip pelatihan

Skrip berikut menggunakan JaxTrainer Ray Train untuk menjalankan tugas pelatihan MaxText terdistribusi. Skrip mengonfigurasi lingkungan pelatihan untuk node pool slice TPU multi-host dan menjalankan tugas pelatihan MaxText di setiap node pekerja. Fungsi train_loop_per_worker membungkus titik entri utama MaxText, dan menggunakan penjadwal terdistribusi Ray untuk menjalankan pelatih MaxText pada slice TPU multi-host.

  1. Simpan skrip Python berikut sebagai maxtext_ray_trainer.py:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. Untuk menghosting image kustom, buat repositori Artifact Registry:

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. Untuk membuat image yang menyertakan dependensi Ray dan MaxText untuk pelatihan, buat Dockerfile:

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Bangun, beri tag, dan kirim image Docker ke Artifact Registry:

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

Melatih model

  1. Simpan manifes contoh berikut sebagai maxtext-tpu-cluster.yaml:

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    Spesifikasi RayCluster sebelumnya membuat grup pekerja TPU dengan empat pekerja (numOfHosts: 4) per replika. Setiap pekerja meminta empat chip TPU (google.com/tpu: "4"). Pekerja akan dijadwalkan di node yang menjalankan TPU Trillium (tpu-v6e-slice), dan merupakan bagian dari slice multi-host yang sama dan diletakkan bersama. KubeRay menskalakan keempat pekerja secara atomik, dan variabel lingkungan JAX yang diperlukan, serta Afinitas Pod untuk penjadwalan, di-bootstrap oleh GKE melalui webhook mutasi.

  2. Untuk mengonfigurasi nilai yang diperlukan dalam file YAML, buat RayCluster menggunakan envsubst:

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. Pastikan cluster sudah siap dan berjalan:

    kubectl get rayclusters maxtext-tpu-cluster
    

    Outputnya akan mirip dengan berikut ini:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Untuk mengakses Dasbor Ray melalui layanan head Ray, buat sesi penerusan port:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifikasi bahwa RayCluster dapat dijangkau dari lingkungan lokal Anda:

    ray list nodes --address http://localhost:8265
    

    Outputnya akan mirip dengan berikut ini:

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. Kirimkan skrip JaxTrainer ke RayCluster dan periksa apakah RayJob berhasil diselesaikan:

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    Perintah sebelumnya mengirimkan skrip Python, yang memanggil kode JaxTrainer Ray ke RayCluster. Perintah ray job submit menyertakan beberapa argumen khusus MaxText untuk diteruskan ke konfigurasi model.

    Di terminal, Anda akan melihat output yang mirip dengan berikut ini:

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

Pembersihan

Agar akun Google Cloud Anda tidak dikenai biaya untuk resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus setiap resource.

  1. Hapus RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Hapus cluster GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Hapus bucket Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Hapus repositori Artifact Registry:

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

Langkah berikutnya