使用 Ray 在 GKE 上微调 Gemma 3 以执行视觉任务

本教程介绍了如何在多节点 GKE 集群上使用 Ray 框架对 Gemma 3 模型进行微调。该集群使用两个 A4 虚拟机 (VM) 实例,每个实例都挂接了八个 NVIDIA B200 GPU。

本教程的内容分为两部分:

  1. 准备在 GKE Autopilot 集群上运行的 Ray 集群。
  2. 运行分布式训练作业,利用 2 个 A4 实例,每个实例配备 8 个 B200 GPU。

本教程适用于机器学习 (ML) 工程师、研究人员、平台管理员和运维人员,以及对在多个节点和 GPU 之间分配 AI 工作负载感兴趣的数据和 AI 专家。

目标

  • 使用 Hugging Face 访问 Gemma 3 模型。

  • 准备环境。

  • 创建已安装 Ray Operator 的 GKE Autopilot 集群。

  • 配置 GKE 集群上的 Ray 集群以接受 Ray 作业。

  • 配置并运行一个 Ray 作业,该作业可根据视觉输入调整 Gemma 3 模型。

  • 监控工作负载。

  • 清理。

费用

在本文档中,您将使用 Google Cloud的以下收费组件:

如需根据您的预计使用情况来估算费用,请使用价格计算器

新 Google Cloud 用户可能有资格申请免费试用

准备工作

  1. 登录您的 Google Cloud 账号。如果您是 Google Cloud新手,请 创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
  2. 安装 Google Cloud CLI。

  3. 如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI

  4. 如需初始化 gcloud CLI,请运行以下命令:

    gcloud init
  5. 创建或选择 Google Cloud 项目

    选择或创建项目所需的角色

    • 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授角色的任何项目。
    • 创建项目:如需创建项目,您需要拥有 Project Creator 角色 (roles/resourcemanager.projectCreator),该角色包含 resourcemanager.projects.create 权限。了解如何授予角色
    • 创建 Google Cloud 项目:

      gcloud projects create PROJECT_ID

      PROJECT_ID 替换为您要创建的 Google Cloud 项目的名称。

    • 选择您创建的 Google Cloud 项目:

      gcloud config set project PROJECT_ID

      PROJECT_ID 替换为您的 Google Cloud 项目名称。

  6. 验证是否已为您的 Google Cloud 项目启用结算功能

  7. 启用所需的 API:

    启用 API 所需的角色

    如需启用 API,您需要拥有 Service Usage Admin IAM 角色 (roles/serviceusage.serviceUsageAdmin),该角色包含 serviceusage.services.enable 权限。了解如何授予角色

    gcloud services enable gcloud services enable compute.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com container.googleapis.com
  8. 安装 Google Cloud CLI。

  9. 如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI

  10. 如需初始化 gcloud CLI,请运行以下命令:

    gcloud init
  11. 创建或选择 Google Cloud 项目

    选择或创建项目所需的角色

    • 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授角色的任何项目。
    • 创建项目:如需创建项目,您需要拥有 Project Creator 角色 (roles/resourcemanager.projectCreator),该角色包含 resourcemanager.projects.create 权限。了解如何授予角色
    • 创建 Google Cloud 项目:

      gcloud projects create PROJECT_ID

      PROJECT_ID 替换为您要创建的 Google Cloud 项目的名称。

    • 选择您创建的 Google Cloud 项目:

      gcloud config set project PROJECT_ID

      PROJECT_ID 替换为您的 Google Cloud 项目名称。

  12. 验证是否已为您的 Google Cloud 项目启用结算功能

  13. 启用所需的 API:

    启用 API 所需的角色

    如需启用 API,您需要拥有 Service Usage Admin IAM 角色 (roles/serviceusage.serviceUsageAdmin),该角色包含 serviceusage.services.enable 权限。了解如何授予角色

    gcloud services enable gcloud services enable compute.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com container.googleapis.com
  14. 向您的用户账号授予角色。对以下每个 IAM 角色运行以下命令一次: roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/container.clusterAdmin, roles/serviceusage.serviceUsageAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    替换以下内容:

    • PROJECT_ID:您的项目 ID。
    • USER_IDENTIFIER:您的用户 账号的标识符。例如,myemail@example.com
    • ROLE:您授予用户账号的 IAM 角色。
  15. 为您的 Google Cloud 项目启用默认服务账号:
    gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com \
        --project=PROJECT_ID

    PROJECT_NUMBER 替换为您的项目编号。如需查看项目编号,请参阅 获取现有项目

  16. 向默认服务账号授予 Editor 角色 (roles/editor):
    gcloud projects add-iam-policy-binding PROJECT_ID \
        --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com" \
        --role=roles/editor
  17. 为您的用户账号创建本地身份验证凭据:
    gcloud auth application-default login
  18. 登录或创建 Hugging Face 账号

使用 Hugging Face 访问 Gemma 3

如需使用 Hugging Face 访问 Gemma 3,请执行以下操作:

  1. 签署同意协议,以使用 Gemma 3

  2. 创建 Hugging Face read access 令牌

  3. 复制并保存 read access 令牌值。您将在本教程的后面部分使用该地址。

准备环境

通过配置必要的设置和设置环境变量来准备环境。

运行以下命令:

gcloud config set billing/quota_project $PROJECT_ID
export RESERVATION=RESERVATION_URL
export REGION=REGION
export CLUSTER_NAME=CLUSTER_NAME
export HF_TOKEN=HF_TOKEN
export NETWORK=default
export GCS_BUCKET=GCS_BUCKET

替换以下内容:

  • RESERVATION_URL:您要用于创建集群的预留的网址。根据预留所在的项目,指定以下值之一:
    • 预留存在于您的项目中:RESERVATION_NAME
    • 预留存在于其他项目中,并且您的项目可以使用该预留:projects/RESERVATION_PROJECT_ID/reservations/RESERVATION_NAME。 系统接受完整网址和部分网址。例如,您可以使用 projects/RESERVATION_PROJECT_ID/reservations/RESERVATION_NAME
  • REGION:您要在其中创建 GKE 集群的区域。 您只能在预留所在的区域中创建集群。
  • CLUSTER_NAME:要创建的 GKE 集群的名称。
  • HF_TOKEN:您在之前的步骤中创建的 Hugging Face 令牌。
  • GCS_BUCKET:用于存储训练检查点结果的存储桶的名称。

在 Autopilot 模式下创建 GKE 集群

如需在 Autopilot 模式下创建 GKE 集群,请运行以下命令:

gcloud container clusters create-auto $CLUSTER_NAME \
    --enable-ray-operator \
    --enable-ray-cluster-monitoring \
    --enable-ray-cluster-logging \
    --location=$REGION

GKE 集群的创建可能需要一些时间才能完成。如需验证 Google Cloud 是否已完成集群创建,请前往 Google Cloud 控制台中的 Kubernetes 集群

为 Hugging Face 凭据创建 Kubernetes Secret

在 Cloud Shell 中,如需为 Hugging Face 凭据创建 Kubernetes Secret,请执行以下操作:

  1. 配置 kubectl 以连接到您的集群:

    gcloud container clusters get-credentials $CLUSTER_NAME \
        --region=$REGION
    
  2. 创建一个 Kubernetes Secret 来存储您的 Hugging Face 令牌:

    kubectl create secret generic hf-secret \
        --from-literal=hf_api_token=${HF_TOKEN} \
        --dry-run=client -o yaml | kubectl apply -f -
    

创建 Google Cloud Storage 存储桶

如果您想使用新的存储桶来存储训练制品,请运行以下命令:

gcloud storage buckets create gs://$GCS_BUCKET --location=$REGION

如果您想使用现有存储桶,可以跳过此步骤。不过,您必须确保存储桶与集群位于同一区域。

将训练代码保存为 ConfigMap

为避免将训练脚本嵌入到容器映像中,您可以将其作为 ConfigMap 存储在集群中。此 ConfigMap 会装载到 Pod 文件系统,这样一来,您无需重新创建整个 Ray 集群即可更新训练脚本。

  1. 前往 code 文件夹,然后创建一个新文件。

    将以下 code/vision_train.py 代码复制到这个新文件中:

    import argparse
    import datetime
    import ray
    import ray.train.huggingface.transformers
    import torch
    from PIL import Image
    from datasets import load_dataset
    from peft import LoraConfig
    from ray.train import ScalingConfig, RunConfig
    from ray.train.torch import TorchTrainer
    from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
    from trl import SFTConfig
    from trl import SFTTrainer
    
    # System message for the assistant
    system_message = "You are an expert product description writer for Amazon."
    
    # User prompt that combines the user query and the schema
    user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
    Only return description. The description should be SEO optimized and for a better mobile search experience.
    
    <PRODUCT>
    {product}
    </PRODUCT>
    
    <CATEGORY>
    {category}
    </CATEGORY>
    """
    
    def get_args():
        parser = argparse.ArgumentParser()
        parser.add_argument("--model_id", type=str, default="google/gemma-3-4b-it", help="Hugging Face model ID")
        # parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face token for private models")
        parser.add_argument("--dataset_name", type=str, default="philschmid/amazon-product-descriptions-vlm", help="Hugging Face dataset name")
        parser.add_argument("--output_dir", type=str, default="gemma-3-4b-seo-optimized", help="Directory to save model checkpoints")
        parser.add_argument("--gcs_bucket", type=str, required=True, help="storage bucket name used to synchronize tasks and save checkpoints")
        parser.add_argument("--push_to_hub", help="Push model to Hugging Face hub", action="store_true")
    
        # LoRA arguments
        parser.add_argument("--lora_r", type=int, default=16, help="LoRA attention dimension")
        parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha scaling factor")
        parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout probability")
    
        # SFTConfig arguments
        parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length")
        parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs")
        parser.add_argument("--per_device_train_batch_size", type=int, default=1, help="Batch size per device during training")
        parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Gradient accumulation steps")
        parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
        parser.add_argument("--logging_steps", type=int, default=10, help="Log every X steps")
        parser.add_argument("--save_strategy", type=str, default="epoch", help="Checkpoint save strategy")
        parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every X steps")
    
        return parser.parse_args()
    
    # Convert dataset to OAI messages
    def format_data(sample):
        return {
            "messages": [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": system_message}],
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": user_prompt.format(
                                product=sample["Product Name"],
                                category=sample["Category"],
                            ),
                        },
                        {
                            "type": "image",
                            "image": sample["image"],
                        },
                    ],
                },
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": sample["description"]}],
                },
            ],
        }
    
    def process_vision_info(messages: list[dict]) -> list[Image.Image]:
        image_inputs = []
        # Iterate through each conversation
        for msg in messages:
            # Get content (ensure it's a list)
            content = msg.get("content", [])
            if not isinstance(content, list):
                content = [content]
    
            # Check each content element for images
            for element in content:
                if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                    # Get the image and convert to RGB
                    if "image" in element:
                        image = element["image"]
                    else:
                        image = element
                    image_inputs.append(image.convert("RGB"))
        return image_inputs
    
    def train(args):
        # Load dataset from the hub
        dataset = load_dataset(args.dataset_name, split="train", streaming=True)
    
        # Convert dataset to OAI messages
        # need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
        dataset = [format_data(sample) for sample in dataset]
    
        # Hugging Face model id
        model_id = args.model_id
    
        # Check if GPU benefits from bfloat16
        if torch.cuda.get_device_capability()[0] < 8:
            raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
    
        # Define model init arguments
        model_kwargs = dict(
            attn_implementation="eager",  # Use "flash_attention_2" when running on Ampere or newer GPU
            torch_dtype=torch.bfloat16,  # What torch dtype to use, defaults to auto
            # device_map="auto",  # Let torch decide how to load the model
        )
    
        # BitsAndBytesConfig int-4 config
        model_kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
            bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
        )
    
        # Load model and tokenizer
        model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
        processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
    
        peft_config = LoraConfig(
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            bias="none",
            target_modules="all-linear",
            task_type="CAUSAL_LM",
            modules_to_save=[
                "lm_head",
                "embed_tokens",
            ],
        )
    
        args = SFTConfig(
            output_dir=args.output_dir,  # directory to save and repository id
            num_train_epochs=args.num_train_epochs,  # number of training epochs
            per_device_train_batch_size=args.per_device_train_batch_size,  # batch size per device during training
            gradient_accumulation_steps=args.gradient_accumulation_steps,  # number of steps before performing a backward/update pass
            gradient_checkpointing=True,  # use gradient checkpointing to save memory
            optim="adamw_torch_fused",  # use fused adamw optimizer
            logging_steps=args.logging_steps,  # log every N steps
            save_strategy=args.save_strategy,  # save checkpoint every epoch
            learning_rate=args.learning_rate,  # learning rate, based on QLoRA paper
            bf16=True,  # use bfloat16 precision
            max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
            warmup_ratio=0.03,  # warmup ratio based on QLoRA paper
            lr_scheduler_type="constant",  # use constant learning rate scheduler
            push_to_hub=args.push_to_hub,  # push model to hub
            report_to="tensorboard",  # report metrics to tensorboard
            gradient_checkpointing_kwargs={
                "use_reentrant": False
            },  # use reentrant checkpointing
            dataset_text_field="",  # need a dummy field for collator
            dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
        )
        args.remove_unused_columns = False  # important for collator
    
        # Create a data collator to encode text and image pairs
        def collate_fn(examples):
            texts = []
            images = []
            for example in examples:
                image_inputs = process_vision_info(example["messages"])
                text = processor.apply_chat_template(
                    example["messages"], add_generation_prompt=False, tokenize=False
                )
                texts.append(text.strip())
                images.append(image_inputs)
    
            # Tokenize the texts and process the images
            batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
    
            # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
            labels = batch["input_ids"].clone()
    
            # Mask image tokens
            image_token_id = [
                processor.tokenizer.convert_tokens_to_ids(
                    processor.tokenizer.special_tokens_map["boi_token"]
                )
            ]
            # Mask tokens for not being used in the loss computation
            labels[labels == processor.tokenizer.pad_token_id] = -100
            labels[labels == image_token_id] = -100
            labels[labels == 262144] = -100
    
            batch["labels"] = labels
            return batch
    
        trainer = SFTTrainer(
            model=model,
            args=args,
            train_dataset=dataset,
            peft_config=peft_config,
            processing_class=processor,
            data_collator=collate_fn,
        )
    
        callback = ray.train.huggingface.transformers.RayTrainReportCallback()
        trainer.add_callback(callback)
        trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
    
        # Start training, the model will be automatically saved to the Hub and the output directory
        trainer.train()
    
        # Save the final model again to the Hugging Face Hub
        trainer.save_model()
    
    if __name__ == "__main__":
        args = get_args()
        print("Starting training task!")
        training_name = f"gemma_vision_train_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"
    
        gcs_bucket = args.gcs_bucket
        if not gcs_bucket.startswith("gs://"):
            gcs_bucket = "gs://" + gcs_bucket
    
        run_config = RunConfig(
            storage_path=gcs_bucket,
            name=training_name,
        )
        scaling_config = ScalingConfig(num_workers=16, use_gpu=True, accelerator_type="B200")
        ray_trainer = TorchTrainer(train, train_loop_config=args, scaling_config=scaling_config, run_config=run_config)
        print("Commencing training!")
        result = ray_trainer.fit()
    
  2. 保存文件。

  3. 在集群中创建 ConfigMap 对象:

    kubectl create cm ray-job-cm --from-file=code -o yaml --dry-run=client | kubectl apply -f -
    

    如需更新训练脚本,请重新运行上述命令。任何更改可能需要一分钟才能传播到所有 pod。

配置 Ray 集群

  1. 如需在 GKE 集群中创建 Ray 集群,请将以下 YAML 保存为 ray_cluster.yaml 文件。

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: gemma3-tuning
    spec:
      rayVersion: '2.48.0'
      headGroupSpec:
        rayStartParams:
          dashboard-host: '0.0.0.0'
        template:
          metadata:
          spec:
            containers:
            - name: ray-head
              image: rayproject/ray:2.48.0
              ports:
              - containerPort: 6379
                name: gcs
              - containerPort: 8265
                name: dashboard
              - containerPort: 10001
                name: client
              resources:
                limits:
                  cpu: "24"
                  ephemeral-storage: "9Gi"
                  memory: "64Gi"
                requests:
                  cpu: "24"
                  ephemeral-storage: "9Gi"
                  memory: "64Gi"
              env:
                - name: HF_TOKEN
                  valueFrom:
                    secretKeyRef:
                      name: hf-secret
                      key: hf_api_token
              volumeMounts:
                - name: job-code
                  mountPath: /code/
                - mountPath: /mnt/local-ssd/
                  name: local-storage
            volumes:
              - name: job-code
                configMap:
                  name: ray-job-cm
              - name: local-storage
                emptyDir: { }
      workerGroupSpecs:
      - replicas: 2
        minReplicas: 1
        maxReplicas: 5
        groupName: gpu-group
        rayStartParams: {}
        template:
          spec:
            containers:
            - name: ray-worker
              image: rayproject/ray:2.48.0-gpu
              resources:
                limits:
                  nvidia.com/gpu: "8"
                requests:
                  nvidia.com/gpu: "8"
              env:
                - name: HF_TOKEN
                  valueFrom:
                    secretKeyRef:
                      name: hf-secret
                      key: hf_api_token
              volumeMounts:
                - name: job-code
                  mountPath: /code/
                - mountPath: /mnt/local-ssd/
                  name: local-storage
            volumes:
              - name: job-code
                configMap:
                  name: ray-job-cm
              - name: local-storage
                emptyDir: { }
            nodeSelector:
              cloud.google.com/gke-accelerator: nvidia-b200
              cloud.google.com/reservation-name: $RESERVATION
              cloud.google.com/reservation-affinity: "specific"
              cloud.google.com/gke-gpu-driver-version: latest
    
  2. 使用以下命令将此 YAML 定义应用于您的集群:

    envsubst < ray_cluster.yaml | kubectl apply -f -
    

    $RESERVATION 标志会自动替换为您配置为环境变量的名称。

    Ray Operator 会创建 raylet Pod,这会触发集群自动扩缩,以便为这些 Pod 提供合适的节点。系统会在集群中创建三个 pod:一个头节点和两个工作器节点。工作器节点配备了 B200 GPU。

  3. 如需验证这三个 pod 是否都已准备就绪,请运行以下命令:

    kubectl get pods
    

    就绪的 Ray 集群的 pod 列表类似于以下内容:

    NAME                                   READY   STATUS    RESTARTS   AGE
    gemma3-tuning-gpu-group-worker-s4h8f   2/2     Running   0          16m
    gemma3-tuning-gpu-group-worker-stg5f   2/2     Running   0          5m34s
    gemma3-tuning-head-zbdvp               2/2     Running   0          16m
    

安排训练作业

  1. 将以下内容保存为 ray_job.yaml 文件:

    apiVersion: ray.io/v1
    kind: RayJob
    metadata:
      name: test-ray-job
    spec:
      entrypoint: python /code/vision_train.py --gcs_bucket $GCS_BUCKET
      runtimeEnvYAML: |
        pip:
          - torch==2.8.0
          - torchvision==0.23.0
          - ray==2.48.0
          - transformers==4.55.2
          - datasets==4.0.0
          - evaluate==0.4.5
          - accelerate==1.10.0
          - pillow==11.3.0
          - bitsandbytes==0.47.0
          - trl==0.21.0
          - peft==0.17.0
      clusterSelector:
        ray.io/cluster: gemma3-tuning
    
  2. 将 RayJob 定义提交到 RayCluster:

    envsubst < ray_job.yaml | kubectl apply -f -
    
  3. 检查集群中是否有新的 Pod:

    kubectl get pods
    

    记下您在输出中看到的 test-ray-job- Pod 的全名。此名称是您的作业独有的。

  4. 检查训练进度。将 gemma-training-ray-job-UNIQUE_ID 替换为您在上一步中记下的唯一 Pod 名称。

    kubectl logs -f <gemma-training-ray-job-UNIQUE_ID>
    

    您看到的输出类似于以下内容:

    2025-08-20 08:29:34,966 INFO cli.py:41 -- Job submission server address: http://gemma3-tuning-head-svc.default.svc.cluster.local:8265
    2025-08-20 08:29:34,991 SUCC cli.py:65 -- -----------------------------------------------
    2025-08-20 08:29:34,991 SUCC cli.py:66 -- Job 'test-ray-job-82mm7' submitted successfully
    2025-08-20 08:29:34,991 SUCC cli.py:67 -- -----------------------------------------------
    2025-08-20 08:29:34,992 INFO cli.py:291 -- Next steps
    2025-08-20 08:29:34,992 INFO cli.py:292 -- Query the logs of the job:
    2025-08-20 08:29:34,992 INFO cli.py:294 -- ray job logs test-ray-job-82mm7
    2025-08-20 08:29:34,992 INFO cli.py:296 -- Query the status of the job:
    2025-08-20 08:29:34,992 INFO cli.py:298 -- ray job status test-ray-job-82mm7
    2025-08-20 08:29:34,992 INFO cli.py:300 -- Request the job to be stopped:
    2025-08-20 08:29:34,992 INFO cli.py:302 -- ray job stop test-ray-job-82mm7
    2025-08-20 08:29:35,003 INFO cli.py:312 -- Tailing logs until the job exits (disable with --no-wait):
    2025-08-20 08:29:34,982 INFO job_manager.py:531 -- Runtime env is setting up.
    Starting training task!
    Commencing training!
    2025-08-20 08:30:08,498 INFO worker.py:1606 -- Using address 10.76.0.17:6379 set in the environment variable RAY_ADDRESS
    2025-08-20 08:30:08,506 INFO worker.py:1747 -- Connecting to existing Ray cluster at address: 10.76.0.17:6379...
    2025-08-20 08:30:08,527 INFO worker.py:1918 -- Connected to Ray cluster. View the dashboard at 10.76.0.17:8265
    2025-08-20 08:30:08,701 INFO tune.py:253 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `<FrameworkTrainer>(...)`.
    2025-08-20 08:30:08,951 WARNING tune_controller.py:2132 -- The maximum number of pending trials has been automatically set to the number of available cluster CPUs, which is high (519 CPUs/pending trials). If you're running an experiment with a large number of trials, this could lead to scheduling overhead. In this case, consider setting the `TUNE_MAX_PENDING_TRIALS_PG` environment variable to the desired maximum number of concurrent pending trials.
    2025-08-20 08:30:08,953 WARNING tune_controller.py:2132 -- The maximum number of pending trials has been automatically set to the number of available cluster CPUs, which is high (519 CPUs/pending trials). If you're running an experiment with a large number of trials, this could lead to scheduling overhead. In this case, consider setting the `TUNE_MAX_PENDING_TRIALS_PG` environment variable to the desired maximum number of concurrent pending trials.
    
    View detailed results here: YOUR_GCS_BUCKET/gemma_vision_train_2025_08_20_08_30_07
    To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2025-08-20_04-43-14_215096_1/artifacts/2025-08-20_08-30-08/gemma_vision_train_2025_08_20_08_30_07/driver_artifacts`
    
    Training started with configuration:
    ╭──────────────────────────────────────────────────────────────────────╮
    │ Training config                                                      │
    ├──────────────────────────────────────────────────────────────────────┤
    │ train_loop_config/dataset_name                  ...-descriptions-vlm │
    │ train_loop_config/gcs_bucket                    ...-bucket-yooo-west │
    │ train_loop_config/gradient_accumulation_steps                      4 │
    │ train_loop_config/learning_rate                               0.0002 │
    │ train_loop_config/logging_steps                                   10 │
    │ train_loop_config/lora_alpha                                      16 │
    │ train_loop_config/lora_dropout                                  0.05 │
    │ train_loop_config/lora_r                                          16 │
    │ train_loop_config/max_seq_length                                 512 │
    │ train_loop_config/model_id                      google/gemma-3-4b-it │
    │ train_loop_config/num_train_epochs                                 3 │
    │ train_loop_config/output_dir                    ...-4b-seo-optimized │
    │ train_loop_config/per_device_train_batch_size                      1 │
    │ train_loop_config/push_to_hub                                  False │
    │ train_loop_config/save_steps                                     100 │
    │ train_loop_config/save_strategy                                epoch │
    ╰──────────────────────────────────────────────────────────────────────╯
    (RayTrainWorker pid=45455, ip=10.76.0.71) Setting up process group for: env:// [rank=0, world_size=16]
    (TorchTrainer pid=45197, ip=10.76.0.71) Started distributed worker processes:
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=4c934ab2f646a578b03cc335586f30b943e811b645526a74c50bfca1, ip=10.76.0.71, pid=45455) world_rank=0, local_rank=0, node_rank=0
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=4c934ab2f646a578b03cc335586f30b943e811b645526a74c50bfca1, ip=10.76.0.71, pid=45450) world_rank=1, local_rank=1, node_rank=0
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=4c934ab2f646a578b03cc335586f30b943e811b645526a74c50bfca1, ip=10.76.0.71, pid=45454) world_rank=2, local_rank=2, node_rank=0
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=4c934ab2f646a578b03cc335586f30b943e811b645526a74c50bfca1, ip=10.76.0.71, pid=45448) world_rank=3, local_rank=3, node_rank=0
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=4c934ab2f646a578b03cc335586f30b943e811b645526a74c50bfca1, ip=10.76.0.71, pid=45453) world_rank=4, local_rank=4, node_rank=0
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=4c934ab2f646a578b03cc335586f30b943e811b645526a74c50bfca1, ip=10.76.0.71, pid=45452) world_rank=5, local_rank=5, node_rank=0
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=4c934ab2f646a578b03cc335586f30b943e811b645526a74c50bfca1, ip=10.76.0.71, pid=45451) world_rank=6, local_rank=6, node_rank=0
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=4c934ab2f646a578b03cc335586f30b943e811b645526a74c50bfca1, ip=10.76.0.71, pid=45449) world_rank=7, local_rank=7, node_rank=0
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=c0db52b44f891f3d6a1cedcbea4c6beb2c8434c66ef414dc15e65743, ip=10.76.0.135, pid=45729) world_rank=8, local_rank=0, node_rank=1
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=c0db52b44f891f3d6a1cedcbea4c6beb2c8434c66ef414dc15e65743, ip=10.76.0.135, pid=45726) world_rank=9, local_rank=1, node_rank=1
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=c0db52b44f891f3d6a1cedcbea4c6beb2c8434c66ef414dc15e65743, ip=10.76.0.135, pid=45728) world_rank=10, local_rank=2, node_rank=1
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=c0db52b44f891f3d6a1cedcbea4c6beb2c8434c66ef414dc15e65743, ip=10.76.0.135, pid=45727) world_rank=11, local_rank=3, node_rank=1
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=c0db52b44f891f3d6a1cedcbea4c6beb2c8434c66ef414dc15e65743, ip=10.76.0.135, pid=45725) world_rank=12, local_rank=4, node_rank=1
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=c0db52b44f891f3d6a1cedcbea4c6beb2c8434c66ef414dc15e65743, ip=10.76.0.135, pid=45724) world_rank=13, local_rank=5, node_rank=1
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=c0db52b44f891f3d6a1cedcbea4c6beb2c8434c66ef414dc15e65743, ip=10.76.0.135, pid=45723) world_rank=14, local_rank=6, node_rank=1
    (TorchTrainer pid=45197, ip=10.76.0.71) - (node_id=c0db52b44f891f3d6a1cedcbea4c6beb2c8434c66ef414dc15e65743, ip=10.76.0.135, pid=45722) world_rank=15, local_rank=7, node_rank=1
    
    ...
    
    Training finished iteration 3 at 2025-08-20 08:40:43. Total running time: 10min 34s
    ╭─────────────────────────────────────────╮
    │ Training result                         │
    ├─────────────────────────────────────────┤
    │ checkpoint_dir_name   checkpoint_000002 │
    │ time_this_iter_s               152.6374 │
    │ time_total_s                  525.88585 │
    │ training_iteration                    3 │
    │ epoch                           2.75294 │
    │ grad_norm                      47.27161 │
    │ learning_rate                    0.0002 │
    │ loss                            22.5275 │
    │ mean_token_accuracy             0.90325 │
    │ num_tokens                     1583017. │
    │ step                                 60 │
    ╰─────────────────────────────────────────╯
    
    ...
    
    Training completed after 3 iterations at 2025-08-20 08:40:52. Total running time: 10min 43s
    2025-08-20 08:40:53,113 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to 'YOUR_GCS_BUCKET/gemma_vision_train_2025_08_20_08_30_07' in 0.1663s.
    
    2025-08-20 08:40:58,304 SUCC cli.py:65 -- ----------------------------------
    2025-08-20 08:40:58,305 SUCC cli.py:66 -- Job 'test-ray-job-82mm7' succeeded
    2025-08-20 08:40:58,305 SUCC cli.py:67 -- ----------------------------------
    

    监控工作负载

您可以使用 Ray 中的信息中心来监控集群中已调度的工作负载。

如需访问此信息中心,您需要在新的终端窗口中运行以下命令,以设置端口转发到集群:

kubectl port-forward service/gemma3-tuning-head-svc 8265:8265 > fwd.log 2>&1 &
  1. 在浏览器中打开以下链接:[http://localhost:8265](http://localhost:8265)

  2. (可选)如果您使用的是 Cloud Shell,则在运行上一步中的命令后,可以点击网页预览按钮,如下图所示:

    “网页预览”按钮。

    选择更改端口选项,输入 8265,然后点击更改并预览。 Ray 信息中心会在新标签页中打开。

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除项目

删除 Google Cloud 项目:

gcloud projects delete PROJECT_ID

删除您的资源

  1. 如需删除 Ray 集群并释放 GPU 赋能的节点,请运行以下命令:

    kubectl delete -f ray_cluster.yaml
    

    GKE 会自动缩减集群规模,并释放 Ray 使用的 A4 机器。

  2. 如需删除整个 GKE 集群,请运行以下命令:

    gcloud container clusters delete $CLUSTER_NAME \
    --region=$REGION
    

后续步骤