本教程介绍了如何在多节点 Slurm 集群上对 Gemma 3 大语言模型 (LLM) 进行微调,该集群使用两个 A4 虚拟机 (VM) 实例。在本教程中,您将执行以下操作:
创建一个自定义映像
配置 RDMA 网络。
运行分布式微调作业。为了实现高效的多节点训练,您可以使用具有完全分片数据并行处理 (FSDP) 功能的 Hugging Face Accelerate 库。
本教程适用于机器学习 (ML) 工程师、平台管理员和运维人员,以及对使用 Slurm 作业调度功能处理微调工作负载感兴趣的数据和 AI 专家。
目标
使用 Hugging Face 访问 Gemma 3。
准备环境。
创建 A4 Slurm 集群。
准备工作负载。
运行微调作业。
监控作业。
清理。
费用
在本文档中,您将使用 Google Cloud的以下收费组件:
如需根据您的预计使用情况来估算费用,请使用价格计算器。
准备工作
- 登录您的 Google Cloud 账号。如果您是 Google Cloud新手,请 创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
-
安装 Google Cloud CLI。
-
如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI。
-
如需初始化 gcloud CLI,请运行以下命令:
gcloud init -
选择或创建项目所需的角色
- 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授角色的任何项目。
-
创建项目:如需创建项目,您需要拥有 Project Creator 角色 (
roles/resourcemanager.projectCreator),该角色包含resourcemanager.projects.create权限。了解如何授予角色。
-
创建 Google Cloud 项目:
gcloud projects create PROJECT_ID
将
PROJECT_ID替换为您要创建的 Google Cloud 项目的名称。 -
选择您创建的 Google Cloud 项目:
gcloud config set project PROJECT_ID
将
PROJECT_ID替换为您的 Google Cloud 项目名称。
启用所需的 API:
启用 API 所需的角色
如需启用 API,您需要拥有 Service Usage Admin IAM 角色 (
roles/serviceusage.serviceUsageAdmin),该角色包含serviceusage.services.enable权限。了解如何授予角色。gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
-
安装 Google Cloud CLI。
-
如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI。
-
如需初始化 gcloud CLI,请运行以下命令:
gcloud init -
选择或创建项目所需的角色
- 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授角色的任何项目。
-
创建项目:如需创建项目,您需要拥有 Project Creator 角色 (
roles/resourcemanager.projectCreator),该角色包含resourcemanager.projects.create权限。了解如何授予角色。
-
创建 Google Cloud 项目:
gcloud projects create PROJECT_ID
将
PROJECT_ID替换为您要创建的 Google Cloud 项目的名称。 -
选择您创建的 Google Cloud 项目:
gcloud config set project PROJECT_ID
将
PROJECT_ID替换为您的 Google Cloud 项目名称。
启用所需的 API:
启用 API 所需的角色
如需启用 API,您需要拥有 Service Usage Admin IAM 角色 (
roles/serviceusage.serviceUsageAdmin),该角色包含serviceusage.services.enable权限。了解如何授予角色。gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
-
向您的用户账号授予角色。对以下每个 IAM 角色运行以下命令一次:
roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/serviceusage.serviceUsageAdmingcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
替换以下内容:
PROJECT_ID:您的项目 ID。USER_IDENTIFIER:您的用户 账号的标识符。例如,myemail@example.com。ROLE:您授予用户账号的 IAM 角色。
- 为您的 Google Cloud 项目启用默认服务账号:
gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com \ --project=PROJECT_ID
将 PROJECT_NUMBER 替换为您的项目编号。如需查看项目编号,请参阅 获取现有项目。
- 向默认服务账号授予 Editor 角色 (
roles/editor):gcloud projects add-iam-policy-binding PROJECT_ID \ --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com" \ --role=roles/editor
- 为您的用户账号创建本地身份验证凭据:
gcloud auth application-default login
- 为项目启用 OS Login:
gcloud compute project-info add-metadata --metadata=enable-oslogin=TRUE
- 登录或创建 Hugging Face 账号。
使用 Hugging Face 访问 Gemma 3
如需使用 Hugging Face 访问 Gemma 3,请按以下步骤操作:
创建 Hugging Face
read访问令牌。 依次点击您的个人资料 > 设置 > 访问令牌 > +创建新令牌复制并保存
read access令牌值。您将在本教程的后面部分使用该地址。
准备环境
如需准备环境,请按照以下步骤操作:
克隆 Cluster Toolkit GitHub 代码库:
git clone https://github.com/GoogleCloudPlatform/cluster-toolkit.git创建 Cloud Storage 存储分区,请运行以下命令:
gcloud storage buckets create gs://BUCKET_NAME \ --project=PROJECT_ID替换以下内容:
BUCKET_NAME:符合存储桶命名要求的 Cloud Storage 存储桶的名称。PROJECT_ID:您要在其中创建 Cloud Storage 存储桶的Google Cloud 项目的 ID。
创建 A4 Slurm 集群
如需创建 A4 Slurm 集群,请按照以下步骤操作:
转到
cluster-toolkit目录:cd cluster-toolkit如果您是首次使用 Cluster Toolkit,请构建
gcluster二进制文件:make转到
examples/machine-learning/a4-highgpu-8g目录:cd examples/machine-learning/a4-highgpu-8g/打开
a4high-slurm-deployment.yaml文件,然后按如下方式进行修改:terraform_backend_defaults: type: gcs configuration: bucket: BUCKET_NAME vars: deployment_name: a4-high project_id: PROJECT_ID region: REGION zone: ZONE a4h_cluster_size: 2 a4h_reservation_name: RESERVATION_URL替换以下内容:
BUCKET_NAME:您在上一部分中创建的 Cloud Storage 存储桶的名称。PROJECT_ID:您的 Cloud Storage 所在的Google Cloud 项目的 ID,也是您要创建 Slurm 集群的项目。REGION:预留所在的区域。ZONE:预留所在的可用区。RESERVATION_URL:您要用于创建 Slurm 集群的预留的网址。根据预留所在的具体项目,指定以下某个值:预留存在于您的项目中:
RESERVATION_NAME预留存在于其他项目中,并且您的项目可以使用该预留:
projects/RESERVATION_PROJECT_ID/reservations/RESERVATION_NAME
部署集群:
./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve./gcluster deploy命令是一个两阶段流程,如下所示:第一阶段会构建一个预先安装了所有软件的自定义映像,这可能需要长达 35 分钟的时间才能完成。
第二阶段使用该自定义映像部署集群。此流程应比第一阶段更快完成。
如果第一阶段成功,但第二阶段失败,您可以尝试跳过第一阶段,再次部署 Slurm 集群:
./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve --skip "image" -w
准备工作负载
如需准备工作负载,请按以下步骤操作:
创建工作负载脚本
如需创建微调工作负载将使用的脚本,请按以下步骤操作:
如需设置 Python 虚拟环境,请创建包含以下内容的
install_environment.sh文件:#!/bin/bash # This script should be run ONCE on the login node to set up the # shared Python virtual environment. set -e echo "--- Creating Python virtual environment in /home ---" python3 -m venv ~/.venv echo "--- Activating virtual environment ---" source ~/.venv/bin/activate echo "--- Installing build dependencies ---" pip install --upgrade pip wheel packaging echo "--- Installing PyTorch for CUDA 12.8 ---" pip install torch --index-url https://download.pytorch.org/whl/cu128 echo "--- Installing application requirements ---" pip install -r requirements.txt echo "--- Environment setup complete. You can now submit jobs with sbatch. ---"如需为微调作业指定配置,请创建包含以下内容的
accelerate_config.yaml文件:# Default configuration for a 2-node, 8-GPU-per-node (16 total GPUs) FSDP training job. compute_environment: "LOCAL_MACHINE" distributed_type: "FSDP" downcast_bf16: "no" fsdp_config: fsdp_auto_wrap_policy: "TRANSFORMER_BASED_WRAP" fsdp_backward_prefetch: "BACKWARD_PRE" fsdp_cpu_ram_efficient_loading: true fsdp_forward_prefetch: false fsdp_offload_params: false fsdp_sharding_strategy: "FULL_SHARD" fsdp_state_dict_type: "FULL_STATE_DICT" fsdp_transformer_layer_cls_to_wrap: "Gemma3DecoderLayer" fsdp_use_orig_params: true machine_rank: 0 main_training_function: "main" mixed_precision: "bf16" num_machines: 2 num_processes: 16 rdzv_backend: "static" same_network: true tpu_env: [] use_cpu: false如需指定作业在 Slurm 集群上运行的任务,请创建包含以下内容的
submit.slurm文件:#!/bin/bash #SBATCH --job-name=gemma3-finetune #SBATCH --nodes=2 #SBATCH --ntasks-per-node=8 # 8 tasks per node #SBATCH --gpus-per-task=1 # 1 GPU per task #SBATCH --partition=a4high #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err set -e echo "--- Slurm Job Started ---" # --- STAGE 1: Copy Environment to Local SSD on all nodes --- srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c ' echo "Setting up local environment on $(hostname)..." LOCAL_VENV="/mnt/localssd/venv_job_${SLURM_JOB_ID}" LOCAL_CACHE="/mnt/localssd/hf_cache_job_${SLURM_JOB_ID}" rsync -a --info=progress2 ~/./.venv/ ${LOCAL_VENV}/ mkdir -p ${LOCAL_CACHE} echo "Setup on $(hostname) complete." ' # --- STAGE 2: Run the Training Job using the Local Environment --- echo "--- Starting Training ---" LOCAL_VENV="/mnt/localssd/venv_job_${SLURM_JOB_ID}" LOCAL_CACHE="/mnt/localssd/hf_cache_job_${SLURM_JOB_ID}" LOCAL_OUTPUT_DIR="/mnt/localssd/outputs_${SLURM_JOB_ID}" mkdir -p ${LOCAL_OUTPUT_DIR} # This is the main training command. srun --ntasks=$((SLURM_NNODES * 8)) --gpus-per-task=1 bash -c " source ${LOCAL_VENV}/bin/activate export HF_HOME=${LOCAL_CACHE} export HF_DATASETS_CACHE=${LOCAL_CACHE} # Run the Python script directly. # Accelerate will divide the work python ~/train.py \ --model_id google/gemma-3-12b-pt \ --output_dir ${LOCAL_OUTPUT_DIR} \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 8 \ --num_train_epochs 3 \ --learning_rate 1e-5 \ --save_strategy steps \ --save_steps 100 " # --- STAGE 3: Copy Final Model from Local SSD to Home Directory --- echo "--- Copying final model from local SSD to /home ---" # This command runs only on the first node of the job allocation # and copies the final model back to the persistent shared directory. srun --nodes=1 --ntasks=1 --ntasks-per-node=1 bash -c " rsync -a --info=progress2 ${LOCAL_OUTPUT_DIR}/ ~/gemma-12b-text-to-sql-finetuned/ " echo "--- Slurm Job Finished ---"如需为微调作业指定依赖项,请创建包含以下内容的
requirements.txt文件:# Hugging Face Libraries (Pinned to recent, stable versions for reproducibility) transformers==4.53.3 datasets==4.0.0 accelerate==1.9.0 evaluate==0.4.5 bitsandbytes==0.46.1 trl==0.19.1 peft==0.16.0 # Other dependencies tensorboard==2.20.0 protobuf==6.31.1 sentencepiece==0.2.0如需指定作业的说明,请创建
train.py文件,其中包含以下内容:import torch import argparse from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model from trl import SFTTrainer, SFTConfig from huggingface_hub import login def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default="google/gemma-3-12b-pt", help="Hugging Face model ID") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face token for private models") parser.add_argument("--dataset_name", type=str, default="philschmid/gretel-synthetic-text-to-sql", help="Hugging Face dataset name") parser.add_argument("--output_dir", type=str, default="gemma-12b-text-to-sql", help="Directory to save model checkpoints") # LoRA arguments parser.add_argument("--lora_r", type=int, default=16, help="LoRA attention dimension") parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha scaling factor") parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout probability") # SFTConfig arguments parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length") parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size per device during training") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate") parser.add_argument("--logging_steps", type=int, default=10, help="Log every X steps") parser.add_argument("--save_strategy", type=str, default="steps", help="Checkpoint save strategy") parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every X steps") return parser.parse_args() def main(): args = get_args() # --- 1. Setup and Login --- if args.hf_token: login(args.hf_token) # --- 2. Create and prepare the fine-tuning dataset --- # The SFTTrainer will use the `formatting_func` to apply the chat template. dataset = load_dataset(args.dataset_name, split="train") dataset = dataset.shuffle().select(range(12500)) dataset = dataset.train_test_split(test_size=2500/12500) # --- 3. Configure Model and Tokenizer --- if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: torch_dtype_obj = torch.bfloat16 torch_dtype_str = "bfloat16" else: torch_dtype_obj = torch.float16 torch_dtype_str = "float16" tokenizer = AutoTokenizer.from_pretrained(args.model_id) tokenizer.pad_token = tokenizer.eos_token gemma_chat_template = ( "" "" ) tokenizer.chat_template = gemma_chat_template # --- 4. Define the Formatting Function --- # This function will be used by the SFTTrainer to format each sample # from the dataset into the correct chat template format. def formatting_func(example): # The create_conversation logic is now implicitly handled by this. # We need to construct the messages list here. system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA." user_prompt = "Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n\n<SCHEMA>\n{context}\n</SCHEMA>\n\n<USER_QUERY>\n{question}\n</USER_QUERY>\n" messages = [ {"role": "user", "content": user_prompt.format(question=example["sql_prompt"][0], context=example["sql_context"][0])}, {"role": "assistant", "content": example["sql"][0]} ] return tokenizer.apply_chat_template(messages, tokenize=False) # --- 5. Load Quantized Model and Apply PEFT --- # Define the quantization configuration quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch_dtype_obj, bnb_4bit_use_double_quant=True, ) config = AutoConfig.from_pretrained(args.model_id) config.use_cache = False # Load the base model with quantization print("Loading base model...") model = AutoModelForCausalLM.from_pretrained( args.model_id, config=config, quantization_config=quantization_config, attn_implementation="eager", torch_dtype=torch_dtype_obj, ) # Prepare the model for k-bit training model = prepare_model_for_kbit_training(model) # Configure LoRA. peft_config = LoraConfig( lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, r=args.lora_r, bias="none", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], task_type="CAUSAL_LM", ) # Apply the PEFT config to the model print("Applying PEFT configuration...") model = get_peft_model(model, peft_config) model.print_trainable_parameters() # --- 6. Configure Training Arguments --- training_args = SFTConfig( output_dir=args.output_dir, max_seq_length=args.max_seq_length, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, logging_steps=args.logging_steps, save_strategy=args.save_strategy, save_steps=args.save_steps, packing=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="adamw_torch", fp16=True if torch_dtype_obj == torch.float16 else False, bf16=True if torch_dtype_obj == torch.bfloat16 else False, max_grad_norm=0.3, warmup_ratio=0.03, lr_scheduler_type="constant", push_to_hub=False, report_to="tensorboard", dataset_kwargs={ "add_special_tokens": False, "append_concat_token": True, } ) # --- 7. Create Trainer and Start Training --- trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], formatting_func=formatting_func, ) print("Starting training...") trainer.train() print("Training finished.") # --- 8. Save the final model --- print(f"Saving final model to {args.output_dir}") trainer.save_model() if __name__ == "__main__": main()
将脚本上传到 Slurm 集群
如需将您在上一部分中创建的脚本上传到 Slurm 集群,请按以下步骤操作:
如需确定登录节点,请列出项目中的所有 A4 虚拟机:
gcloud compute instances list --filter="machineType:a4-highgpu-8g"登录节点的名称类似于
a4-high-login-001。将脚本上传到登录节点的主目录:
gcloud compute scp \ --project=PROJECT_ID \ --zone=ZONE \ --tunnel-through-iap \ ./train.py \ ./requirements.txt \ ./submit.slurm \ ./install_environment.sh \ ./accelerate_config.yaml \ "LOGIN_NODE_NAME":~/将
LOGIN_NODE_NAME替换为登录节点的名称。
连接到 Slurm 集群
通过 SSH 连接到登录节点,从而连接到 Slurm 集群:
gcloud compute ssh LOGIN_NODE_NAME \
--project=PROJECT_ID \
--tunnel-through-iap \
--zone=ZONE
安装框架和工具
连接到登录节点后,请按照以下步骤安装框架和工具:
为 Hugging Face 访问令牌创建环境变量:
export HUGGING_FACE_TOKEN="HUGGING_FACE_TOKEN"设置包含所有必需依赖项的 Python 虚拟环境:
chmod +x install_environment.sh ./install_environment.sh
启动微调工作负载
如需开始微调工作负载,请按以下步骤操作:
将作业提交给 Slurm 调度程序:
sbatch submit.slurm在 Slurm 集群的登录节点上,您可以通过检查
home目录中创建的输出文件来监控作业的进度:tail -f slurm-gemma3-finetune.err如果作业成功开始,
.err文件会显示一个进度条,该进度条会随着作业的进展而更新。
监控工作负载
您可以监控 Slurm 集群中 GPU 的使用情况,以验证微调作业是否在高效运行。为此,请在浏览器中打开以下链接:
https://console.cloud.google.com/monitoring/metrics-explorer?project=PROJECT_ID&pageState=%7B%22xyChart%22%3A%7B%22dataSets%22%3A%5B%7B%22timeSeriesFilter%22%3A%7B%22filter%22%3A%22metric.type%3D%5C%22agent.googleapis.com%2Fgpu%2Futilization%5C%22%20resource.type%3D%5C%22gce_instance%5C%22%22%2C%22perSeriesAligner%22%3A%22ALIGN_MEAN%22%7D%2C%22plotType%22%3A%22LINE%22%7D%5D%7D%7D
监控工作负载时,您可以看到以下内容:
GPU 使用情况:对于运行正常的微调作业,您应该会看到所有 16 个 GPU(集群中每个虚拟机有 8 个 GPU)的使用率在整个训练过程中上升并稳定在特定水平。
作业时长:此作业大约需要 1 小时才能完成。
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
删除项目
删除 Google Cloud 项目:
gcloud projects delete PROJECT_ID
删除 Slurm 集群
如需删除 Slurm 集群,请按以下步骤操作:
转到
cluster-toolkit目录。销毁 Terraform 文件和所有已创建的资源:
./gcluster destroy a4-high --auto-approve