在 A4 Slurm 集群上使用 FSDP 对 Llama 4 进行微调

本教程介绍如何在 Google Cloud 上的多节点、多 GPU Slurm 集群上微调 Llama-4-Scout-17 大语言 模型 (LLM) Google Cloud。该集群使用两个 A4 虚拟机 (VM) 实例,每个实例都有 8 个 NVIDIA B200 GPU。

本教程中介绍的两个主要流程如下:

  1. 使用 Google Cloud Cluster Toolkit 部署生产级高性能 Slurm 集群。在此部署过程中,您将创建一个预先安装了必要软件的自定义虚拟机映像。 您还可以设置共享 Filestore 实例,并配置高速 RDMA 网络。
  2. 部署集群后,您可以使用本教程附带的一组脚本运行分布式微调作业。该作业利用 PyTorch 完全分片数据并行处理 (FSDP),您可以通过 Hugging Face Transformer 强化学习访问该作业

本教程适用于机器学习 (ML) 工程师、平台管理员和运维人员,以及对使用 Slurm 作业调度功能处理微调工作负载感兴趣的数据和 AI 专家。

目标

  • 使用 Hugging Face 访问 Llama 4

  • 准备环境

  • 创建和部署生产级 A4 高 GPU Slurm 集群。

  • 配置多节点环境以使用 FSDP 进行分布式训练。

  • 使用 Hugging Face trl.SFTTrainer 微调 Llama 4 模型。

  • 将数据暂存到本地 SSD。

  • 监控作业。

  • 清理。

费用

在本文档中,您将使用的以下收费组件: Google Cloud

您可使用 价格计算器 根据您的预计使用情况来估算费用。

新 Google Cloud 用户可能有资格申请免费试用

准备工作

  1. 登录您的 Google Cloud 账号。如果您是 Google Cloud新手, 请创建一个账号来评估我们的产品在 实际场景中的表现。新客户还可获享 $300 赠金,用于 运行、测试和部署工作负载。
  2. 安装 Google Cloud CLI。

  3. 如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI

  4. 如需初始化 gcloud CLI,请运行以下命令:

    gcloud init
  5. 创建或选择 Google Cloud 项目

    选择或创建项目所需的角色

    • 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授予角色的任何项目。
    • 创建项目:如需创建项目,您需要 Project Creator 角色 (roles/resourcemanager.projectCreator),其中包含 resourcemanager.projects.create 权限。了解如何授予 角色
    • 创建 Google Cloud 项目:

      gcloud projects create PROJECT_ID

      PROJECT_ID 替换为您要创建的 Google Cloud 项目名称。

    • 选择您创建的 Google Cloud 项目:

      gcloud config set project PROJECT_ID

      PROJECT_ID 替换为您的 Google Cloud 项目名称。

  6. 验证是否已为您的 Google Cloud 项目启用结算功能。

  7. 启用必需的 API:

    启用 API 所需的角色

    如需启用 API,您需要 Service Usage Admin IAM 角色 (roles/serviceusage.serviceUsageAdmin),其中包含 serviceusage.services.enable 权限。了解如何授予 角色

    gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
  8. 安装 Google Cloud CLI。

  9. 如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI

  10. 如需初始化 gcloud CLI,请运行以下命令:

    gcloud init
  11. 创建或选择 Google Cloud 项目

    选择或创建项目所需的角色

    • 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授予角色的任何项目。
    • 创建项目:如需创建项目,您需要 Project Creator 角色 (roles/resourcemanager.projectCreator),其中包含 resourcemanager.projects.create 权限。了解如何授予 角色
    • 创建 Google Cloud 项目:

      gcloud projects create PROJECT_ID

      PROJECT_ID 替换为您要创建的 Google Cloud 项目名称。

    • 选择您创建的 Google Cloud 项目:

      gcloud config set project PROJECT_ID

      PROJECT_ID 替换为您的 Google Cloud 项目名称。

  12. 验证是否已为您的 Google Cloud 项目启用结算功能。

  13. 启用必需的 API:

    启用 API 所需的角色

    如需启用 API,您需要 Service Usage Admin IAM 角色 (roles/serviceusage.serviceUsageAdmin),其中包含 serviceusage.services.enable 权限。了解如何授予 角色

    gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
  14. 将角色授予您的用户账号。对以下每个 IAM 角色运行以下命令一次: roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/serviceusage.serviceUsageAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    替换以下内容:

    • PROJECT_ID:您的项目 ID。
    • USER_IDENTIFIER:您的用户账号的标识符。 例如,myemail@example.com
    • ROLE:您授予用户账号的 IAM 角色。
  15. 为您的 Google Cloud 项目启用默认服务帐号:
    gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com 
    --project=PROJECT_ID

    PROJECT_NUMBER 替换为您的项目编号。如需查看您的 项目编号,请参阅 获取现有项目

  16. 向默认服务账号授予 Editor 角色 (roles/editor):
    gcloud projects add-iam-policy-binding PROJECT_ID 
    --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com"
    --role=roles/editor
  17. 为您的用户账号创建本地身份验证凭据:
    gcloud auth application-default login
  18. 为您的项目启用 OS Login:
    gcloud compute project-info add-metadata --metadata=enable-oslogin=TRUE
  19. 登录或创建 Hugging Face 账号
  20. 安装使用 Cluster Toolkit 所需的依赖项

使用 Hugging Face 访问 Llama 4

如需使用 Hugging Face 访问 Llama 4,请执行以下操作:

  1. 签署同意协议以使用 Llama 4

  2. 创建 Hugging Face read 访问令牌

    依次点击您的个人资料 > 设置 > 访问令牌 > +创建新令牌

  3. 复制并保存 read access 令牌值。您将在本教程的后面部分使用它。

准备环境

如需准备环境,请按照以下步骤操作:

  1. 克隆 Cluster Toolkit GitHub 代码库:

    git clone https://github.com/GoogleCloudPlatform/cluster-toolkit.git
    
  2. 创建 Cloud Storage 存储桶:

    gcloud storage buckets create gs://BUCKET_NAME \
        --project=PROJECT_ID
    

    替换以下内容:

    • BUCKET_NAME:Cloud Storage 存储桶的名称,该名称必须符合 存储桶命名要求

    • PROJECT_ID:要在其中创建 Cloud Storage 存储桶的 Google Cloud 项目的 ID。

创建 A4 Slurm 集群

如需创建 A4 Slurm 集群,请按照以下步骤操作:

  1. 转到 cluster-toolkit 目录:

    cd cluster-toolkit
    
  2. 如果您是首次使用 Cluster Toolkit,请构建 gcluster 二进制文件:

    make
    
  3. 转到 examples/machine-learning/a4-highgpu-8g 目录:

    cd examples/machine-learning/a4-highgpu-8g/
    
  4. 打开 a4high-slurm-deployment.yaml 文件,然后按如下方式进行修改:

    terraform_backend_defaults:
      type: gcs
      configuration:
        bucket: BUCKET_NAME
    
    vars:
      deployment_name: a4-high
      project_id: PROJECT_ID
      region: REGION
      zone: ZONE
      a4h_cluster_size: 2
      a4h_reservation_name: RESERVATION_URL
    

    替换以下内容:

    • BUCKET_NAME:您在上一部分中创建的 Cloud Storage 存储桶的名称。

    • PROJECT_ID:Cloud Storage 所在的 Google Cloud 项目的 ID,也是您要在其中创建 Slurm 集群的项目的 ID。

    • REGION:预留所在的区域。

    • ZONE:预留所在的可用区。

    • RESERVATION_URL:您要用于创建 Slurm 集群的预留的网址。根据预留所在的项目的不同,指定以下某个值:

      • 预留存在于您的项目中: RESERVATION_NAME

      • 预留存在于其他项目中,并且您的项目可以使用该预留: projects/RESERVATION_PROJECT_ID/reservations/RESERVATION_NAME

  5. 部署集群:

    ./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve
    

    ./gcluster deploy 命令是一个两阶段的过程,如下所示:

    • 第一阶段构建预先安装了所有软件的自定义映像,此过程可能需要长达 35 分钟才能完成。

    • 第二阶段使用该自定义映像部署集群。此过程应比第一阶段更快完成。

    如果第一阶段成功,但第二阶段失败,您可以尝试跳过第一阶段再次部署 Slurm 集群:

    ./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve --skip "image" -w
    

准备工作负载

如需准备工作负载,请执行以下操作:

  1. 创建工作负载脚本

  2. 将脚本上传到 Slurm 集群

  3. 连接到 Slurm 集群

  4. 安装框架和工具

创建工作负载脚本

如需创建微调工作负载将使用的脚本,请按照以下步骤操作:

  1. 如需设置 Python 虚拟环境,请创建包含以下内容的 install_environment.sh 文件:

    #!/bin/bash
    # This script sets up a consistent environment for FSDP training.
    # It is meant to be run once on the login node of your Slurm cluster
    set -e
    
    # --- 1. Create the Python virtual environment ---
    VENV_PATH="$HOME/.venv/venv-fsdp"
    if [ ! -d "$VENV_PATH" ]; then
      echo "--- Creating Python virtual environment at $VENV_PATH ---"
      python3 -m venv $VENV_PATH
    else
      echo "--- Virtual environment already exists at $VENV_PATH ---"
    fi
    
    source $VENV_PATH/bin/activate
    
    # --- 2. Install Dependencies ---
    echo "--- [STEP 2.1] Upgrading build toolchain ---"
    pip install --upgrade pip wheel packaging
    
    echo "--- [STEP 2.2] Installing PyTorch Nightly ---"
    pip install --force-reinstall --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
    
    echo "--- [STEP 2.3] Installing application dependencies ---"
    if [ -f "requirements-fsdp.txt" ]; then
        pip install -r requirements-fsdp.txt
    else
        echo "ERROR: requirements-fsdp.txt not found!"
        exit 1
    fi
    
    # --- 3. Download the Model ---
    echo "--- [STEP 2.4] Downloading Llama4 model ---"
    if [ -z "$HF_TOKEN" ]; then
      echo "ERROR: The HF_TOKEN environment variable is not set."; exit 1;
    fi
    pip install huggingface_hub[cli]
    
    # Execute the CLI using its full, explicit path
    $VENV_PATH/bin/huggingface-cli download meta-llama/Llama-4-Scout-17B-16E-Instruct --local-dir ~/Llama-4-Scout-17B-16E-Instruct --token $HF_TOKEN
    
    echo "--- Environment setup complete. ---"
    

    此脚本会设置可靠的 Python 虚拟环境,安装 PyTorch 每夜版,并下载 Llama 4 模型。

  2. 如需为训练脚本指定 Python 依赖项,请创建包含以下内容的 requirements-fsdp.txt 文件:

    transformers==4.55.0
    datasets==4.0.0
    peft==0.16.0
    accelerate==1.9.0
    trl==0.21.0
    
    # Other dependencies
    sentencepiece==0.2.0
    
  3. llama4-train-distributed.py 指定为主训练脚本:

    import torch
    from datasets import load_dataset
    from peft import LoraConfig, PeftModel
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        HfArgumentParser,
    )
    
    from torch.distributed import get_rank, get_world_size
    
    from transformers.models.llama4.modeling_llama4 import Llama4TextDecoderLayer
    from trl import SFTTrainer
    from dataclasses import dataclass, field
    from typing import Optional
    
    @dataclass
    class ScriptArguments:
        model_id: str = field(metadata={"help": "Hugging Face model ID from the Hub"})
        dataset_name: str = field(default="philschmid/gretel-synthetic-text-to-sql", metadata={"help": "Dataset from the Hub"})
        run_inference_after_training: bool = field(default=False, metadata={"help": "Run sample inference on rank 0 after training"})
        dataset_subset_size: Optional[int] = field(default=None, metadata={"help": "Number of samples to use from the dataset for training. If None, uses the full dataset."})
    
    @dataclass
    class PeftArguments:
        lora_r: int = field(default=16, metadata={"help": "LoRA attention dimension"})
        lora_alpha: int = field(default=32, metadata={"help": "LoRA alpha scaling factor"})
        lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout probability"})
    
    @dataclass
    class SftTrainingArguments(TrainingArguments):
        max_length: Optional[int] = field(default=2048, metadata={"help": "The maximum sequence length for SFTTrainer"})
        packing: Optional[bool] = field(default=False, metadata={"help": "Enable packing for SFTTrainer"})
        ddp_find_unused_parameters: Optional[bool] = field(default=True, metadata={"help": "When using FSDP activation checkpointing, this must be set to True"})
    
    def formatting_prompts_func(example):
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        user_prompt = f"### SCHEMA:\n{example['sql_context']}\n\n### USER QUERY:\n{example['sql_prompt']}"
        response = f"\n\n### SQL QUERY:\n{example['sql']}"
        return f"{system_message}\n\n{user_prompt}{response}"
    
    def main():
        parser = HfArgumentParser((ScriptArguments, PeftArguments, SftTrainingArguments))
        script_args, peft_args, training_args = parser.parse_args_into_dataclasses()
    
        training_args.gradient_checkpointing = True
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
    
        training_args.optim = "adamw_torch_fused"
    
        training_args.fsdp = "full_shard"
        training_args.fsdp_config = {
            "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
            "fsdp_transformer_layer_cls_to_wrap": [Llama4TextDecoderLayer],
            "fsdp_state_dict_type": "FULL_STATE_DICT",
            "fsdp_offload_params": False,
            "fsdp_forward_prefetch": True,
        }
    
        tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, trust_remote_code=True)
    
        model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            attn_implementation="sdpa",
        )
    
        peft_config = LoraConfig(
            r=peft_args.lora_r,
            lora_alpha=peft_args.lora_alpha,
            lora_dropout=peft_args.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        )
        rank = get_rank()
        world_size = get_world_size()
    
        dataset = load_dataset(script_args.dataset_name, split="train")
    
        if script_args.dataset_subset_size is not None:
            dataset = dataset.select(range(script_args.dataset_subset_size))
        else:
            print(f"Using the full dataset with {len(dataset)} samples.")
    
        dataset = dataset.shuffle(seed=training_args.seed)
        print(f"Dataset shuffled with seed: {training_args.seed}.")
    
        if world_size > 1:
            print(f"Sharding dataset for Rank {rank} of {world_size}.")
            dataset = dataset.shard(num_shards=world_size, index=rank)
    
        print("Initializing SFTTrainer...")
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            peft_config=peft_config,
            formatting_func=formatting_prompts_func,
            processing_class=tokenizer,
        )
    
        trainer.train()
    
        trainer.save_model(training_args.output_dir)
    
        if script_args.run_inference_after_training and trainer.is_world_process_zero():
            del model
            del trainer
            torch.cuda.empty_cache()
            run_post_training_inference(script_args, training_args, tokenizer)
    
    def run_post_training_inference(script_args, training_args, tokenizer):
        """
        Loads the fine-tuned PEFT adapter from the local output directory and runs inference.
        This should only be called on rank 0 after training is complete.
        """
        print("\n" + "="*50)
        print("=== RUNNING POST-TRAINING INFERENCE TEST ===")
        print("="*50 + "\n")
    
        # Load the base model and merge the adapter.
        base_model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        )
        # Load the PEFT adapter and merge it into the base model
        model = PeftModel.from_pretrained(base_model, training_args.output_dir)
        model = model.merge_and_unload() # Merge weights for faster inference
        model.eval()
    
        # Define the test case
        schema = "CREATE TABLE artists (Name TEXT, Country TEXT, Genre TEXT)"
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        question = "Show me all artists from the Country just north of the USA."
    
        # This must match the formatting_func exactly
        prompt = f"{system_message}\n\n### SCHEMA:\n{schema}\n\n### USER QUERY:\n{question}\n\n### SQL QUERY:\n"
    
        print(f"Test Prompt:\n{prompt}")
    
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
        print("\n--- Generating SQL... ---")
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False,
            temperature=None,
            top_p=None,
        )
    
        generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
    
        print(f"\n--- Generated SQL Query ---")
        print(generated_sql)
        print("\n" + "="*50)
        print("=== INFERENCE TEST COMPLETE ===")
        print("="*50 + "\n")
    
    if __name__ == "__main__":
        main()
    

    此脚本利用 TRL 监督式微调 (SFT) Trainer 来管理 FSDP 训练循环、低秩适应 (LoRA) 配置和数据格式。

  4. 如需指定作业在 Slurm 集群上运行的任务,请创建包含以下内容的 submit.slurm 文件:

    #!/bin/bash
    #SBATCH --job-name=llama4-fsdp-fixed
    #SBATCH --nodes=2
    #SBATCH --ntasks-per-node=8
    #SBATCH --gpus-per-node=8
    #SBATCH --partition=a4high
    #SBATCH --output=llama4-%j.out
    #SBATCH --error=llama4-%j.err
    
    set -e
    set -x
    
    echo "--- Slurm Job Started ---"
    echo "Job ID: $SLURM_JOB_ID"
    echo "Node List: $SLURM_JOB_NODELIST"
    
    # --- Define Paths ---
    LOCAL_SSD_PATH="/mnt/localssd/job_${SLURM_JOB_ID}"
    VENV_PATH="${HOME}/.venv/venv-fsdp"
    MODEL_PATH="${HOME}/Llama-4-Scout-17B-16E-Instruct"
    
    # --- STAGE 1: Stage Data to Local SSD on Each Node ---
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "
      echo '--- Staging on node: $(hostname) ---'
    
      mkdir -p ${LOCAL_SSD_PATH}
    
      echo 'Copying virtual environment...'
      rsync -a -q ${VENV_PATH}/ ${LOCAL_SSD_PATH}/venv/
    
      echo 'Copying model weights...'
      rsync -a --info=progress2 ${MODEL_PATH}/ ${LOCAL_SSD_PATH}/model/
    
      mkdir -p ${LOCAL_SSD_PATH}/hf_cache
    
      echo '--- Staging on $(hostname) complete ---'
    "
    echo "--- Staging complete on all nodes ---"
    
    # --- STAGE 2: Run the Training Job ---
    echo "--- Launching Distributed Training with GIB NCCL Plugin ---"
    nodes=( $( scontrol show hostnames "$SLURM_JOB_NODELIST" ) )
    head_node=${nodes[0]}
    head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
    
    export MASTER_ADDR=$head_node_ip
    export MASTER_PORT=29500
    
    export NCCL_SOCKET_IFNAME=enp0s19
    export NCCL_NET=gIB
    
    # export NCCL_DEBUG=INFO # Un-comment to diagnose NCCL issues if needed
    
    srun --cpu-bind=none --accel-bind=g bash -c '
      # Activate the environment from the local copy
      source '${LOCAL_SSD_PATH}'/venv/bin/activate
    
      # Point Hugging Face cache to the local SSD
      export HF_HOME='${LOCAL_SSD_PATH}'/hf_cache
    
      export RANK=$SLURM_PROCID
      export WORLD_SIZE=$SLURM_NTASKS
      export LOCAL_RANK=$SLURM_LOCALID
    
      export LD_LIBRARY_PATH=/usr/local/gib/lib64:$LD_LIBRARY_PATH
      source /usr/local/gib/scripts/set_nccl_env.sh
    
      # --- Launch the training ---
      python \
        '${SLURM_SUBMIT_DIR}'/llama4-train-distributed.py \
          --model_id="'${LOCAL_SSD_PATH}'/model/" \
          --output_dir="'${LOCAL_SSD_PATH}'/outputs/" \
          --dataset_name="philschmid/gretel-synthetic-text-to-sql" \
          --seed=900913 \
          --bf16=True \
          --num_train_epochs=1 \
          --per_device_train_batch_size=2 \
          --gradient_accumulation_steps=4 \
          --learning_rate=2e-5 \
          --logging_steps=10 \
          --lora_r=16 \
          --lora_alpha=32 \
          --lora_dropout=0.05 \
          --run_inference_after_training
    '
    
    # --- STAGE 3: Copy Final Results Back to Persistent Storage ---
    echo "--- Copying final results from local SSD to shared storage ---"
    PERSISTENT_OUTPUT_DIR="${HOME}/outputs/llama4_job_${SLURM_JOB_ID}"
    mkdir -p "$PERSISTENT_OUTPUT_DIR"
    
    # Only copy from the head node where trl has combined the results
    srun --nodes=1 --ntasks=1 -w "$head_node" \
      rsync -a --info=progress2 "${LOCAL_SSD_PATH}/outputs/" "${PERSISTENT_OUTPUT_DIR}/"
    
    # --- STAGE 4: Cleanup ---
    echo "--- Cleaning up local SSD on all nodes ---"
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "rm -rf ${LOCAL_SSD_PATH}"
    
    echo "--- Slurm Job Finished ---"
    

将脚本上传到 Slurm 集群

如需将您在上一部分中创建的脚本上传到 Slurm 集群,请按照以下步骤操作:

  1. 如需标识登录节点,请列出项目中的所有 A4 虚拟机:

    gcloud compute instances list --filter="machineType:a4-highgpu-8g"
    

    登录节点的名称类似于 a4-high-login-001

  2. 将脚本上传到登录节点的主目录:

    gcloud compute scp --project="$PROJECT_ID" --zone="$ZONE" --tunnel-through-iap \
      ./install_environment.sh \
      ./requirements-fsdp.txt \
      ./llama4-train-distributed.py \
      ./submit.slurm \
      "${LOGIN_NODE_NAME}":~/
    

    LOGIN_NODE_NAME 替换为登录节点的名称。

连接到 Slurm 集群

通过 SSH 连接到登录节点,从而连接到 Slurm 集群:

gcloud compute ssh LOGIN_NODE_NAME \
    --project=PROJECT_ID \
    --tunnel-through-iap \
    --zone=ZONE

安装框架和工具

连接到登录节点后,请执行以下操作来安装框架和工具:

  1. 导出 Hugging Face 令牌:

    # On the login node
    export HF_TOKEN="hf_..." # Replace with your token
    
  2. 运行安装脚本:

    # On the login node
    chmod +x install_environment.sh
    ./install_environment.sh
    

    此命令会设置包含所有必需依赖项的虚拟环境,并将模型权重下载到 ~/Llama-4-Scout-17B-16E-Instruct 文件中。

    由于模型下载非常大(约 200 GB),因此此过程大约需要 30 分钟,具体取决于网络状况。

启动微调工作负载

如需开始训练工作负载,请执行以下操作:

  1. 将作业提交给 Slurm 调度程序:

    sbatch submit.slurm
    
  2. 在 Slurm 集群中的登录节点上,您可以通过检查在 home 目录中创建的输出文件来监控作业的进度:

    # On the login node
    tail -f llama4-*.out
    

    如果作业成功启动,.err 文件会显示一个进度条,该进度条会随着作业的进行而更新。

    此作业大约需要一个多小时才能在 Slurm 集群上完成。该作业分为两个主要阶段:

    • 将大型基础模型复制到每个计算节点的本地 SSD。
    • 训练作业,该作业在模型复制完成后开始。 此作业大约需要 35 分钟才能运行。

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除项目

删除项目: Google Cloud

gcloud projects delete PROJECT_ID

删除 Slurm 集群

如需删除 Slurm 集群,请按照以下步骤操作:

  1. 转到 cluster-toolkit 目录。

  2. 销毁 Terraform 文件和所有已创建的资源:

    ./gcluster destroy a4-high --auto-approve
    

删除 Filestore 实例

默认情况下,Filestore 实例在 cluster-toolkit 蓝图中将 deletion_protection 设置为 true。此设置可防止您在修改环境时意外丢失数据。如需删除 Filestore 实例, 您必须 手动停用防删除保护

后续步骤