本教學課程說明如何使用兩個 A4 虛擬機器 (VM) 執行個體,在多節點 Slurm 叢集上微調 Gemma 3 大型語言模型 (LLM)。在本教學課程中,您將執行下列操作:
建立自訂映像檔。
設定 RDMA 網路。
執行分散式微調工作。如要有效率地進行多節點訓練,請使用搭配 Fully Sharded Data Parallel (FSDP) 的 Hugging Face Accelerate 程式庫。
本教學課程適合機器學習 (ML) 工程師、平台管理員和營運人員,以及有興趣使用 Slurm 工作排程功能處理微調工作負載的資料和 AI 專家。
目標
使用 Hugging Face 存取 Gemma 3。
準備環境。
建立 A4 Slurm 叢集。
準備工作負載。
執行微調工作。
監控工作。
清除所用資源。
費用
在本文件中,您會使用下列 Google Cloud的計費元件:
如要根據預測用量估算費用,請使用 Pricing Calculator。
事前準備
- 登入 Google Cloud 帳戶。如果您是 Google Cloud新手,歡迎 建立帳戶,親自評估產品在實際工作環境中的成效。新客戶還能獲得價值 $300 美元的免費抵免額,可用於執行、測試及部署工作負載。
-
安裝 Google Cloud CLI。
-
若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI。
-
執行下列指令,初始化 gcloud CLI:
gcloud init -
選取或建立專案所需的角色
- 選取專案:選取專案時,不需要具備特定 IAM 角色,只要您已獲授角色,即可選取任何專案。
-
建立專案:如要建立專案,您需要具備專案建立者角色 (
roles/resourcemanager.projectCreator),其中包含resourcemanager.projects.create權限。瞭解如何授予角色。
-
建立 Google Cloud 專案:
gcloud projects create PROJECT_ID
將
PROJECT_ID替換為您要建立的 Google Cloud 專案名稱。 -
選取您建立的 Google Cloud 專案:
gcloud config set project PROJECT_ID
將
PROJECT_ID替換為 Google Cloud 專案名稱。
啟用必要的 API:
啟用 API 時所需的角色
如要啟用 API,您需要具備服務使用情形管理員 IAM 角色 (
roles/serviceusage.serviceUsageAdmin),其中包含serviceusage.services.enable權限。瞭解如何授予角色。gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
-
安裝 Google Cloud CLI。
-
若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI。
-
執行下列指令,初始化 gcloud CLI:
gcloud init -
選取或建立專案所需的角色
- 選取專案:選取專案時,不需要具備特定 IAM 角色,只要您已獲授角色,即可選取任何專案。
-
建立專案:如要建立專案,您需要具備專案建立者角色 (
roles/resourcemanager.projectCreator),其中包含resourcemanager.projects.create權限。瞭解如何授予角色。
-
建立 Google Cloud 專案:
gcloud projects create PROJECT_ID
將
PROJECT_ID替換為您要建立的 Google Cloud 專案名稱。 -
選取您建立的 Google Cloud 專案:
gcloud config set project PROJECT_ID
將
PROJECT_ID替換為 Google Cloud 專案名稱。
啟用必要的 API:
啟用 API 時所需的角色
如要啟用 API,您需要具備服務使用情形管理員 IAM 角色 (
roles/serviceusage.serviceUsageAdmin),其中包含serviceusage.services.enable權限。瞭解如何授予角色。gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
-
將角色授予使用者帳戶。針對下列每個 IAM 角色,執行一次下列指令:
roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/serviceusage.serviceUsageAdmingcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
更改下列內容:
PROJECT_ID:專案 ID。USER_IDENTIFIER:使用者帳戶的 ID。 例如:myemail@example.com。ROLE:授予使用者帳戶的 IAM 角色。
- 為 Google Cloud 專案啟用預設服務帳戶:
gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com \ --project=PROJECT_ID
將 PROJECT_NUMBER 替換為專案編號。如要查看專案編號,請參閱「 取得現有專案」。
- 將編輯者角色 (
roles/editor) 授予預設服務帳戶:gcloud projects add-iam-policy-binding PROJECT_ID \ --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com" \ --role=roles/editor
- 為使用者帳戶建立本機驗證憑證:
gcloud auth application-default login
- 為專案啟用 OS 登入功能:
gcloud compute project-info add-metadata --metadata=enable-oslogin=TRUE
- 登入或建立 Hugging Face 帳戶。
使用 Hugging Face 存取 Gemma 3
如要使用 Hugging Face 存取 Gemma 3,請按照下列步驟操作:
建立 Hugging Face
read存取權杖。 依序點選「你的個人資料」>「設定」>「存取權杖」>「+ 建立新權杖」複製並儲存
read access權杖值。您會在稍後的教學課程中用到這項資訊。
準備環境
如要準備環境,請按照下列步驟操作:
複製 Cluster Toolkit GitHub 存放區:
git clone https://github.com/GoogleCloudPlatform/cluster-toolkit.git建立 Cloud Storage bucket:
gcloud storage buckets create gs://BUCKET_NAME \ --project=PROJECT_ID更改下列內容:
BUCKET_NAME:Cloud Storage bucket 的名稱,必須符合bucket 命名規定。PROJECT_ID:您要建立 Cloud Storage bucket 的Google Cloud 專案 ID。
建立 A4 Slurm 叢集
如要建立 A4 Slurm 叢集,請按照下列步驟操作:
前往
cluster-toolkit目錄:cd cluster-toolkit如果是首次使用 Cluster Toolkit,請建構
gcluster二進位檔:make前往
examples/machine-learning/a4-highgpu-8g目錄:cd examples/machine-learning/a4-highgpu-8g/開啟
a4high-slurm-deployment.yaml檔案,然後按照下列方式編輯:terraform_backend_defaults: type: gcs configuration: bucket: BUCKET_NAME vars: deployment_name: a4-high project_id: PROJECT_ID region: REGION zone: ZONE a4h_cluster_size: 2 a4h_reservation_name: RESERVATION_URL更改下列內容:
BUCKET_NAME:您在上一個章節中建立的 Cloud Storage bucket 名稱。PROJECT_ID:Cloud Storage 所在的Google Cloud 專案 ID,也是您要建立 Slurm 叢集的位置。REGION:預訂項目所在的區域。ZONE:預訂項目所在的可用區。RESERVATION_URL:您要用來建立 Slurm 叢集的預訂網址。根據保留項目所在的專案,指定下列其中一個值:專案中已有預留項目:
RESERVATION_NAME預留項目位於其他專案,且您的專案可以使用該預留項目:
projects/RESERVATION_PROJECT_ID/reservations/RESERVATION_NAME
部署叢集:
./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve./gcluster deploy指令包含兩個階段,如下所示:第一階段會建立預先安裝所有軟體的自訂映像檔,最多可能需要 35 分鐘才能完成。
第二階段會使用該自訂映像檔部署叢集。這個程序應該會比第一階段更快完成。
如果第一階段成功,但第二階段失敗,您可以嘗試略過第一階段,再次部署 Slurm 叢集:
./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve --skip "image" -w
準備工作負載
如要準備工作負載,請按照下列步驟操作:
建立工作負載指令碼
如要建立微調工作負載使用的指令碼,請按照下列步驟操作:
如要設定 Python 虛擬環境,請建立
install_environment.sh檔案,並加入下列內容:#!/bin/bash # This script should be run ONCE on the login node to set up the # shared Python virtual environment. set -e echo "--- Creating Python virtual environment in /home ---" python3 -m venv ~/.venv echo "--- Activating virtual environment ---" source ~/.venv/bin/activate echo "--- Installing build dependencies ---" pip install --upgrade pip wheel packaging echo "--- Installing PyTorch for CUDA 12.8 ---" pip install torch --index-url https://download.pytorch.org/whl/cu128 echo "--- Installing application requirements ---" pip install -r requirements.txt echo "--- Environment setup complete. You can now submit jobs with sbatch. ---"如要指定微調工作的設定,請建立
accelerate_config.yaml檔案,並加入下列內容:# Default configuration for a 2-node, 8-GPU-per-node (16 total GPUs) FSDP training job. compute_environment: "LOCAL_MACHINE" distributed_type: "FSDP" downcast_bf16: "no" fsdp_config: fsdp_auto_wrap_policy: "TRANSFORMER_BASED_WRAP" fsdp_backward_prefetch: "BACKWARD_PRE" fsdp_cpu_ram_efficient_loading: true fsdp_forward_prefetch: false fsdp_offload_params: false fsdp_sharding_strategy: "FULL_SHARD" fsdp_state_dict_type: "FULL_STATE_DICT" fsdp_transformer_layer_cls_to_wrap: "Gemma3DecoderLayer" fsdp_use_orig_params: true machine_rank: 0 main_training_function: "main" mixed_precision: "bf16" num_machines: 2 num_processes: 16 rdzv_backend: "static" same_network: true tpu_env: [] use_cpu: false如要指定工作在 Slurm 叢集上執行的工作,請建立
submit.slurm檔案,並加入下列內容:#!/bin/bash #SBATCH --job-name=gemma3-finetune #SBATCH --nodes=2 #SBATCH --ntasks-per-node=8 # 8 tasks per node #SBATCH --gpus-per-task=1 # 1 GPU per task #SBATCH --partition=a4high #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err set -e echo "--- Slurm Job Started ---" # --- STAGE 1: Copy Environment to Local SSD on all nodes --- srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c ' echo "Setting up local environment on $(hostname)..." LOCAL_VENV="/mnt/localssd/venv_job_${SLURM_JOB_ID}" LOCAL_CACHE="/mnt/localssd/hf_cache_job_${SLURM_JOB_ID}" rsync -a --info=progress2 ~/./.venv/ ${LOCAL_VENV}/ mkdir -p ${LOCAL_CACHE} echo "Setup on $(hostname) complete." ' # --- STAGE 2: Run the Training Job using the Local Environment --- echo "--- Starting Training ---" LOCAL_VENV="/mnt/localssd/venv_job_${SLURM_JOB_ID}" LOCAL_CACHE="/mnt/localssd/hf_cache_job_${SLURM_JOB_ID}" LOCAL_OUTPUT_DIR="/mnt/localssd/outputs_${SLURM_JOB_ID}" mkdir -p ${LOCAL_OUTPUT_DIR} # This is the main training command. srun --ntasks=$((SLURM_NNODES * 8)) --gpus-per-task=1 bash -c " source ${LOCAL_VENV}/bin/activate export HF_HOME=${LOCAL_CACHE} export HF_DATASETS_CACHE=${LOCAL_CACHE} # Run the Python script directly. # Accelerate will divide the work python ~/train.py \ --model_id google/gemma-3-12b-pt \ --output_dir ${LOCAL_OUTPUT_DIR} \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 8 \ --num_train_epochs 3 \ --learning_rate 1e-5 \ --save_strategy steps \ --save_steps 100 " # --- STAGE 3: Copy Final Model from Local SSD to Home Directory --- echo "--- Copying final model from local SSD to /home ---" # This command runs only on the first node of the job allocation # and copies the final model back to the persistent shared directory. srun --nodes=1 --ntasks=1 --ntasks-per-node=1 bash -c " rsync -a --info=progress2 ${LOCAL_OUTPUT_DIR}/ ~/gemma-12b-text-to-sql-finetuned/ " echo "--- Slurm Job Finished ---"如要指定微調工作的依附元件,請建立
requirements.txt檔案,並加入下列內容:# Hugging Face Libraries (Pinned to recent, stable versions for reproducibility) transformers==4.53.3 datasets==4.0.0 accelerate==1.9.0 evaluate==0.4.5 bitsandbytes==0.46.1 trl==0.19.1 peft==0.16.0 # Other dependencies tensorboard==2.20.0 protobuf==6.31.1 sentencepiece==0.2.0如要指定工作指令,請建立
train.py檔案,並加入下列內容:import torch import argparse from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model from trl import SFTTrainer, SFTConfig from huggingface_hub import login def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default="google/gemma-3-12b-pt", help="Hugging Face model ID") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face token for private models") parser.add_argument("--dataset_name", type=str, default="philschmid/gretel-synthetic-text-to-sql", help="Hugging Face dataset name") parser.add_argument("--output_dir", type=str, default="gemma-12b-text-to-sql", help="Directory to save model checkpoints") # LoRA arguments parser.add_argument("--lora_r", type=int, default=16, help="LoRA attention dimension") parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha scaling factor") parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout probability") # SFTConfig arguments parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length") parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size per device during training") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate") parser.add_argument("--logging_steps", type=int, default=10, help="Log every X steps") parser.add_argument("--save_strategy", type=str, default="steps", help="Checkpoint save strategy") parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every X steps") return parser.parse_args() def main(): args = get_args() # --- 1. Setup and Login --- if args.hf_token: login(args.hf_token) # --- 2. Create and prepare the fine-tuning dataset --- # The SFTTrainer will use the `formatting_func` to apply the chat template. dataset = load_dataset(args.dataset_name, split="train") dataset = dataset.shuffle().select(range(12500)) dataset = dataset.train_test_split(test_size=2500/12500) # --- 3. Configure Model and Tokenizer --- if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: torch_dtype_obj = torch.bfloat16 torch_dtype_str = "bfloat16" else: torch_dtype_obj = torch.float16 torch_dtype_str = "float16" tokenizer = AutoTokenizer.from_pretrained(args.model_id) tokenizer.pad_token = tokenizer.eos_token gemma_chat_template = ( "" "" ) tokenizer.chat_template = gemma_chat_template # --- 4. Define the Formatting Function --- # This function will be used by the SFTTrainer to format each sample # from the dataset into the correct chat template format. def formatting_func(example): # The create_conversation logic is now implicitly handled by this. # We need to construct the messages list here. system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA." user_prompt = "Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n\n<SCHEMA>\n{context}\n</SCHEMA>\n\n<USER_QUERY>\n{question}\n</USER_QUERY>\n" messages = [ {"role": "user", "content": user_prompt.format(question=example["sql_prompt"][0], context=example["sql_context"][0])}, {"role": "assistant", "content": example["sql"][0]} ] return tokenizer.apply_chat_template(messages, tokenize=False) # --- 5. Load Quantized Model and Apply PEFT --- # Define the quantization configuration quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch_dtype_obj, bnb_4bit_use_double_quant=True, ) config = AutoConfig.from_pretrained(args.model_id) config.use_cache = False # Load the base model with quantization print("Loading base model...") model = AutoModelForCausalLM.from_pretrained( args.model_id, config=config, quantization_config=quantization_config, attn_implementation="eager", torch_dtype=torch_dtype_obj, ) # Prepare the model for k-bit training model = prepare_model_for_kbit_training(model) # Configure LoRA. peft_config = LoraConfig( lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, r=args.lora_r, bias="none", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], task_type="CAUSAL_LM", ) # Apply the PEFT config to the model print("Applying PEFT configuration...") model = get_peft_model(model, peft_config) model.print_trainable_parameters() # --- 6. Configure Training Arguments --- training_args = SFTConfig( output_dir=args.output_dir, max_seq_length=args.max_seq_length, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, logging_steps=args.logging_steps, save_strategy=args.save_strategy, save_steps=args.save_steps, packing=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="adamw_torch", fp16=True if torch_dtype_obj == torch.float16 else False, bf16=True if torch_dtype_obj == torch.bfloat16 else False, max_grad_norm=0.3, warmup_ratio=0.03, lr_scheduler_type="constant", push_to_hub=False, report_to="tensorboard", dataset_kwargs={ "add_special_tokens": False, "append_concat_token": True, } ) # --- 7. Create Trainer and Start Training --- trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], formatting_func=formatting_func, ) print("Starting training...") trainer.train() print("Training finished.") # --- 8. Save the final model --- print(f"Saving final model to {args.output_dir}") trainer.save_model() if __name__ == "__main__": main()
將指令碼上傳至 Slurm 叢集
如要將上一節建立的指令碼上傳至 Slurm 叢集,請按照下列步驟操作:
如要找出登入節點,請列出專案中的所有 A4 VM:
gcloud compute instances list --filter="machineType:a4-highgpu-8g"登入節點的名稱類似於
a4-high-login-001。將指令碼上傳至登入節點的主目錄:
gcloud compute scp \ --project=PROJECT_ID \ --zone=ZONE \ --tunnel-through-iap \ ./train.py \ ./requirements.txt \ ./submit.slurm \ ./install_environment.sh \ ./accelerate_config.yaml \ "LOGIN_NODE_NAME":~/將
LOGIN_NODE_NAME替換為登入節點的名稱。
連線至 Slurm 叢集
透過 SSH 連線至登入節點,藉此連線至 Slurm 叢集:
gcloud compute ssh LOGIN_NODE_NAME \
--project=PROJECT_ID \
--tunnel-through-iap \
--zone=ZONE
安裝架構和工具
連線至登入節點後,請按照下列步驟安裝架構和工具:
為 Hugging Face 存取權杖建立環境變數:
export HUGGING_FACE_TOKEN="HUGGING_FACE_TOKEN"設定 Python 虛擬環境和所有必要依附元件:
chmod +x install_environment.sh ./install_environment.sh
開始微調工作負載
如要啟動微調工作負載,請按照下列步驟操作:
將工作提交至 Slurm 排程器:
sbatch submit.slurm在 Slurm 叢集的登入節點上,您可以檢查
home目錄中建立的輸出檔案,監控工作進度:tail -f slurm-gemma3-finetune.err如果工作順利啟動,
.err檔案會顯示進度列,並隨著工作進度更新。
監控工作負載
您可以監控 Slurm 叢集中的 GPU 使用情形,確認微調工作是否有效率地執行。如要這麼做,請在瀏覽器中開啟下列連結:
https://console.cloud.google.com/monitoring/metrics-explorer?project=PROJECT_ID&pageState=%7B%22xyChart%22%3A%7B%22dataSets%22%3A%5B%7B%22timeSeriesFilter%22%3A%7B%22filter%22%3A%22metric.type%3D%5C%22agent.googleapis.com%2Fgpu%2Futilization%5C%22%20resource.type%3D%5C%22gce_instance%5C%22%22%2C%22perSeriesAligner%22%3A%22ALIGN_MEAN%22%7D%2C%22plotType%22%3A%22LINE%22%7D%5D%7D%7D
監控工作負載時,您可以查看下列資訊:
GPU 使用率:如果微調工作正常運作,您應該會看到所有 16 個 GPU (叢集中每個 VM 有 8 個 GPU) 的使用率在訓練期間上升並穩定在特定程度。
工作時間:這項工作大約需要一小時才能完成。
清除所用資源
為避免因為本教學課程所用資源,導致系統向 Google Cloud 帳戶收取費用,請刪除含有相關資源的專案,或者保留專案但刪除個別資源。
刪除專案
刪除 Google Cloud 專案:
gcloud projects delete PROJECT_ID
刪除 Slurm 叢集
如要刪除 Slurm 叢集,請按照下列步驟操作:
前往
cluster-toolkit目錄。刪除 Terraform 檔案和所有已建立的資源:
./gcluster destroy a4-high --auto-approve