GKE에서 Ray Train을 사용하여 TPU에서 멀티슬라이스 및 탄력적 학습

이 튜토리얼에서는 MaxText, Ray Train, 멀티슬라이스 Trillium TPU를 사용하여 Google Kubernetes Engine (GKE)에서 Llama 3 70B와 같은 대규모 언어 모델 (LLM)을 학습하는 방법을 보여줍니다. 이 튜토리얼에서는 필요한 보조 데이터 센터 네트워킹 구성부터 32개의 실제 TPU 칩에 분산된 학습 워크로드를 제출하고 성공적으로 실행하는 것까지 완전한 엔드 투 엔드 연습을 제공합니다.

이 튜토리얼은 분산된 멀티 호스트 TPU 슬라이스에서 700억 개의 매개변수 모델을 학습할 때 발생하는 메모리 및 네트워킹 문제를 해결하는 방법을 알아보려는 플랫폼 관리자, 운영자, AI 전문가를 대상으로 합니다.

배경

GKE, KubeRay, MaxText, TPU를 함께 사용하면 대규모 모델 학습을 위한 강력하고 확장 가능한 플랫폼을 제공할 수 있습니다. 이 섹션에서는 이 가이드에서 사용되는 주요 기술을 설명합니다.

JAX

JAX는 가속기 지향 배열 계산과 프로그램 변환에 사용할 수 있는 Python 라이브러리로, XLA 컴파일러를 활용하여 가속기에서 효율적으로 확장되는 고도로 최적화된 코드를 만듭니다.

MaxText

MaxText는 확장성과 맞춤설정을 위해 설계된 고성능 오픈소스 LLM 프레임워크입니다. MaxText는 JAX를 기반으로 빌드되었으며 Cloud TPU에서 효율적으로 실행되도록 최적화되어 있습니다.

TPU

Tensor Processing Unit (TPU)은 머신러닝 워크로드를 최적화하기 위해 Google에서 만든 맞춤 설계된 가속기입니다. 범용 CPU 또는 병렬 처리 GPU와 달리 TPU는 딥 러닝의 기반이 되는 대규모 행렬 및 텐서 연산에 매우 특화되어 있어 이 특정 작업에 효율적입니다. TPU의 주요 이점은 대규모 성능입니다.

이 튜토리얼에서는 멀티슬라이스 배포 패턴에서 6세대 TPU인 TPU Trillium을 사용합니다. Cloud TPU 멀티슬라이스는 2개 이상의 Cloud TPU 슬라이스가 데이터 센터 네트워크 (DCN)를 통해 통신하는 환경입니다. 멀티슬라이스는 TPU 칩을 최대 10,000개까지 선형에 가깝게 수직 확장하여 경제적인 풀 스택 대규모 학습을 지원합니다. 멀티슬라이스에 대한 자세한 내용은 Cloud TPU 멀티슬라이스 개요를 참고하세요.

KubeRay

KubeRay는 Kubernetes에서 Ray 애플리케이션을 배포, 관리, 모니터링하는 통합 방법을 제공하는 Kubernetes 연산자입니다. KubeRay 연산자는 GKE에서 Ray 클러스터를 배포하고 관리하는 데 권장되는 방법인 GKE의 Ray 부가기능을 통해 설치되고 관리됩니다.

GKE 동적 리소스 할당 네트워크 (DRANET)

GKE DRANET (동적 리소스 할당 네트워크)은 표준 Kubernetes 네트워킹을 우회하고 DCN을 통해 고성능을 지원하여 고성능 네트워크 기기를 포드에 동적으로 연결하는 기능입니다.

목표

이 튜토리얼에서는 다음 작업을 처리하는 방법을 보여줍니다.

  1. 멀티 호스트 TPU 노드 풀이 2개인 GKE 클러스터를 설정합니다.
  2. 슬라이스 간 TPU 통신을 위해 보조 DCN을 구성합니다.
  3. 분산 학습 환경을 관리하도록 KubeRay를 구성합니다.
  4. 네트워크 연결에 동적 리소스 할당 (DRA)을 사용하여 RayCluster 커스텀 리소스를 배포합니다.
  5. Ray Train의 JaxTrainer를 활용하여 TPU 슬라이스 전반에서 MaxText 학습 루프를 오케스트레이션하는 Python 학습 스크립트를 만듭니다.
  6. 기준 Llama 3 8B 학습 작업을 실행합니다.
  7. DCN을 통해 2D 샤딩 (텐서 병렬 처리 및 FSDP)을 활용하여 Llama 3 70B로 수직 확장합니다.

시작하기 전에

  • Google Cloud 계정에 로그인합니다. Google Cloud를 처음 사용하는 경우 계정을 만들고 Google 제품의 실제 성능을 평가해 보세요. 신규 고객에게는 워크로드를 실행, 테스트, 배포하는 데 사용할 수 있는 $300의 무료 크레딧이 제공됩니다.
  • Google Cloud CLI를 설치합니다.

  • 외부 ID 공급업체(IdP)를 사용하는 경우 먼저 제휴 ID로 gcloud CLI에 로그인해야 합니다.

  • gcloud CLI를 초기화하려면, 다음 명령어를 실행합니다.

    gcloud init
  • Google Cloud 프로젝트를 만들거나 선택합니다.

    프로젝트를 선택하거나 만드는 데 필요한 역할

    • 프로젝트 선택: 프로젝트를 선택하는 데는 특정 IAM 역할이 필요하지 않습니다. 역할이 부여된 프로젝트를 선택하면 됩니다.
    • 프로젝트 만들기: 프로젝트를 만들려면 resourcemanager.projects.create 권한이 포함된 프로젝트 생성자 역할(roles/resourcemanager.projectCreator)이 필요합니다. 역할 부여 방법 알아보기
    • Google Cloud 프로젝트를 만듭니다.

      gcloud projects create PROJECT_ID

      PROJECT_ID를 만들려는 Google Cloud 프로젝트의 이름으로 바꿉니다.

    • 생성한 Google Cloud 프로젝트를 선택합니다.

      gcloud config set project PROJECT_ID

      PROJECT_ID을 Google Cloud 프로젝트 이름으로 바꿉니다.

  • Google Cloud 프로젝트에 결제가 사용 설정되어 있는지 확인합니다.

  • 필요한 API를 사용 설정합니다.

    API 사용 설정에 필요한 역할

    API를 사용 설정하려면 serviceusage.services.enable 권한이 포함된 서비스 사용량 관리자 IAM 역할 (roles/serviceusage.serviceUsageAdmin)이 필요합니다. 역할 부여 방법 알아보기

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Google Cloud CLI를 설치합니다.

  • 외부 ID 공급업체(IdP)를 사용하는 경우 먼저 제휴 ID로 gcloud CLI에 로그인해야 합니다.

  • gcloud CLI를 초기화하려면, 다음 명령어를 실행합니다.

    gcloud init
  • Google Cloud 프로젝트를 만들거나 선택합니다.

    프로젝트를 선택하거나 만드는 데 필요한 역할

    • 프로젝트 선택: 프로젝트를 선택하는 데는 특정 IAM 역할이 필요하지 않습니다. 역할이 부여된 프로젝트를 선택하면 됩니다.
    • 프로젝트 만들기: 프로젝트를 만들려면 resourcemanager.projects.create 권한이 포함된 프로젝트 생성자 역할(roles/resourcemanager.projectCreator)이 필요합니다. 역할 부여 방법 알아보기
    • Google Cloud 프로젝트를 만듭니다.

      gcloud projects create PROJECT_ID

      PROJECT_ID를 만들려는 Google Cloud 프로젝트의 이름으로 바꿉니다.

    • 생성한 Google Cloud 프로젝트를 선택합니다.

      gcloud config set project PROJECT_ID

      PROJECT_ID을 Google Cloud 프로젝트 이름으로 바꿉니다.

  • Google Cloud 프로젝트에 결제가 사용 설정되어 있는지 확인합니다.

  • 필요한 API를 사용 설정합니다.

    API 사용 설정에 필요한 역할

    API를 사용 설정하려면 serviceusage.services.enable 권한이 포함된 서비스 사용량 관리자 IAM 역할 (roles/serviceusage.serviceUsageAdmin)이 필요합니다. 역할 부여 방법 알아보기

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • 사용자 계정에 역할을 부여합니다. 다음 IAM 역할마다 다음 명령어를 1회 실행합니다. roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    다음을 바꿉니다.

    • PROJECT_ID: 프로젝트 ID입니다.
    • USER_IDENTIFIER: 사용자 계정의 식별자입니다. 예를 들면 myemail@example.com입니다.
    • ROLE: 사용자 계정에 부여하는 IAM 역할입니다.
  • 이 튜토리얼에서는 TPU Trillium (v6e)을 사용하므로 사용 가능한 리전 또는 영역을 선택하세요. 자세한 내용은 Cloud TPU 할당량을 참고하세요.

개발 환경 준비

이 튜토리얼에서는 Cloud Shell을 사용합니다. Cloud Shell에는 이 튜토리얼에서 사용되는 gcloud, helm, kubectl 명령줄 도구가 사전 설치되어 있습니다.

  1. Google Cloud 콘솔로 이동합니다.

  2. Google Cloud 콘솔 창 상단에서 Cloud Shell 활성화 셸 활성화 버튼 버튼을 클릭합니다.

    Google Cloud 콘솔의 새 프레임 내에 Cloud Shell 세션이 열리면서 명령줄 프롬프트가 표시됩니다.

  3. 터미널에서 kubernetes-engine-samples 저장소를 클론합니다.

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. 샘플 파일이 포함된 디렉터리로 변경합니다.

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Python 가상 환경을 만들고 활성화합니다.

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Ray CLI를 설치합니다.

    pip install "ray[default]==2.55.0"
    
  7. 다음 환경 변수를 설정합니다.

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    다음을 바꿉니다.

    • GS_BUCKET: Cloud Storage 버킷 이름
    • KSA_NAME: Kubernetes 서비스 계정의 이름입니다.
    • CLUSTER_NAME: 새 클러스터의 이름
    • REGION: TPU Trillium 용량을 사용할 수 있는 리전입니다.
    • ZONE: TPU Trillium 용량을 사용할 수 있는 영역입니다. 자세한 내용은 GKE의 TPU 가용성을 참고하세요.

Cloud TPU 멀티슬라이스용 클러스터 네트워킹 구성

멀티 호스트 TPU 슬라이스 내에서 TPU 기기는 고속 칩 간 상호 연결을 통해 통신합니다. 하지만 멀티슬라이스 작업을 실행할 때는 TPU 슬라이스가 DCN을 통해 서로 통신해야 합니다. 표준 Kubernetes 포드 네트워크는 이 트래픽의 병목 현상을 일으킬 수 있습니다. ct6e-standard-4t 머신 유형은 여러 물리적 네트워크 인터페이스 카드 (NIC)로 지원됩니다. 최상의 성능을 위해 두 개의 추가 VPC 네트워크를 만들고 GKE DRANET을 사용하여 Ray 포드에 직접 연결합니다.

  1. 큰 최대 학습 단위 (MTU)를 사용하여 두 개의 추가 VPC 네트워크를 만듭니다.

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. 전용 서브넷을 만듭니다.

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

GKE 클러스터 만들기

GKE Autopilot 또는 Standard 클러스터의 TPU에서 KubeRay를 구성할 수 있습니다. 완전 관리형 Kubernetes 환경을 위해서는 Autopilot 클러스터를 사용하는 것이 좋습니다. 워크로드에 가장 적합한 GKE 작업 모드를 선택하려면 GKE 작업 모드 정보를 참고하세요.

GKE 관리 DRANET을 사용하려면 클러스터에서 Autopilot 모드의 경우 버전 1.35.2-gke.1842000 이상을, Standard 모드의 경우 1.34.1-gke.1829001 이상을 사용해야 합니다. 이 튜토리얼에서는 버전 1.35.2-gke.1842000을 사용합니다.

Autopilot

  1. Cloud Shell에서 다음 명령어를 실행합니다.

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. 클러스터와 통신하려면 kubectl을 구성하세요.

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

표준

  1. Cloud Shell에서 다음 명령어를 실행하여 Ray operator 부가기능을 사용 설정하는 Standard 클러스터를 만듭니다.

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    이 명령어를 사용하면 포드가 Cloud Storage 버킷을 로컬 파일 시스템으로 마운트할 수 있는 GcsFuseCsiDriver도 사용 설정됩니다. 클러스터를 만드는 데 몇 분 정도 걸릴 수 있습니다.

  2. 클러스터와 통신하려면 kubectl을 구성하세요.

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. GKE DRANET이 사용 설정된 첫 번째 멀티 호스트 TPU 슬라이스 노드 풀을 만듭니다.

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. 두 번째 TPU 슬라이스 노드 풀을 만듭니다.

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE는 4x4 토폴로지가 있는 멀티 호스트 TPU 슬라이스로 함께 구성된 4개의 TPU Trillium (v6e) VM으로 구성된 노드 풀을 프로비저닝합니다. 이 노드 풀은 분산형 학습 워크로드에 사용할 수 있습니다.

Ray 연산자가 사용 설정된 GKE 클러스터는 클러스터에 KubeRay와 KubeRay TPU 웹훅을 자동으로 설치합니다.

Cloud Storage 버킷 및 서비스 계정 구성

  1. 다중 호스트 TPU 노드 간에 공유되는 체크포인트를 위한 Cloud Storage 버킷을 만듭니다.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Cloud Storage 버킷에 대한 액세스를 사용 설정하려면 Kubernetes 서비스 계정을 만드세요.

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Cloud Storage 버킷에 대한 액세스를 사용 설정하려면 서비스 계정에 필요한 IAM 정책 바인딩을 추가하세요.

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

학습 스크립트 만들기

maxtext_multi_slice_trainer.py 스크립트는 Ray Train의 JaxTrainer를 사용하여 두 개의 TPU 슬라이스에서 분산 MaxText 학습 작업을 실행합니다. 이 스크립트는 8개의 멀티 호스트 TPU 워커의 학습 환경을 구성하고 각 워커 노드에서 MaxText 학습 작업을 실행합니다. train_loop_per_worker 함수는 MaxText 기본 진입점을 래핑하고 Ray의 분산 스케줄러를 사용하여 멀티 호스트 TPU 슬라이스에서 MaxText 트레이너를 실행합니다.

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

위 스크립트는 작업자 8개와 4x4 토폴로지를 요청하는 JaxTrainer 인스턴스를 정의합니다. 내부적으로 Ray는 두 TPU 슬라이스에 걸쳐 SlicePlacementGroup을 프로비저닝하고 Ray Train 작업자가 호스트당 하나의 작업자로 두 슬라이스에 걸쳐 원자적으로 실행되도록 지원합니다.

모델 학습

  1. 현재 디렉터리의 ray-cluster.tpu-multi-slice.yaml 매니페스트는 RayCluster 커스텀 리소스를 정의합니다. 이 매니페스트에는 GKE DRANET 및 멀티슬라이스의 네트워크 기기를 프로비저닝하는 DRANET ResourceClaimTemplate가 포함되어 있습니다.

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    위의 RayCluster 사양은 복제본당 8개의 TPU 워커(numOfHosts: 4)가 있는 TPU 작업자 그룹을 2개의 복제본으로 만듭니다. 각 작업자는 TPU 칩 4개(google.com/tpu: "4")를 요청합니다. 각 작업자는 동일한 공동 배치 멀티 호스트 슬라이스의 일부인 TPU Trillium 노드(tpu-v6e-slice)에 예약됩니다. KubeRay는 슬라이스의 네 가지 작업자를 모두 원자적으로 확장합니다. 필요한 JAX 환경 변수와 일정 예약용 포드 선호도는 GKE에서 변형 웹훅을 통해 부트스트랩됩니다.

  2. RayCluster를 만들려면 매니페스트를 적용합니다.

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. 클러스터가 준비되어 실행 중인지 확인합니다.

    kubectl get rayclusters maxtext-tpu-cluster
    

    출력은 다음과 비슷하게 표시됩니다.

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. Ray 헤드 서비스를 통해 Ray 대시보드에 액세스하려면 포트 전달 세션을 설정하세요.

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. 로컬 환경에서 RayCluster에 연결할 수 있는지 확인합니다.

    ray list nodes --address http://localhost:8265
    

    출력은 다음과 비슷하게 표시됩니다.

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. 기본 MaxText 구성 파일을 다운로드합니다. 이 파일은 학습 스크립트가 모델의 기본 초매개변수를 설정하는 데 필요합니다.

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. JaxTrainer 스크립트를 RayCluster에 제출하고 RayJob이 성공적으로 완료되었는지 확인합니다.

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

위 명령어는 JaxTrainer Ray 코드를 호출하는 Python 스크립트를 RayCluster에 제출합니다. ray job submit 명령어에는 모델 구성에 전달할 몇 가지 MaxText 관련 인수가 포함되어 있습니다.

터미널에 Llama 3 70B 작업에 대해 다음과 비슷한 출력이 표시됩니다.

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

스팟 VM에서 멀티슬라이스 탄력적 학습 실행

TPU와 같이 수요가 많은 액셀러레이터를 사용하는 경우 스팟 VM을 활용하면 비용을 크게 절감할 수 있습니다. 하지만 스팟 VM은 예기치 않게 선점될 수 있습니다.

Ray Train은 탄력적 학습을 지원하므로 작업이 실패하지 않고 참여하는 TPU 슬라이스 수를 동적으로 확장하거나 축소할 수 있습니다. 슬라이스가 선점되면 Ray는 학습 루프를 일시중지하고, 나머지 작업자가 재구성되기를 기다리고, 최신 MaxText 체크포인트에서 복원하고, 더 작은 설치 공간에서 학습을 재개합니다.

탄력적 학습을 사용 설정하려면 ScalingConfignum_workers 매개변수를 정적 정수에서 (minimum_workers, maximum_workers)을 나타내는 튜플로 변경합니다. 또한 작업자가 선점될 때 작업을 완전히 실패하는 대신 Ray Train이 학습 루프를 최대 3회까지 다시 시도하도록 지시하는 FailureConfig(max_failures=3)RunConfig에 추가합니다.

Ray Train 스크립트 업데이트

  1. 현재 디렉터리의 maxtext_elastic_trainer.py 스크립트는 탄력적 학습을 지원합니다. num_workers=(4,8)가 설정되어 있습니다. 이는 16칩 슬라이스 (작업자 4명)가 하나 이상 있으면 Ray가 진행하도록 하지만 가능한 경우 슬라이스 2개 (작업자 8명)로 확장하도록 지시합니다. 여기에는 탄력적 학습을 사용 설정하고, 재시도 횟수를 정의하고, 작업이 선점되지 않도록 하는 데 도움이 되는 FailureConfig가 포함됩니다.

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Ray Job CLI를 사용하여 작업을 제출합니다. 체크포인트가 이전 실행과 충돌하지 않도록 고유한 run_name를 제공해야 합니다.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. 학습 중에 노드 종료 또는 선점을 시뮬레이션하려면 포드를 삭제합니다.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

터미널에 작업자 오류가 로깅되지만 오케스트레이션 컨트롤러는 작업을 계속 실행하고 최소 토폴로지가 제공된 후 /data/rayjob-elastic-8b/checkpoints 체크포인트에서 자동으로 다시 시작합니다.

MaxText는 재개 시 기기 메시를 동적으로 다시 계산하므로 토폴로지가 축소될 때 체크포인트 재샤딩을 처리하는 맞춤 로직을 작성할 필요가 없습니다. JAX의 Orbax 체크포인터는 학습 루프를 계속하기 전에 저장된 가중치를 새로운 물리적 레이아웃으로 자동 재샤드합니다. 다음 출력은 Ray Train 컨트롤러가 클러스터에서 새로 사용할 수 있는 TPU 리소스를 감지하고 학습 중에 슬라이스 1개 (작업자 4명)에서 슬라이스 2개 (작업자 8명)로 확장 작업을 실행하는 것을 보여줍니다.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

삭제

이 튜토리얼에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 프로젝트는 유지하되 개별 리소스를 삭제하세요.

  1. RayCluster를 삭제합니다.

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. GKE 클러스터를 삭제합니다.

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Cloud Storage 버킷을 삭제합니다.

    gsutil rm -r gs://${GS_BUCKET}
    

다음 단계