GKE에서 JAX, Ray Train, TPU Trillium을 사용하여 LLM 학습

이 튜토리얼에서는 MaxText, Ray Train, TPU를 사용하여 Google Kubernetes Engine (GKE)에서 Llama 3 8B 대규모 언어 모델 (LLM)을 학습하는 방법을 보여줍니다.

이 튜토리얼에서는 필요한 클라우드 인프라를 구성하는 것부터 멀티 호스트 TPU에서 학습 워크로드를 제출하고 성공적으로 실행하는 것까지 전체 과정을 안내합니다.

이 튜토리얼은 분산된 멀티 호스트 TPU 슬라이스에서 대규모 모델을 학습하는 방법을 알아보려는 플랫폼 관리자 및 운영자, 데이터 및 AI 전문가를 대상으로 합니다.

배경

GKE, KubeRay, MaxText, TPU를 결합하면 대규모 모델 학습을 위한 강력하고 확장 가능한 플랫폼을 제공할 수 있습니다. 이 섹션에서는 이 가이드에서 사용되는 주요 기술을 설명합니다.

JAX

JAX는 고성능 수치 계산과 대규모 머신러닝을 위해 설계된 가속기 지향 배열 계산과 프로그램 변환에 사용할 수 있는 Python 라이브러리입니다.

JAX는 jax.grad, jax.jit, jax.vmap와 같은 수치 함수를 변환하는 확장 가능한 시스템을 제공하며, XLA 컴파일러를 활용하여 GPU 및 TPU와 같은 가속기에서 효율적으로 확장되는 고도로 최적화된 코드를 생성합니다. JAX의 핵심 기능은 구성 가능성에 있으며, 이를 통해 사용자는 이러한 변환을 결합하여 분산 실행을 위한 복잡한 고성능 수치 프로그램을 빌드할 수 있습니다.

MaxText

MaxText는 확장성과 맞춤설정을 위해 설계된 고성능 오픈소스 대규모 언어 모델 (LLM)입니다. MaxText는 JAX를 기반으로 빌드되었으며 Cloud TPU 및 GPU에서 효율적으로 실행되도록 최적화되어 있습니다.

TPU

Tensor Processing Unit (TPU)은 머신러닝 워크로드를 최적화하기 위해 Google에서 만든 맞춤 설계 가속기입니다. 범용 CPU 또는 병렬 처리 GPU와 달리 TPU는 딥 러닝의 기반이 되는 대규모 행렬 및 텐서 연산에 특화되어 있어 이 특정 작업에 효율적입니다. TPU의 주요 이점은 대규모 성능입니다.

이 튜토리얼에서는 6세대 TPU인 TPU Trillium을 사용합니다. 자세한 내용은 TPU Trillium 사용의 이점을 참고하세요.

KubeRay

KubeRay는 Kubernetes에서 Ray 애플리케이션을 배포, 관리, 모니터링하는 통합 방법을 제공하는 Kubernetes 연산자입니다. KubeRay 연산자는 GKE에서 Ray 클러스터를 배포하고 관리하는 데 권장되는 방법인 GKE의 Ray 부가기능을 통해 설치되고 관리됩니다.

목표

이 튜토리얼에서는 다음 작업을 처리하는 방법을 보여줍니다.

  1. 멀티 호스트 TPU 노드 풀이 있는 GKE 클러스터를 설정합니다.
  2. 분산 학습 환경을 관리하도록 KubeRay를 구성합니다.
  3. MaxText, Ray, JAX 종속 항목이 포함된 맞춤 Docker 이미지를 빌드합니다.
  4. Ray Train의 JaxTrainer를 사용하여 TPU 슬라이스 전반에서 MaxText 학습 루프를 오케스트레이션하는 Python 학습 스크립트를 만듭니다.
  5. 필요한 TPU 리소스로 헤드 및 워커 노드를 프로비저닝하는 RayCluster 커스텀 리소스를 정의합니다.
  6. RayCluster에 학습 작업을 제출하고 진행 상황을 모니터링합니다.
  7. Cloud Storage를 사용하여 모델 체크포인트를 저장합니다.

시작하기 전에

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • 외부 ID 공급업체(IdP)를 사용하는 경우 먼저 제휴 ID로 gcloud CLI에 로그인해야 합니다.

  • gcloud CLI를 초기화하려면, 다음 명령어를 실행합니다.

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • 외부 ID 공급업체(IdP)를 사용하는 경우 먼저 제휴 ID로 gcloud CLI에 로그인해야 합니다.

  • gcloud CLI를 초기화하려면, 다음 명령어를 실행합니다.

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • 이 튜토리얼에서는 TPU Trillium (v6e)을 사용하므로 사용 가능한 리전 또는 영역을 선택하세요. 자세한 내용은 Cloud TPU 할당량을 참고하세요.

개발 환경 준비

이 튜토리얼에서는 Cloud Shell을 사용합니다. Cloud Shell에는 이 튜토리얼에서 사용되는 gcloud, helm, kubectl 명령줄 도구가 사전 설치되어 있습니다.

  1. Google Cloud 콘솔로 이동합니다.

  2. Google Cloud 콘솔 창 상단에서 Cloud Shell 활성화 셸 활성화 버튼 버튼을 클릭합니다.

    Google Cloud 콘솔의 새 프레임 내에 Cloud Shell 세션이 열리면서 명령줄 프롬프트가 표시됩니다.

  3. Python 가상 환경을 만들고 활성화합니다.

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Ray CLI 및 기타 종속 항목을 설치합니다.

    pip install "ray[default]==2.49.1"
    
  5. 다음 환경 변수를 설정합니다.

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    다음을 바꿉니다.

    • GS_BUCKET: Cloud Storage 버킷 이름
    • KSA_NAME: Kubernetes 서비스 계정의 이름입니다.
    • CLUSTER_NAME: 새 클러스터의 이름
    • REGION: TPU Trillium 용량을 사용할 수 있는 리전입니다.
    • ZONE: TPU Trillium 용량을 사용할 수 있는 영역입니다. 자세한 내용은 GKE의 TPU 가용성을 참고하세요.
    • ARTIFACT_REGISTRY: Artifact Registry 저장소의 이름입니다.

GKE 클러스터 만들기

GKE Autopilot 또는 Standard 클러스터의 TPU에서 KubeRay를 구성할 수 있습니다. 완전 관리형 Kubernetes 환경을 위해서는 Autopilot 클러스터를 사용하는 것이 좋습니다. 워크로드에 가장 적합한 GKE 작업 모드를 선택하려면 GKE 작업 모드 정보를 참고하세요.

Autopilot

  1. Cloud Shell에서 다음 명령어를 실행합니다.

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. 클러스터와 통신하려면 kubectl을 구성하세요.

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

표준

  1. Cloud Shell에서 다음 명령어를 실행하여 Ray operator 부가기능을 사용 설정하는 Standard 클러스터를 만듭니다.

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    이 명령어는 포드가 Cloud Storage 버킷을 로컬 파일 시스템으로 마운트할 수 있도록 하는 GcsFuseCsiDriver도 사용 설정합니다. 클러스터를 만드는 데 몇 분 정도 걸릴 수 있습니다.

  2. 클러스터와 통신하려면 kubectl을 구성하세요.

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. 멀티 호스트 TPU 슬라이스 노드 풀을 만듭니다.

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE는 분산 학습 워크로드에 적합한 4x4 토폴로지를 사용하여 멀티 호스트 TPU 슬라이스로 함께 구성된 4개의 TPU Trillium (v6e) VM으로 구성된 노드 풀을 프로비저닝합니다.

Ray 연산자가 사용 설정된 GKE 클러스터는 클러스터에 KubeRay와 KubeRay TPU 웹훅을 자동으로 설치합니다.

Cloud Storage 버킷 및 서비스 계정 구성

  1. 다중 호스트 TPU 노드 간에 공유되는 체크포인트를 위한 Cloud Storage 버킷을 만듭니다.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Cloud Storage 버킷에 대한 액세스를 사용 설정하려면 Kubernetes 서비스 계정을 만드세요.

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Cloud Storage 버킷에 대한 액세스를 사용 설정하려면 서비스 계정에 필요한 IAM 정책 바인딩을 추가하세요.

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

학습 스크립트 만들기

다음 스크립트는 Ray Train의 JaxTrainer를 사용하여 분산 MaxText 학습 작업을 실행합니다. 스크립트는 멀티 호스트 TPU 슬라이스 노드 풀의 학습 환경을 구성하고 각 작업자 노드에서 MaxText 학습 작업을 실행합니다. train_loop_per_worker 함수는 MaxText 기본 진입점을 래핑하고 Ray의 분산 스케줄러를 사용하여 멀티 호스트 TPU 슬라이스에서 MaxText 트레이너를 실행합니다.

  1. 다음 Python 스크립트를 maxtext_ray_trainer.py로 저장합니다.

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. 커스텀 이미지를 호스팅하려면 Artifact Registry 저장소를 만드세요.

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. 학습을 위해 Ray 및 MaxText 종속 항목이 포함된 이미지를 빌드하려면 Dockerfile를 만드세요.

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Docker 이미지를 빌드하고 태그를 지정한 후 Artifact Registry에 푸시합니다.

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

모델 학습

  1. 다음 샘플 매니페스트를 maxtext-tpu-cluster.yaml로 저장합니다.

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    앞의 RayCluster 사양은 복제본당 작업자가 4명(numOfHosts: 4)인 TPU 작업자 그룹을 만듭니다. 각 작업자는 TPU 칩 4개(google.com/tpu: "4")를 요청합니다. 작업자는 TPU Trillium (tpu-v6e-slice)을 실행하고 동일한 공동 배치 멀티 호스트 슬라이스의 일부인 노드에 예약됩니다. KubeRay는 네 개의 작업자를 모두 원자적으로 확장하며, 필요한 JAX 환경 변수와 일정 예약용 포드 어피니티는 변형 웹훅을 통해 GKE에 의해 부트스트랩됩니다.

  2. YAML 파일에서 필수 값을 구성하려면 envsubst를 사용하여 RayCluster를 만듭니다.

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. 클러스터가 준비되었고 실행 중인지 확인합니다.

    kubectl get rayclusters maxtext-tpu-cluster
    

    출력은 다음과 비슷하게 표시됩니다.

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Ray 헤드 서비스를 통해 Ray 대시보드에 액세스하려면 포트 전달 세션을 설정하세요.

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. 로컬 환경에서 RayCluster에 연결할 수 있는지 확인합니다.

    ray list nodes --address http://localhost:8265
    

    출력은 다음과 비슷하게 표시됩니다.

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. JaxTrainer 스크립트를 RayCluster에 제출하고 RayJob이 성공적으로 완료되었는지 확인합니다.

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    위 명령어는 JaxTrainer Ray 코드를 호출하는 Python 스크립트를 RayCluster에 제출합니다. ray job submit 명령어에는 모델 구성에 전달할 일부 MaxText 관련 인수가 포함되어 있습니다.

    터미널에 다음과 비슷한 출력이 표시됩니다.

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

삭제

이 튜토리얼에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 프로젝트는 유지하되 개별 리소스를 삭제하세요.

  1. RayCluster를 삭제합니다.

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. GKE 클러스터를 삭제합니다.

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Cloud Storage 버킷을 삭제합니다.

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Artifact Registry 저장소를 삭제합니다.

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

다음 단계