本教程将介绍如何在多节点、多 GPU Slurm 集群上微调 mistralai/Mixtral-8x7B-v0.1
模型 Google Cloud。该集群使用两个 a4-highgpu-8g 虚拟机 (VM) 实例,每个实例都有 8 个 NVIDIA B200 GPU。
本教程中介绍的两个主要流程如下:
- 使用 Google Cloud Cluster Toolkit 部署生产级高性能 Slurm 集群。在此部署过程中,您将创建一个预先安装了必要软件的自定义虚拟机映像。您还将设置共享 Lustre 文件系统并配置高速网络。
- 集群部署完毕后,您可以使用本教程附带的一组脚本运行分布式微调作业。该作业利用 PyTorch Fully Sharded Data Parallel (FSDP),您可以通过 Hugging Face Transformer Reinforcement Learning (TRL) 库访问该作业。
本教程适用于机器学习 (ML) 工程师、研究人员、平台管理员和运维人员,以及对跨多个节点和 GPU 分布 AI 工作负载感兴趣的数据和 AI 专家。
目标
- 使用 Hugging Face 访问 Mixtral
- 准备环境
- 创建和部署生产级 A4 High-GPU Slurm 集群。
- 配置多节点环境以使用 FSDP 进行分布式训练。
- 使用 Hugging Face
trl.SFTTrainer类微调 Mixtral 模型。 - 将数据暂存到本地 SSD。
- 监控作业。
- 清理。
费用
在本文档中,您将使用的以下收费组件: Google Cloud
您可使用 价格计算器 根据您的预计使用情况来估算费用。
准备工作
- 登录您的 Google Cloud 账号。如果您是 Google Cloud新手, 请创建一个账号来评估我们的产品在 实际场景中的表现。新客户还可获享 $300 赠金,用于 运行、测试和部署工作负载。
-
安装 Google Cloud CLI。
-
如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI。
-
如需初始化 gcloud CLI,请运行以下命令:
gcloud init -
选择或创建项目所需的角色
- 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授予角色的任何项目。
-
创建项目:如需创建项目,您需要 Project Creator 角色
(
roles/resourcemanager.projectCreator),该角色包含resourcemanager.projects.create权限。了解如何授予 角色。
-
创建 Google Cloud 项目:
gcloud projects create PROJECT_ID
将
PROJECT_ID替换为您要创建的 Google Cloud 项目的名称。 -
选择您创建的 Google Cloud 项目:
gcloud config set project PROJECT_ID
将
PROJECT_ID替换为您的 Google Cloud 项目名称。
-
验证是否已为您的 Google Cloud 项目启用结算功能。
启用必需的 API:
启用 API 所需的角色
如需启用 API,您需要 Service Usage Admin IAM 角色 (
roles/serviceusage.serviceUsageAdmin),该角色包含serviceusage.services.enable权限。了解如何授予 角色。gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com lustre.googleapis.com
-
安装 Google Cloud CLI。
-
如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI。
-
如需初始化 gcloud CLI,请运行以下命令:
gcloud init -
选择或创建项目所需的角色
- 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授予角色的任何项目。
-
创建项目:如需创建项目,您需要 Project Creator 角色
(
roles/resourcemanager.projectCreator),该角色包含resourcemanager.projects.create权限。了解如何授予 角色。
-
创建 Google Cloud 项目:
gcloud projects create PROJECT_ID
将
PROJECT_ID替换为您要创建的 Google Cloud 项目的名称。 -
选择您创建的 Google Cloud 项目:
gcloud config set project PROJECT_ID
将
PROJECT_ID替换为您的 Google Cloud 项目名称。
-
验证是否已为您的 Google Cloud 项目启用结算功能。
启用必需的 API:
启用 API 所需的角色
如需启用 API,您需要 Service Usage Admin IAM 角色 (
roles/serviceusage.serviceUsageAdmin),该角色包含serviceusage.services.enable权限。了解如何授予 角色。gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com lustre.googleapis.com
-
向您的用户账号授予角色。对以下每个 IAM 角色运行以下命令一次:
roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/serviceusage.serviceUsageAdmingcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
替换以下内容:
PROJECT_ID:您的项目 ID。USER_IDENTIFIER:您的用户账号的标识符。 例如,myemail@example.com。ROLE:您授予用户账号的 IAM 角色。
- 为您的 Google Cloud 项目启用默认服务帐号:
gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com \ --project=PROJECT_ID
将 PROJECT_NUMBER 替换为您的项目编号。如需查看您的 项目编号,请参阅 获取现有项目。
- 向默认服务账号授予 Editor 角色 (
roles/editor):gcloud projects add-iam-policy-binding PROJECT_ID \ --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com" \ --role=roles/editor
- 为您的用户账号创建本地身份验证凭据:
gcloud auth application-default login
- 为您的项目启用 OS Login:
gcloud compute project-info add-metadata --metadata=enable-oslogin=TRUE
- 登录或创建 Hugging Face 账号。
- 安装使用 Cluster Toolkit 所需的依赖项。
使用 Hugging Face 访问 Mixtral
如需使用 Hugging Face 访问 Mixtral,请执行以下操作:
- 创建 Hugging Face
read access令牌。 - 复制并保存
read访问令牌值。您将在本教程的后面部分使用它。
准备环境
您可以在本地机器上执行以下步骤,为集群部署做准备。
克隆 Google Cloud Cluster Toolkit 代码库:
git clone https://github.com/GoogleCloudPlatform/cluster-toolkit.git创建 Cloud Storage 存储桶:
export BUCKET_NAME="your-unique-bucket-name" gcloud storage buckets create gs://${BUCKET_NAME}
创建 A4 Slurm 集群
如需创建 A4 Slurm 集群,请执行以下操作:
转到克隆的
cluster-toolkit目录:cd cluster-toolkit如果您是首次使用 Cluster Toolkit,请构建
gcluster二进制文件:make转到
examples/machine-learning/a4-highgpu-8g目录。打开
a4high-slurm-deployment.yaml文件并按如下方式进行修改:terraform_backend_defaults: type: gcs configuration: bucket: BUCKET_NAME vars: deployment_name: DEPLOYMENT_NAME project_id: PROJECT_ID region: REGION zone: ZONE a4h_cluster_size: 2 a4h_reservation_name: RESERVATION_NAME替换以下内容:
BUCKET_NAME::您在上一步中创建的 Cloud Storage 存储桶的名称。PROJECT_ID:Cloud Storage 所在的项目的 ID,也是您要创建 Slurm 集群的项目。 Google CloudREGION:预留所在的区域。ZONE:预留所在的可用区。A4h_reservation_name:使用 A4 预留的名称。
打开
a4high-slurm-blueprint.yaml文件并按如下方式进行修改:- 移除
filestore_homefs模块。 - 启用
lustrefs和private-service-access模块。 - 在
vars块中,配置以下内容:Find slurm_vars并将install_managed_lustre设置为true。- 将
per_unit_storage_throughput参数设置为500。 - 将
size_gib参数设置为36000。
- 移除
部署集群:
./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml \ examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml \ --auto-approve./gcluster deploy命令会启动一个两阶段流程,如下所示:- 第一阶段构建预先安装了所有软件的自定义映像,此过程最多可能需要 35 分钟才能完成。
- 第二阶段使用该自定义映像部署集群。此过程应比第一阶段更快完成。
准备工作负载
如需准备工作负载,请按以下步骤操作:
创建工作负载脚本
如需创建微调工作负载将使用的脚本,请按以下步骤操作:
如需设置 Python 虚拟环境,请创建包含以下内容的
install_environment.sh文件:#!/bin/bash # This script sets a reliable environment for FSDP training. # It is meant to be run on a compute node. set -e # --- 1. Create the Python virtual environment --- VENV_PATH="$HOME/.venv/venv-fsdp" if [ ! -d "$VENV_PATH" ]; then echo "--- Creating Python virtual environment at $VENV_PATH ---" python3 -m venv $VENV_PATH else echo "--- Virtual environment already exists at $VENV_PATH ---" fi source $VENV_PATH/bin/activate # --- 2. Install Dependencies --- echo "--- [STEP 2.1] Upgrading build toolchain ---" pip install --upgrade pip wheel packaging echo "--- [STEP 2.2] Installing PyTorch Nightly ---" pip install --force-reinstall --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 echo "--- [STEP 2.3] Installing application dependencies ---" if [ -f "requirements-fsdp.txt" ]; then pip install -r requirements-fsdp.txt else echo "ERROR: requirements-fsdp.txt not found!" exit 1 fi # --- [STEP 2.4] Build Flash Attention from Source --- echo "--- Building flash-attn from source... This will take a while. ---" # Use all available CPU cores to speed up the build MAX_JOBS=$(nproc) pip install flash-attn --no-build-isolation # --- 3. Download the Model --- echo "--- [STEP 2.5] Downloading Mixtral model ---" if [ -z "$HF_TOKEN" ]; then echo "ERROR: The HF_TOKEN environment variable is not set."; exit 1; fi pip install huggingface_hub[cli] # Execute the CLI using its full, explicit path $VENV_PATH/bin/huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir ~/Mixtral-8x7B-v0.1 --token $HF_TOKEN echo "--- Environment setup complete. ---"如需为训练脚本指定 Python 依赖项,请创建包含以下内容的
requirements-fsdp.txt文件:transformers==4.55.0 datasets==4.0.0 peft==0.16.0 accelerate==1.9.0 trl==0.21.0 # Other dependencies sentencepiece==0.2.0 protobuf==6.31.1将
train-mixtral.py指定为主训练脚本:import torch from torch.distributed.fsdp import MixedPrecision from datasets import load_dataset import shutil import os import torch.distributed as dist from peft import LoraConfig, PeftModel, get_peft_model from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, HfArgumentParser, ) from torch.distributed import get_rank, get_world_size from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer from trl import SFTTrainer from dataclasses import dataclass, field from typing import Optional @dataclass class ScriptArguments: model_id: str = field(default="mistralai/Mixtral-8x7B-v0.1", metadata={"help": "Hugging Face model ID from the Hub"}) dataset_name: str = field(default="philschmid/gretel-synthetic-text-to-sql", metadata={"help": "Dataset from the Hub"}) run_inference_after_training: bool = field(default=False, metadata={"help": "Run sample inference on rank 0 after training"}) dataset_subset_size: Optional[int] = field(default=None, metadata={"help": "Number of samples to use from the dataset for training. If None, uses the full dataset."}) @dataclass class PeftArguments: lora_r: int = field(default=16, metadata={"help": "LoRA attention dimension"}) lora_alpha: int = field(default=32, metadata={"help": "LoRA alpha scaling factor"}) lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout probability"}) @dataclass class SftTrainingArguments(TrainingArguments): max_length: Optional[int] = field(default=2048, metadata={"help": "The maximum sequence length for SFTTrainer"}) packing: Optional[bool] = field(default=False, metadata={"help": "Enable packing for SFTTrainer"}) ddp_find_unused_parameters: Optional[bool] = field(default=False, metadata={"help": "When using FSDP activation checkpointing, this must be set to False for Mixtral"}) def formatting_prompts_func(example): system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA." user_prompt = f"### SCHEMA:\n{example['sql_context']}\n\n### USER QUERY:\n{example['sql_prompt']}" response = f"\n\n### SQL QUERY:\n{example['sql']}" return f"{system_message}\n\n{user_prompt}{response}" def main(): parser = HfArgumentParser((ScriptArguments, PeftArguments, SftTrainingArguments)) script_args, peft_args, training_args = parser.parse_args_into_dataclasses() training_args.gradient_checkpointing = True training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} training_args.optim = "adamw_torch_fused" bf16_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) training_args.fsdp = "full_shard" training_args.fsdp_config = { "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_transformer_layer_cls_to_wrap": [MixtralDecoderLayer], "fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_offload_params": False, "fsdp_forward_prefetch": True, "fsdp_mixed_precision_policy": bf16_policy } tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained( script_args.model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, attn_implementation="flash_attention_2", ) peft_config = LoraConfig( r=peft_args.lora_r, lora_alpha=peft_args.lora_alpha, lora_dropout=peft_args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) model = get_peft_model(model, peft_config) data_splits = load_dataset(script_args.dataset_name) dataset = data_splits["train"] eval_dataset = data_splits["test"] if script_args.dataset_subset_size is not None: dataset = dataset.select(range(script_args.dataset_subset_size)) dataset = dataset.shuffle(seed=training_args.seed) trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, eval_dataset=eval_dataset, formatting_func=formatting_prompts_func, processing_class=tokenizer, ) trainer.train() dist.barrier() if trainer.is_world_process_zero(): best_model_path = trainer.state.best_model_checkpoint final_model_dir = os.path.join(training_args.output_dir, "final_best_model") print(f"Copying best model to: {final_model_dir}") if os.path.exists(final_model_dir): shutil.rmtree(final_model_dir) shutil.copytree(best_model_path, final_model_dir) if script_args.run_inference_after_training: del model, trainer torch.cuda.empty_cache() run_post_training_inference(script_args, final_model_dir, tokenizer) def run_post_training_inference(script_args, best_model_path, tokenizer): print("\n" + "="*50) print("=== RUNNING POST-TRAINING INFERENCE TEST ===") print("="*50 + "\n") base_model = AutoModelForCausalLM.from_pretrained( script_args.model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) model = PeftModel.from_pretrained(base_model, best_model_path) model = model.merge_and_unload() model.eval() # Define the test case schema = "CREATE TABLE artists (Name TEXT, Country TEXT, Genre TEXT)" system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA." question = "Show me all artists from the Country just north of the USA." prompt = f"{system_message}\n\n### SCHEMA:\n{schema}\n\n### USER QUERY:\n{question}\n\n### SQL QUERY:\n" print(f"Test Prompt:\n{prompt}") inputs = tokenizer(prompt, return_tensors="pt").to("cuda") print("\n--- Generating SQL... ---") outputs = model.generate( **inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, do_sample=False, temperature=None, top_p=None, ) generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip() print(f"\n--- Generated SQL Query ---") print(generated_sql) print("\n" + "="*50) print("=== INFERENCE TEST COMPLETE ===") print("="*50 + "\n") if __name__ == "__main__": main()如需指定作业在 Slurm 集群上运行的任务,请创建包含以下内容的
train-mixtral.sh文件:#!/bin/bash #SBATCH --job-name=mixtral-fsdp #SBATCH --nodes=2 #SBATCH --ntasks-per-node=8 #SBATCH --gpus-per-node=8 #SBATCH --partition=a4high #SBATCH --output=mixtral-%j.out #SBATCH --error=mixtral-%j.err set -e set -x echo "--- Slurm Job Started ---" # --- Define Paths --- LOCAL_SSD_PATH="/mnt/localssd/job_${SLURM_JOB_ID}" VENV_PATH="${HOME}/.venv/venv-fsdp" MODEL_PATH="${HOME}/Mixtral-8x7B-v0.1" # --- STAGE 1: Stage Data to Local SSD on Each Node --- srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c " echo '--- Staging on node: $(hostname) ---' mkdir -p ${LOCAL_SSD_PATH} echo 'Copying virtual environment...' rsync -a -q ${VENV_PATH}/ ${LOCAL_SSD_PATH}/venv/ echo 'Copying model weights...' rsync -a ${MODEL_PATH}/ ${LOCAL_SSD_PATH}/model/ mkdir -p ${LOCAL_SSD_PATH}/hf_cache echo '--- Staging on $(hostname) complete ---' " echo "--- Staging complete on all nodes ---" # --- STAGE 2: Run the Training Job --- echo "--- Launching Distributed Training with GIB NCCL Plugin ---" nodes=( $( scontrol show hostnames "$SLURM_JOB_NODELIST" ) ) head_node=${nodes[0]} head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) export MASTER_ADDR=$head_node_ip export MASTER_PORT=29500 export NCCL_SOCKET_IFNAME=enp0s19 export NCCL_NET=gIB # export NCCL_DEBUG=INFO # Un-comment to diagnose NCCL issues if needed srun --cpu-bind=none --accel-bind=g bash -c ' # Activate the environment from the local copy source '${LOCAL_SSD_PATH}'/venv/bin/activate # Point Hugging Face cache to the local SSD export HF_HOME='${LOCAL_SSD_PATH}'/hf_cache export RANK=$SLURM_PROCID export WORLD_SIZE=$SLURM_NTASKS export LOCAL_RANK=$SLURM_LOCALID export LD_LIBRARY_PATH=/usr/local/gib/lib64:$LD_LIBRARY_PATH source /usr/local/gib/scripts/set_nccl_env.sh # --- Launch the training --- python \ '${SLURM_SUBMIT_DIR}'/train-mixtral.py \ --model_id="'${LOCAL_SSD_PATH}'/model/" \ --output_dir="${HOME}/outputs/mixtral_job_${SLURM_JOB_ID}" \ --dataset_name="philschmid/gretel-synthetic-text-to-sql" \ --seed=900913 \ --bf16=True \ --num_train_epochs=3 \ --per_device_train_batch_size=32 \ --gradient_accumulation_steps=4 \ --learning_rate=4e-5 \ --logging_steps=3 \ --lora_r=32 \ --lora_alpha=32 \ --lora_dropout=0.05 \ --eval_strategy=steps \ --eval_steps=10 \ --save_strategy=steps \ --save_steps=10 \ --load_best_model_at_end=False \ --metric_for_best_model=eval_loss \ --run_inference_after_training \ --dataset_subset_size=67000 ' # --- STAGE 3: Cleanup --- echo "--- Cleaning up local SSD on all nodes ---" srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "rm -rf ${LOCAL_SSD_PATH}" echo "--- Slurm Job Finished ---"
将脚本上传到 Slurm 集群
如需将您在上一步中创建的脚本上传到 Slurm 集群,请执行以下操作:
如需标识登录节点,请列出项目中的所有虚拟机:
gcloud compute instances list登录节点的名称类似于
a4-high-login-001。将脚本上传到登录节点的主目录:
# Run this from your local machine where you created the files LOGIN_NODE_NAME="your-login-node-name" # e.g., a4high-login-001 PROJECT_ID="your-gcp-project-id" ZONE="your-cluster-zone" # e.g., us-west4-a gcloud compute scp --project="$PROJECT_ID" --zone="$ZONE" --tunnel-through-iap \ ./install_environment.sh \ ./requirements-fsdp.txt \ ./train-mixtral.py \ ./train-mixtral.sh \ "${LOGIN_NODE_NAME}":~/
连接到 Slurm 集群
通过 SSH 连接到登录节点,从而连接到 Slurm 集群:
gcloud compute ssh $LOGIN_NODE_NAME \
--project=$PROJECT_ID \
--tunnel-through-iap \
--zone=$ZONE
安装框架和工具
连接到登录节点后,安装框架和工具。
导出 Hugging Face 令牌:
# On the login node export HF_TOKEN="hf_..." # Replace with your token在计算节点上运行安装脚本。
# On the login node srun \ --job-name=env-setup \ --nodes=1 \ --ntasks=1 \ --gpus-per-node=1 \ --partition=a4high \ bash ./install_environment.sh此命令会设置虚拟环境,安装所有依赖项,并将 Mixtral 模型权重下载到
~/Mixtral-8x7B-v0.1中。此过程可能需要 30 多分钟才能完成。
启动微调工作负载
如需启动训练工作负载,请执行以下操作:
将作业提交给 Slurm 调度程序:
# On the login node sbatch train-mixtral.sh在 Slurm 集群中的登录节点上,您可以通过检查在
home目录中创建的输出文件来监控作业的进度:# On the login node tail -f mixtral-*.out如果作业成功启动,
.err文件会显示一个进度条,该进度条会随着作业的进行而更新。该作业有两个主要阶段:
- 将大型基础模型复制到每个计算节点的本地 SSD。
- 训练作业,该作业在模型复制完成后开始。
整个作业大约需要 40 分钟才能运行完毕。
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
删除 Slurm 集群
如需删除 Slurm 集群,请按以下步骤操作:
转到
cluster-toolkit目录。销毁 Terraform 文件和所有已创建的资源:
./gcluster destroy DEPLOYMENT_NAME --auto-approve
删除项目
删除项目: Google Cloud
gcloud projects delete PROJECT_ID
后续步骤
- 重新部署 Slurm 集群
- 测试 Slurm 集群的网络性能
- 监控 Slurm 集群中的虚拟机
- 创建服务端点: 获得微调后的模型后,您可以使用 GKE 或 Vertex AI 将其部署到服务端点,使其可用于推理。