使用 FSDP 在 A4 Slurm 集群上对 Mixtral-8x7B 进行微调

本教程将介绍如何在多节点、多 GPU Slurm 集群上微调 mistralai/Mixtral-8x7B-v0.1 模型 Google Cloud。该集群使用两个 a4-highgpu-8g 虚拟机 (VM) 实例,每个实例都有 8 个 NVIDIA B200 GPU。

本教程中介绍的两个主要流程如下:

  1. 使用 Google Cloud Cluster Toolkit 部署生产级高性能 Slurm 集群。在此部署过程中,您将创建一个预先安装了必要软件的自定义虚拟机映像。您还将设置共享 Lustre 文件系统并配置高速网络。
  2. 集群部署完毕后,您可以使用本教程附带的一组脚本运行分布式微调作业。该作业利用 PyTorch Fully Sharded Data Parallel (FSDP),您可以通过 Hugging Face Transformer Reinforcement Learning (TRL) 库访问该作业。

本教程适用于机器学习 (ML) 工程师、研究人员、平台管理员和运维人员,以及对跨多个节点和 GPU 分布 AI 工作负载感兴趣的数据和 AI 专家。

目标

  • 使用 Hugging Face 访问 Mixtral
  • 准备环境
  • 创建和部署生产级 A4 High-GPU Slurm 集群。
  • 配置多节点环境以使用 FSDP 进行分布式训练。
  • 使用 Hugging Face trl.SFTTrainer 类微调 Mixtral 模型。
  • 将数据暂存到本地 SSD。
  • 监控作业。
  • 清理。

费用

在本文档中,您将使用的以下收费组件: Google Cloud

您可使用 价格计算器 根据您的预计使用情况来估算费用。

新 Google Cloud 用户可能有资格申请免费试用

准备工作

  1. 登录您的 Google Cloud 账号。如果您是 Google Cloud新手, 请创建一个账号来评估我们的产品在 实际场景中的表现。新客户还可获享 $300 赠金,用于 运行、测试和部署工作负载。
  2. 安装 Google Cloud CLI。

  3. 如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI

  4. 如需初始化 gcloud CLI,请运行以下命令:

    gcloud init
  5. 创建或选择 Google Cloud 项目

    选择或创建项目所需的角色

    • 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授予角色的任何项目。
    • 创建项目:如需创建项目,您需要 Project Creator 角色 (roles/resourcemanager.projectCreator),该角色包含 resourcemanager.projects.create 权限。了解如何授予 角色
    • 创建 Google Cloud 项目:

      gcloud projects create PROJECT_ID

      PROJECT_ID 替换为您要创建的 Google Cloud 项目的名称。

    • 选择您创建的 Google Cloud 项目:

      gcloud config set project PROJECT_ID

      PROJECT_ID 替换为您的 Google Cloud 项目名称。

  6. 验证是否已为您的 Google Cloud 项目启用结算功能。

  7. 启用必需的 API:

    启用 API 所需的角色

    如需启用 API,您需要 Service Usage Admin IAM 角色 (roles/serviceusage.serviceUsageAdmin),该角色包含 serviceusage.services.enable 权限。了解如何授予 角色

    gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com lustre.googleapis.com
  8. 安装 Google Cloud CLI。

  9. 如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI

  10. 如需初始化 gcloud CLI,请运行以下命令:

    gcloud init
  11. 创建或选择 Google Cloud 项目

    选择或创建项目所需的角色

    • 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授予角色的任何项目。
    • 创建项目:如需创建项目,您需要 Project Creator 角色 (roles/resourcemanager.projectCreator),该角色包含 resourcemanager.projects.create 权限。了解如何授予 角色
    • 创建 Google Cloud 项目:

      gcloud projects create PROJECT_ID

      PROJECT_ID 替换为您要创建的 Google Cloud 项目的名称。

    • 选择您创建的 Google Cloud 项目:

      gcloud config set project PROJECT_ID

      PROJECT_ID 替换为您的 Google Cloud 项目名称。

  12. 验证是否已为您的 Google Cloud 项目启用结算功能。

  13. 启用必需的 API:

    启用 API 所需的角色

    如需启用 API,您需要 Service Usage Admin IAM 角色 (roles/serviceusage.serviceUsageAdmin),该角色包含 serviceusage.services.enable 权限。了解如何授予 角色

    gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com lustre.googleapis.com
  14. 向您的用户账号授予角色。对以下每个 IAM 角色运行以下命令一次: roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/serviceusage.serviceUsageAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    替换以下内容:

    • PROJECT_ID:您的项目 ID。
    • USER_IDENTIFIER:您的用户账号的标识符。 例如,myemail@example.com
    • ROLE:您授予用户账号的 IAM 角色。
  15. 为您的 Google Cloud 项目启用默认服务帐号:
    gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com \
        --project=PROJECT_ID

    PROJECT_NUMBER 替换为您的项目编号。如需查看您的 项目编号,请参阅 获取现有项目

  16. 向默认服务账号授予 Editor 角色 (roles/editor):
    gcloud projects add-iam-policy-binding PROJECT_ID \
      --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com" \
      --role=roles/editor
  17. 为您的用户账号创建本地身份验证凭据:
    gcloud auth application-default login
  18. 为您的项目启用 OS Login:
    gcloud compute project-info add-metadata --metadata=enable-oslogin=TRUE
  19. 登录或创建 Hugging Face 账号
  20. 安装使用 Cluster Toolkit 所需的依赖项

使用 Hugging Face 访问 Mixtral

如需使用 Hugging Face 访问 Mixtral,请执行以下操作:

  1. 创建 Hugging Face read access 令牌
  2. 复制并保存 read 访问令牌值。您将在本教程的后面部分使用它。

准备环境

您可以在本地机器上执行以下步骤,为集群部署做准备。

  1. 克隆 Google Cloud Cluster Toolkit 代码库:

    git clone https://github.com/GoogleCloudPlatform/cluster-toolkit.git
    
  2. 创建 Cloud Storage 存储桶:

    export BUCKET_NAME="your-unique-bucket-name"
    gcloud storage buckets create gs://${BUCKET_NAME}
    

创建 A4 Slurm 集群

如需创建 A4 Slurm 集群,请执行以下操作:

  1. 转到克隆的 cluster-toolkit 目录:

    cd cluster-toolkit
    
  2. 如果您是首次使用 Cluster Toolkit,请构建 gcluster 二进制文件:

    make
    
  3. 转到 examples/machine-learning/a4-highgpu-8g 目录。

    打开 a4high-slurm-deployment.yaml 文件并按如下方式进行修改:

    terraform_backend_defaults:
      type: gcs
      configuration:
        bucket: BUCKET_NAME
    
    vars:
      deployment_name: DEPLOYMENT_NAME
      project_id: PROJECT_ID
      region: REGION
      zone: ZONE
      a4h_cluster_size: 2
      a4h_reservation_name: RESERVATION_NAME
    

    替换以下内容:

    • BUCKET_NAME::您在上一步中创建的 Cloud Storage 存储桶的名称。
    • PROJECT_ID:Cloud Storage 所在的项目的 ID,也是您要创建 Slurm 集群的项目。 Google Cloud
    • REGION:预留所在的区域。
    • ZONE:预留所在的可用区。
    • A4h_reservation_name:使用 A4 预留的名称。
  4. 打开 a4high-slurm-blueprint.yaml 文件并按如下方式进行修改:

    • 移除 filestore_homefs 模块。
    • 启用 lustrefsprivate-service-access 模块。
    • vars 块中,配置以下内容:
      1. Find slurm_vars 并将 install_managed_lustre 设置为 true
      2. per_unit_storage_throughput 参数设置为 500
      3. size_gib 参数设置为 36000
  5. 部署集群:

    ./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml \
      examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml \
      --auto-approve
    

    ./gcluster deploy 命令会启动一个两阶段流程,如下所示:

    • 第一阶段构建预先安装了所有软件的自定义映像,此过程最多可能需要 35 分钟才能完成。
    • 第二阶段使用该自定义映像部署集群。此过程应比第一阶段更快完成。

准备工作负载

如需准备工作负载,请按以下步骤操作:

  1. 创建工作负载脚本

  2. 将脚本上传到 Slurm 集群

  3. 连接到 Slurm 集群

  4. 安装框架和工具

创建工作负载脚本

如需创建微调工作负载将使用的脚本,请按以下步骤操作:

  1. 如需设置 Python 虚拟环境,请创建包含以下内容的 install_environment.sh 文件:

    #!/bin/bash
    # This script sets a reliable environment for FSDP training.
    # It is meant to be run on a compute node.
    set -e
    
    # --- 1. Create the Python virtual environment ---
    VENV_PATH="$HOME/.venv/venv-fsdp"
    if [ ! -d "$VENV_PATH" ]; then
      echo "--- Creating Python virtual environment at $VENV_PATH ---"
      python3 -m venv $VENV_PATH
    else
      echo "--- Virtual environment already exists at $VENV_PATH ---"
    fi
    
    source $VENV_PATH/bin/activate
    
    # --- 2. Install Dependencies ---
    echo "--- [STEP 2.1] Upgrading build toolchain ---"
    pip install --upgrade pip wheel packaging
    
    echo "--- [STEP 2.2] Installing PyTorch Nightly ---"
    pip install --force-reinstall --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
    
    echo "--- [STEP 2.3] Installing application dependencies ---"
    if [ -f "requirements-fsdp.txt" ]; then
        pip install -r requirements-fsdp.txt
    else
        echo "ERROR: requirements-fsdp.txt not found!"
        exit 1
    fi
    
    # --- [STEP 2.4] Build Flash Attention from Source ---
    echo "--- Building flash-attn from source... This will take a while. ---"
    # Use all available CPU cores to speed up the build
    MAX_JOBS=$(nproc) pip install flash-attn --no-build-isolation
    
    # --- 3. Download the Model ---
    echo "--- [STEP 2.5] Downloading Mixtral model ---"
    if [ -z "$HF_TOKEN" ]; then
      echo "ERROR: The HF_TOKEN environment variable is not set."; exit 1;
    fi
    pip install huggingface_hub[cli]
    
    # Execute the CLI using its full, explicit path
    $VENV_PATH/bin/huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir ~/Mixtral-8x7B-v0.1 --token $HF_TOKEN
    
    echo "--- Environment setup complete. ---"
    
  2. 如需为训练脚本指定 Python 依赖项,请创建包含以下内容的 requirements-fsdp.txt 文件:

    transformers==4.55.0
    datasets==4.0.0
    peft==0.16.0
    accelerate==1.9.0
    trl==0.21.0
    
    # Other dependencies
    sentencepiece==0.2.0
    protobuf==6.31.1
    
  3. train-mixtral.py 指定为主训练脚本:

    import torch
    from torch.distributed.fsdp import MixedPrecision
    from datasets import load_dataset
    import shutil
    import os
    import torch.distributed as dist
    
    from peft import LoraConfig, PeftModel, get_peft_model
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        HfArgumentParser,
    )
    
    from torch.distributed import get_rank, get_world_size
    
    from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
    from trl import SFTTrainer
    from dataclasses import dataclass, field
    from typing import Optional
    
    @dataclass
    class ScriptArguments:
        model_id: str = field(default="mistralai/Mixtral-8x7B-v0.1", metadata={"help": "Hugging Face model ID from the Hub"})
        dataset_name: str = field(default="philschmid/gretel-synthetic-text-to-sql", metadata={"help": "Dataset from the Hub"})
        run_inference_after_training: bool = field(default=False, metadata={"help": "Run sample inference on rank 0 after training"})
        dataset_subset_size: Optional[int] = field(default=None, metadata={"help": "Number of samples to use from the dataset for training. If None, uses the full dataset."})
    
    @dataclass
    class PeftArguments:
        lora_r: int = field(default=16, metadata={"help": "LoRA attention dimension"})
        lora_alpha: int = field(default=32, metadata={"help": "LoRA alpha scaling factor"})
        lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout probability"})
    
    @dataclass
    class SftTrainingArguments(TrainingArguments):
        max_length: Optional[int] = field(default=2048, metadata={"help": "The maximum sequence length for SFTTrainer"})
        packing: Optional[bool] = field(default=False, metadata={"help": "Enable packing for SFTTrainer"})
        ddp_find_unused_parameters: Optional[bool] = field(default=False, metadata={"help": "When using FSDP activation checkpointing, this must be set to False for Mixtral"})
    
    def formatting_prompts_func(example):
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        user_prompt = f"### SCHEMA:\n{example['sql_context']}\n\n### USER QUERY:\n{example['sql_prompt']}"
        response = f"\n\n### SQL QUERY:\n{example['sql']}"
        return f"{system_message}\n\n{user_prompt}{response}"
    
    def main():
        parser = HfArgumentParser((ScriptArguments, PeftArguments, SftTrainingArguments))
        script_args, peft_args, training_args = parser.parse_args_into_dataclasses()
    
        training_args.gradient_checkpointing = True
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
    
        training_args.optim = "adamw_torch_fused"
    
        bf16_policy = MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        )
    
        training_args.fsdp = "full_shard"
        training_args.fsdp_config = {
            "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
            "fsdp_transformer_layer_cls_to_wrap": [MixtralDecoderLayer],
            "fsdp_state_dict_type": "SHARDED_STATE_DICT",
            "fsdp_offload_params": False,
            "fsdp_forward_prefetch": True,
            "fsdp_mixed_precision_policy": bf16_policy
        }
    
        tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, trust_remote_code=True)
    
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right"
    
        model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            attn_implementation="flash_attention_2",
        )
    
        peft_config = LoraConfig(
            r=peft_args.lora_r,
            lora_alpha=peft_args.lora_alpha,
            lora_dropout=peft_args.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        )
    
        model = get_peft_model(model, peft_config)
    
        data_splits = load_dataset(script_args.dataset_name)
    
        dataset = data_splits["train"]
        eval_dataset = data_splits["test"]
    
        if script_args.dataset_subset_size is not None:
            dataset = dataset.select(range(script_args.dataset_subset_size))
    
        dataset = dataset.shuffle(seed=training_args.seed)
    
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            eval_dataset=eval_dataset,
            formatting_func=formatting_prompts_func,
            processing_class=tokenizer,
        )
    
        trainer.train()
    
        dist.barrier()
        if trainer.is_world_process_zero():
            best_model_path = trainer.state.best_model_checkpoint
    
            final_model_dir = os.path.join(training_args.output_dir, "final_best_model")
            print(f"Copying best model to: {final_model_dir}")
    
            if os.path.exists(final_model_dir):
                shutil.rmtree(final_model_dir)
            shutil.copytree(best_model_path, final_model_dir)
    
            if script_args.run_inference_after_training:
                del model, trainer
                torch.cuda.empty_cache()
                run_post_training_inference(script_args, final_model_dir, tokenizer)
    
    def run_post_training_inference(script_args, best_model_path, tokenizer):
        print("\n" + "="*50)
        print("=== RUNNING POST-TRAINING INFERENCE TEST ===")
        print("="*50 + "\n")
    
        base_model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        )
        model = PeftModel.from_pretrained(base_model, best_model_path)
        model = model.merge_and_unload()
        model.eval()
    
        # Define the test case
        schema = "CREATE TABLE artists (Name TEXT, Country TEXT, Genre TEXT)"
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        question = "Show me all artists from the Country just north of the USA."
    
        prompt = f"{system_message}\n\n### SCHEMA:\n{schema}\n\n### USER QUERY:\n{question}\n\n### SQL QUERY:\n"
    
        print(f"Test Prompt:\n{prompt}")
    
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
        print("\n--- Generating SQL... ---")
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False,
            temperature=None,
            top_p=None,
        )
    
        generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
    
        print(f"\n--- Generated SQL Query ---")
        print(generated_sql)
        print("\n" + "="*50)
        print("=== INFERENCE TEST COMPLETE ===")
        print("="*50 + "\n")
    
    if __name__ == "__main__":
        main()
    
  4. 如需指定作业在 Slurm 集群上运行的任务,请创建包含以下内容的 train-mixtral.sh 文件:

    #!/bin/bash
    #SBATCH --job-name=mixtral-fsdp
    #SBATCH --nodes=2
    #SBATCH --ntasks-per-node=8
    #SBATCH --gpus-per-node=8
    #SBATCH --partition=a4high
    #SBATCH --output=mixtral-%j.out
    #SBATCH --error=mixtral-%j.err
    
    set -e
    set -x
    
    echo "--- Slurm Job Started ---"
    
    # --- Define Paths ---
    LOCAL_SSD_PATH="/mnt/localssd/job_${SLURM_JOB_ID}"
    VENV_PATH="${HOME}/.venv/venv-fsdp"
    MODEL_PATH="${HOME}/Mixtral-8x7B-v0.1"
    
    # --- STAGE 1: Stage Data to Local SSD on Each Node ---
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "
    echo '--- Staging on node: $(hostname) ---'
    mkdir -p ${LOCAL_SSD_PATH}
    
    echo 'Copying virtual environment...'
    rsync -a -q ${VENV_PATH}/ ${LOCAL_SSD_PATH}/venv/
    
    echo 'Copying model weights...'
    rsync -a ${MODEL_PATH}/ ${LOCAL_SSD_PATH}/model/
    
    mkdir -p ${LOCAL_SSD_PATH}/hf_cache
    
    echo '--- Staging on $(hostname) complete ---'
    "
    echo "--- Staging complete on all nodes ---"
    
    # --- STAGE 2: Run the Training Job ---
    echo "--- Launching Distributed Training with GIB NCCL Plugin ---"
    nodes=( $( scontrol show hostnames "$SLURM_JOB_NODELIST" ) )
    head_node=${nodes[0]}
    head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
    
    export MASTER_ADDR=$head_node_ip
    export MASTER_PORT=29500
    
    export NCCL_SOCKET_IFNAME=enp0s19
    
    export NCCL_NET=gIB
    
    # export NCCL_DEBUG=INFO # Un-comment to diagnose NCCL issues if needed
    
    srun --cpu-bind=none --accel-bind=g bash -c '
    # Activate the environment from the local copy
    source '${LOCAL_SSD_PATH}'/venv/bin/activate
    
    # Point Hugging Face cache to the local SSD
    export HF_HOME='${LOCAL_SSD_PATH}'/hf_cache
    
    export RANK=$SLURM_PROCID
    export WORLD_SIZE=$SLURM_NTASKS
    export LOCAL_RANK=$SLURM_LOCALID
    
    export LD_LIBRARY_PATH=/usr/local/gib/lib64:$LD_LIBRARY_PATH
    source /usr/local/gib/scripts/set_nccl_env.sh
    
    # --- Launch the training ---
    python \
        '${SLURM_SUBMIT_DIR}'/train-mixtral.py \
        --model_id="'${LOCAL_SSD_PATH}'/model/" \
        --output_dir="${HOME}/outputs/mixtral_job_${SLURM_JOB_ID}" \
        --dataset_name="philschmid/gretel-synthetic-text-to-sql" \
        --seed=900913 \
        --bf16=True \
        --num_train_epochs=3 \
        --per_device_train_batch_size=32 \
        --gradient_accumulation_steps=4 \
        --learning_rate=4e-5 \
        --logging_steps=3 \
        --lora_r=32 \
        --lora_alpha=32 \
        --lora_dropout=0.05 \
        --eval_strategy=steps \
        --eval_steps=10 \
        --save_strategy=steps \
        --save_steps=10 \
        --load_best_model_at_end=False \
        --metric_for_best_model=eval_loss \
        --run_inference_after_training \
        --dataset_subset_size=67000
    '
    
    # --- STAGE 3: Cleanup ---
    echo "--- Cleaning up local SSD on all nodes ---"
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "rm -rf ${LOCAL_SSD_PATH}"
    
    echo "--- Slurm Job Finished ---"
    

将脚本上传到 Slurm 集群

如需将您在上一步中创建的脚本上传到 Slurm 集群,请执行以下操作:

  1. 如需标识登录节点,请列出项目中的所有虚拟机:

    gcloud compute instances list
    

    登录节点的名称类似于 a4-high-login-001

  2. 将脚本上传到登录节点的主目录:

    # Run this from your local machine where you created the files
    LOGIN_NODE_NAME="your-login-node-name" # e.g., a4high-login-001
    PROJECT_ID="your-gcp-project-id"
    ZONE="your-cluster-zone" # e.g., us-west4-a
    
    gcloud compute scp --project="$PROJECT_ID" --zone="$ZONE" --tunnel-through-iap \
      ./install_environment.sh \
      ./requirements-fsdp.txt \
      ./train-mixtral.py \
      ./train-mixtral.sh \
      "${LOGIN_NODE_NAME}":~/
    

连接到 Slurm 集群

通过 SSH 连接到登录节点,从而连接到 Slurm 集群:

gcloud compute ssh $LOGIN_NODE_NAME \
    --project=$PROJECT_ID \
    --tunnel-through-iap \
    --zone=$ZONE

安装框架和工具

连接到登录节点后,安装框架和工具。

  1. 导出 Hugging Face 令牌:

    # On the login node
    export HF_TOKEN="hf_..." # Replace with your token
    
  2. 在计算节点上运行安装脚本。

    # On the login node
    srun \
      --job-name=env-setup \
      --nodes=1 \
      --ntasks=1 \
      --gpus-per-node=1 \
      --partition=a4high \
      bash ./install_environment.sh
    

    此命令会设置虚拟环境,安装所有依赖项,并将 Mixtral 模型权重下载到 ~/Mixtral-8x7B-v0.1 中。此过程可能需要 30 多分钟才能完成。

启动微调工作负载

如需启动训练工作负载,请执行以下操作:

  1. 将作业提交给 Slurm 调度程序:

    # On the login node
    sbatch train-mixtral.sh
    
  2. 在 Slurm 集群中的登录节点上,您可以通过检查在 home 目录中创建的输出文件来监控作业的进度:

    # On the login node
    tail -f mixtral-*.out
    

    如果作业成功启动,.err 文件会显示一个进度条,该进度条会随着作业的进行而更新。

    该作业有两个主要阶段:

    • 将大型基础模型复制到每个计算节点的本地 SSD。
    • 训练作业,该作业在模型复制完成后开始。

    整个作业大约需要 40 分钟才能运行完毕。

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除 Slurm 集群

如需删除 Slurm 集群,请按以下步骤操作:

  1. 转到 cluster-toolkit 目录。

  2. 销毁 Terraform 文件和所有已创建的资源:

    ./gcluster destroy DEPLOYMENT_NAME --auto-approve
    

删除项目

删除项目: Google Cloud

gcloud projects delete PROJECT_ID

后续步骤