FSDP を使用して A4 Slurm クラスタで Llama 4 をファインチューニングする

このチュートリアルでは、 Google Cloudのマルチノード、マルチ GPU Slurm クラスタで Llama-4-Scout-17 大規模言語モデル(LLM)をファインチューニングする方法について説明します。クラスタは、それぞれ 8 個の NVIDIA B200 GPU を搭載した 2 つの A4 仮想マシン(VM)インスタンスを使用します。

このチュートリアルで説明する主なプロセスは次の 2 つです。

  1. Google Cloud Cluster Toolkit を使用して、本番環境グレードの高性能 Slurm クラスタをデプロイします。このデプロイの一環として、必要なソフトウェアがプリインストールされたカスタム VM イメージを作成します。また、共有 Filestore インスタンスを設定し、高速 RDMA ネットワーキングを構成します。
  2. クラスタをデプロイしたら、このチュートリアルに付属のスクリプト セットを使用して、分散ファインチューニング ジョブを実行します。このジョブは、Hugging Face の Transformer Reinforcement Learning を介してアクセスする PyTorch Fully Sharded Data Parallel(FSDP)を活用します。

このチュートリアルは、Slurm ジョブ スケジューリング機能を使用してファインチューニング ワークロードを処理することに関心がある ML エンジニア、プラットフォーム管理者、オペレーター、データおよび AI スペシャリストを対象としています。

目標

  • Hugging Face を使用して Llama 4 にアクセスする

  • 環境を準備する

  • 本番環境グレードの A4 High-GPU Slurm クラスタを作成してデプロイします。

  • FSDP を使用した分散トレーニング用にマルチノード環境を構成します。

  • Hugging Face trl.SFTTrainer を使用して Llama 4 モデルをファインチューニングします。

  • データをローカル SSD にステージングします。

  • ジョブをモニタリングします。

  • クリーンアップする。

費用

このドキュメントでは、課金対象である次の Google Cloudコンポーネントを使用します。

料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを生成できます。

新規の Google Cloud ユーザーは無料トライアルをご利用いただける場合があります。

始める前に

  1. Google Cloud アカウントにログインします。 Google Cloudを初めて使用する場合は、 アカウントを作成して、実際のシナリオでの Google プロダクトのパフォーマンスを評価してください。新規のお客様には、ワークロードの実行、テスト、デプロイができる無料クレジット $300 分を差し上げます。
  2. Google Cloud CLI をインストールします。

  3. 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  4. gcloud CLI を初期化するには、次のコマンドを実行します。

    gcloud init
  5. Google Cloud プロジェクトを作成または選択します

    プロジェクトの選択または作成に必要なロール

    • プロジェクトを選択する: プロジェクトの選択に特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトであれば、どのプロジェクトでも選択できます。
    • プロジェクトを作成する: プロジェクトを作成するには、resourcemanager.projects.create 権限を含むプロジェクト作成者ロール(roles/resourcemanager.projectCreator)が必要です。ロールを付与する方法を確認する
    • Google Cloud プロジェクトを作成します。

      gcloud projects create PROJECT_ID

      PROJECT_ID は、作成する Google Cloud プロジェクトの名前に置き換えます。

    • 作成した Google Cloud プロジェクトを選択します。

      gcloud config set project PROJECT_ID

      PROJECT_ID は、 Google Cloud プロジェクトの名前に置き換えます。

  6. Google Cloud プロジェクトに対して課金が有効になっていることを確認します

  7. 必要な API を有効にします。

    API を有効にするために必要なロール

    API を有効にするには、serviceusage.services.enable 権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。ロールを付与する方法を確認する

    gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
  8. Google Cloud CLI をインストールします。

  9. 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  10. gcloud CLI を初期化するには、次のコマンドを実行します。

    gcloud init
  11. Google Cloud プロジェクトを作成または選択します

    プロジェクトの選択または作成に必要なロール

    • プロジェクトを選択する: プロジェクトの選択に特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトであれば、どのプロジェクトでも選択できます。
    • プロジェクトを作成する: プロジェクトを作成するには、resourcemanager.projects.create 権限を含むプロジェクト作成者ロール(roles/resourcemanager.projectCreator)が必要です。ロールを付与する方法を確認する
    • Google Cloud プロジェクトを作成します。

      gcloud projects create PROJECT_ID

      PROJECT_ID は、作成する Google Cloud プロジェクトの名前に置き換えます。

    • 作成した Google Cloud プロジェクトを選択します。

      gcloud config set project PROJECT_ID

      PROJECT_ID は、 Google Cloud プロジェクトの名前に置き換えます。

  12. Google Cloud プロジェクトに対して課金が有効になっていることを確認します

  13. 必要な API を有効にします。

    API を有効にするために必要なロール

    API を有効にするには、serviceusage.services.enable 権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。ロールを付与する方法を確認する

    gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
  14. ユーザー アカウントにロールを付与します。次の IAM ロールごとに次のコマンドを 1 回実行します。 roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/serviceusage.serviceUsageAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    次のように置き換えます。

    • PROJECT_ID: プロジェクト ID。
    • USER_IDENTIFIER: ユーザー アカウントの識別子。例: myemail@example.com
    • ROLE: ユーザー アカウントに付与する IAM ロール。
  15. Google Cloud プロジェクトのデフォルトのサービス アカウントを有効にします。
    gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com 
    --project=PROJECT_ID

    PROJECT_NUMBER は、使用するプロジェクト番号に置き換えます。プロジェクト番号を確認するには、 既存のプロジェクトを取得するをご覧ください。

  16. デフォルトのサービス アカウントに編集者ロール(roles/editor)を付与します。
    gcloud projects add-iam-policy-binding PROJECT_ID 
    --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com"
    --role=roles/editor
  17. ユーザー アカウントのローカル認証情報を作成します。
    gcloud auth application-default login
  18. プロジェクトで OS Login を有効にします。
    gcloud compute project-info add-metadata --metadata=enable-oslogin=TRUE
  19. Hugging Face アカウントにログインするか、アカウントを作成します
  20. Cluster Toolkit の使用に必要な依存関係をインストールします

Hugging Face を使用して Llama 4 にアクセスする

Hugging Face を使用して Llama 4 にアクセスする手順は次のとおりです。

  1. Llama 4 を使用するための同意契約に署名します

  2. Hugging Face read アクセス トークンを作成します

    [Your Profile] > [Settings] > [Access tokens] > [+Create new token] の順にクリックします。

  3. read access トークンの値をコピーして保存します。これは、このチュートリアルの後半で使用します。

環境を準備する

環境を準備する手順は次のとおりです。

  1. Cluster Toolkit GitHub リポジトリのクローンを作成します。

    git clone https://github.com/GoogleCloudPlatform/cluster-toolkit.git
    
  2. Cloud Storage バケットを作成します。

    gcloud storage buckets create gs://BUCKET_NAME \
        --project=PROJECT_ID
    

    次のように置き換えます。

    • BUCKET_NAME: バケットの命名要件に沿った Cloud Storage バケットの名前。

    • PROJECT_ID: Cloud Storage バケットを作成するGoogle Cloud プロジェクトの ID。

A4 Slurm クラスタを作成する

A4 Slurm クラスタを作成する手順は次のとおりです。

  1. cluster-toolkit ディレクトリに移動します。

    cd cluster-toolkit
    
  2. Cluster Toolkit を初めて使用する場合は、gcluster バイナリをビルドします。

    make
    
  3. examples/machine-learning/a4-highgpu-8g ディレクトリに移動します。

    cd examples/machine-learning/a4-highgpu-8g/
    
  4. a4high-slurm-deployment.yaml ファイルを開き、次のように編集します。

    terraform_backend_defaults:
      type: gcs
      configuration:
        bucket: BUCKET_NAME
    
    vars:
      deployment_name: a4-high
      project_id: PROJECT_ID
      region: REGION
      zone: ZONE
      a4h_cluster_size: 2
      a4h_reservation_name: RESERVATION_URL
    

    次のように置き換えます。

    • BUCKET_NAME: 前のセクションで作成した Cloud Storage バケットの名前。

    • PROJECT_ID: Cloud Storage が存在し、Slurm クラスタを作成するGoogle Cloud プロジェクトの ID。

    • REGION: 予約が存在するリージョン。

    • ZONE: 予約が存在するゾーン。

    • RESERVATION_URL: Slurm クラスタの作成に使用する予約の URL。予約が存在するプロジェクトに基づいて、次のいずれかの値を指定します。

      • 予約がプロジェクトに存在する場合: RESERVATION_NAME

      • 予約が別のプロジェクトにあり、プロジェクトで予約を使用できる場合: projects/RESERVATION_PROJECT_ID/reservations/RESERVATION_NAME

  5. クラスタをデプロイします。

    ./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve
    

    ./gcluster deploy コマンドは 2 フェーズのプロセスです。

    • 最初のフェーズでは、すべてのソフトウェアがプリインストールされたカスタム イメージがビルドされます。この処理には最大 35 分かかることがあります。

    • 第 2 フェーズでは、そのカスタム イメージを使用してクラスタをデプロイします。このプロセスは、最初のフェーズよりも早く完了します。

    第 1 フェーズは成功したが第 2 フェーズが失敗した場合は、第 1 フェーズをスキップして Slurm クラスタのデプロイを再試行できます。

    ./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve --skip "image" -w
    

ワークロードを準備する

ワークロードを準備する手順は次のとおりです。

  1. ワークロード スクリプトを作成する

  2. スクリプトを Slurm クラスタにアップロードします

  3. Slurm クラスタに接続します

  4. フレームワークとツールをインストールします

ワークロード スクリプトを作成する

ファインチューニング ワークロードで使用するスクリプトを作成する手順は次のとおりです。

  1. Python 仮想環境を設定するには、次の内容で install_environment.sh ファイルを作成します。

    #!/bin/bash
    # This script sets up a consistent environment for FSDP training.
    # It is meant to be run once on the login node of your Slurm cluster
    set -e
    
    # --- 1. Create the Python virtual environment ---
    VENV_PATH="$HOME/.venv/venv-fsdp"
    if [ ! -d "$VENV_PATH" ]; then
      echo "--- Creating Python virtual environment at $VENV_PATH ---"
      python3 -m venv $VENV_PATH
    else
      echo "--- Virtual environment already exists at $VENV_PATH ---"
    fi
    
    source $VENV_PATH/bin/activate
    
    # --- 2. Install Dependencies ---
    echo "--- [STEP 2.1] Upgrading build toolchain ---"
    pip install --upgrade pip wheel packaging
    
    echo "--- [STEP 2.2] Installing PyTorch Nightly ---"
    pip install --force-reinstall --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
    
    echo "--- [STEP 2.3] Installing application dependencies ---"
    if [ -f "requirements-fsdp.txt" ]; then
        pip install -r requirements-fsdp.txt
    else
        echo "ERROR: requirements-fsdp.txt not found!"
        exit 1
    fi
    
    # --- 3. Download the Model ---
    echo "--- [STEP 2.4] Downloading Llama4 model ---"
    if [ -z "$HF_TOKEN" ]; then
      echo "ERROR: The HF_TOKEN environment variable is not set."; exit 1;
    fi
    pip install huggingface_hub[cli]
    
    # Execute the CLI using its full, explicit path
    $VENV_PATH/bin/huggingface-cli download meta-llama/Llama-4-Scout-17B-16E-Instruct --local-dir ~/Llama-4-Scout-17B-16E-Instruct --token $HF_TOKEN
    
    echo "--- Environment setup complete. ---"
    

    このスクリプトは、信頼性の高い Python 仮想環境を設定し、PyTorch ナイトリー ビルドをインストールして、Llama 4 モデルをダウンロードします。

  2. トレーニング スクリプトの Python 依存関係を指定するには、次の内容の requirements-fsdp.txt ファイルを作成します。

    transformers==4.55.0
    datasets==4.0.0
    peft==0.16.0
    accelerate==1.9.0
    trl==0.21.0
    
    # Other dependencies
    sentencepiece==0.2.0
    
  3. メインのトレーニング スクリプトとして llama4-train-distributed.py を指定します。

    import torch
    from datasets import load_dataset
    from peft import LoraConfig, PeftModel
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        HfArgumentParser,
    )
    
    from torch.distributed import get_rank, get_world_size
    
    from transformers.models.llama4.modeling_llama4 import Llama4TextDecoderLayer
    from trl import SFTTrainer
    from dataclasses import dataclass, field
    from typing import Optional
    
    @dataclass
    class ScriptArguments:
        model_id: str = field(metadata={"help": "Hugging Face model ID from the Hub"})
        dataset_name: str = field(default="philschmid/gretel-synthetic-text-to-sql", metadata={"help": "Dataset from the Hub"})
        run_inference_after_training: bool = field(default=False, metadata={"help": "Run sample inference on rank 0 after training"})
        dataset_subset_size: Optional[int] = field(default=None, metadata={"help": "Number of samples to use from the dataset for training. If None, uses the full dataset."})
    
    @dataclass
    class PeftArguments:
        lora_r: int = field(default=16, metadata={"help": "LoRA attention dimension"})
        lora_alpha: int = field(default=32, metadata={"help": "LoRA alpha scaling factor"})
        lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout probability"})
    
    @dataclass
    class SftTrainingArguments(TrainingArguments):
        max_length: Optional[int] = field(default=2048, metadata={"help": "The maximum sequence length for SFTTrainer"})
        packing: Optional[bool] = field(default=False, metadata={"help": "Enable packing for SFTTrainer"})
        ddp_find_unused_parameters: Optional[bool] = field(default=True, metadata={"help": "When using FSDP activation checkpointing, this must be set to True"})
    
    def formatting_prompts_func(example):
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        user_prompt = f"### SCHEMA:\n{example['sql_context']}\n\n### USER QUERY:\n{example['sql_prompt']}"
        response = f"\n\n### SQL QUERY:\n{example['sql']}"
        return f"{system_message}\n\n{user_prompt}{response}"
    
    def main():
        parser = HfArgumentParser((ScriptArguments, PeftArguments, SftTrainingArguments))
        script_args, peft_args, training_args = parser.parse_args_into_dataclasses()
    
        training_args.gradient_checkpointing = True
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
    
        training_args.optim = "adamw_torch_fused"
    
        training_args.fsdp = "full_shard"
        training_args.fsdp_config = {
            "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
            "fsdp_transformer_layer_cls_to_wrap": [Llama4TextDecoderLayer],
            "fsdp_state_dict_type": "FULL_STATE_DICT",
            "fsdp_offload_params": False,
            "fsdp_forward_prefetch": True,
        }
    
        tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, trust_remote_code=True)
    
        model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            attn_implementation="sdpa",
        )
    
        peft_config = LoraConfig(
            r=peft_args.lora_r,
            lora_alpha=peft_args.lora_alpha,
            lora_dropout=peft_args.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        )
        rank = get_rank()
        world_size = get_world_size()
    
        dataset = load_dataset(script_args.dataset_name, split="train")
    
        if script_args.dataset_subset_size is not None:
            dataset = dataset.select(range(script_args.dataset_subset_size))
        else:
            print(f"Using the full dataset with {len(dataset)} samples.")
    
        dataset = dataset.shuffle(seed=training_args.seed)
        print(f"Dataset shuffled with seed: {training_args.seed}.")
    
        if world_size > 1:
            print(f"Sharding dataset for Rank {rank} of {world_size}.")
            dataset = dataset.shard(num_shards=world_size, index=rank)
    
        print("Initializing SFTTrainer...")
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            peft_config=peft_config,
            formatting_func=formatting_prompts_func,
            processing_class=tokenizer,
        )
    
        trainer.train()
    
        trainer.save_model(training_args.output_dir)
    
        if script_args.run_inference_after_training and trainer.is_world_process_zero():
            del model
            del trainer
            torch.cuda.empty_cache()
            run_post_training_inference(script_args, training_args, tokenizer)
    
    def run_post_training_inference(script_args, training_args, tokenizer):
        """
        Loads the fine-tuned PEFT adapter from the local output directory and runs inference.
        This should only be called on rank 0 after training is complete.
        """
        print("\n" + "="*50)
        print("=== RUNNING POST-TRAINING INFERENCE TEST ===")
        print("="*50 + "\n")
    
        # Load the base model and merge the adapter.
        base_model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        )
        # Load the PEFT adapter and merge it into the base model
        model = PeftModel.from_pretrained(base_model, training_args.output_dir)
        model = model.merge_and_unload() # Merge weights for faster inference
        model.eval()
    
        # Define the test case
        schema = "CREATE TABLE artists (Name TEXT, Country TEXT, Genre TEXT)"
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        question = "Show me all artists from the Country just north of the USA."
    
        # This must match the formatting_func exactly
        prompt = f"{system_message}\n\n### SCHEMA:\n{schema}\n\n### USER QUERY:\n{question}\n\n### SQL QUERY:\n"
    
        print(f"Test Prompt:\n{prompt}")
    
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
        print("\n--- Generating SQL... ---")
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False,
            temperature=None,
            top_p=None,
        )
    
        generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
    
        print(f"\n--- Generated SQL Query ---")
        print(generated_sql)
        print("\n" + "="*50)
        print("=== INFERENCE TEST COMPLETE ===")
        print("="*50 + "\n")
    
    if __name__ == "__main__":
        main()
    

    このスクリプトは、TRL 教師ありファインチューニング(SFT)トレーナーを使用して、FSDP トレーニング ループ、Low-Rank Adaptation(LoRA)構成、データ形式を管理します。

  4. Slurm クラスタで実行するジョブのタスクを指定するには、次のコンテンツを含む submit.slurm ファイルを作成します。

    #!/bin/bash
    #SBATCH --job-name=llama4-fsdp-fixed
    #SBATCH --nodes=2
    #SBATCH --ntasks-per-node=8
    #SBATCH --gpus-per-node=8
    #SBATCH --partition=a4high
    #SBATCH --output=llama4-%j.out
    #SBATCH --error=llama4-%j.err
    
    set -e
    set -x
    
    echo "--- Slurm Job Started ---"
    echo "Job ID: $SLURM_JOB_ID"
    echo "Node List: $SLURM_JOB_NODELIST"
    
    # --- Define Paths ---
    LOCAL_SSD_PATH="/mnt/localssd/job_${SLURM_JOB_ID}"
    VENV_PATH="${HOME}/.venv/venv-fsdp"
    MODEL_PATH="${HOME}/Llama-4-Scout-17B-16E-Instruct"
    
    # --- STAGE 1: Stage Data to Local SSD on Each Node ---
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "
      echo '--- Staging on node: $(hostname) ---'
    
      mkdir -p ${LOCAL_SSD_PATH}
    
      echo 'Copying virtual environment...'
      rsync -a -q ${VENV_PATH}/ ${LOCAL_SSD_PATH}/venv/
    
      echo 'Copying model weights...'
      rsync -a --info=progress2 ${MODEL_PATH}/ ${LOCAL_SSD_PATH}/model/
    
      mkdir -p ${LOCAL_SSD_PATH}/hf_cache
    
      echo '--- Staging on $(hostname) complete ---'
    "
    echo "--- Staging complete on all nodes ---"
    
    # --- STAGE 2: Run the Training Job ---
    echo "--- Launching Distributed Training with GIB NCCL Plugin ---"
    nodes=( $( scontrol show hostnames "$SLURM_JOB_NODELIST" ) )
    head_node=${nodes[0]}
    head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
    
    export MASTER_ADDR=$head_node_ip
    export MASTER_PORT=29500
    
    export NCCL_SOCKET_IFNAME=enp0s19
    export NCCL_NET=gIB
    
    # export NCCL_DEBUG=INFO # Un-comment to diagnose NCCL issues if needed
    
    srun --cpu-bind=none --accel-bind=g bash -c '
      # Activate the environment from the local copy
      source '${LOCAL_SSD_PATH}'/venv/bin/activate
    
      # Point Hugging Face cache to the local SSD
      export HF_HOME='${LOCAL_SSD_PATH}'/hf_cache
    
      export RANK=$SLURM_PROCID
      export WORLD_SIZE=$SLURM_NTASKS
      export LOCAL_RANK=$SLURM_LOCALID
    
      export LD_LIBRARY_PATH=/usr/local/gib/lib64:$LD_LIBRARY_PATH
      source /usr/local/gib/scripts/set_nccl_env.sh
    
      # --- Launch the training ---
      python \
        '${SLURM_SUBMIT_DIR}'/llama4-train-distributed.py \
          --model_id="'${LOCAL_SSD_PATH}'/model/" \
          --output_dir="'${LOCAL_SSD_PATH}'/outputs/" \
          --dataset_name="philschmid/gretel-synthetic-text-to-sql" \
          --seed=900913 \
          --bf16=True \
          --num_train_epochs=1 \
          --per_device_train_batch_size=2 \
          --gradient_accumulation_steps=4 \
          --learning_rate=2e-5 \
          --logging_steps=10 \
          --lora_r=16 \
          --lora_alpha=32 \
          --lora_dropout=0.05 \
          --run_inference_after_training
    '
    
    # --- STAGE 3: Copy Final Results Back to Persistent Storage ---
    echo "--- Copying final results from local SSD to shared storage ---"
    PERSISTENT_OUTPUT_DIR="${HOME}/outputs/llama4_job_${SLURM_JOB_ID}"
    mkdir -p "$PERSISTENT_OUTPUT_DIR"
    
    # Only copy from the head node where trl has combined the results
    srun --nodes=1 --ntasks=1 -w "$head_node" \
      rsync -a --info=progress2 "${LOCAL_SSD_PATH}/outputs/" "${PERSISTENT_OUTPUT_DIR}/"
    
    # --- STAGE 4: Cleanup ---
    echo "--- Cleaning up local SSD on all nodes ---"
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "rm -rf ${LOCAL_SSD_PATH}"
    
    echo "--- Slurm Job Finished ---"
    

スクリプトを Slurm クラスタにアップロードする

前のセクションで作成したスクリプトを Slurm クラスタにアップロードする手順は次のとおりです。

  1. ログインノードを特定するには、プロジェクト内のすべての A4 VM を一覧表示します。

    gcloud compute instances list --filter="machineType:a4-highgpu-8g"
    

    ログインノードの名前は a4-high-login-001 のようになります。

  2. スクリプトをログインノードのホーム ディレクトリにアップロードします。

    gcloud compute scp --project="$PROJECT_ID" --zone="$ZONE" --tunnel-through-iap \
      ./install_environment.sh \
      ./requirements-fsdp.txt \
      ./llama4-train-distributed.py \
      ./submit.slurm \
      "${LOGIN_NODE_NAME}":~/
    

    LOGIN_NODE_NAME は、ログインノードの名前に置き換えます。

Slurm クラスタに接続する

SSH を使用してログインノードに接続し、Slurm クラスタに接続します。

gcloud compute ssh LOGIN_NODE_NAME \
    --project=PROJECT_ID \
    --tunnel-through-iap \
    --zone=ZONE

フレームワークとツールをインストールする

ログインノードに接続したら、次の手順でフレームワークとツールをインストールします。

  1. Hugging Face トークンをエクスポートします。

    # On the login node
    export HF_TOKEN="hf_..." # Replace with your token
    
  2. インストール スクリプトを実行します。

    # On the login node
    chmod +x install_environment.sh
    ./install_environment.sh
    

    このコマンドは、必要なすべての依存関係を含む仮想環境をセットアップし、モデルの重みを ~/Llama-4-Scout-17B-16E-Instruct ファイルにダウンロードします。

    モデルのダウンロードは非常に大きいため(約 200 GB)、ネットワークの状況によってはこのプロセスに 30 分ほどかかります。

ファインチューニング ワークロードを開始する

ワークロードのトレーニングを開始するには、次の操作を行います。

  1. Slurm スケジューラにジョブを送信します。

    sbatch submit.slurm
    
  2. Slurm クラスタのログインノードで、home ディレクトリに作成された出力ファイルを確認することで、ジョブの進行状況をモニタリングできます。

    # On the login node
    tail -f llama4-*.out
    

    ジョブが正常に開始されると、.err ファイルにプログレスバーが表示され、ジョブの進行状況に応じて更新されます。

    このジョブは、Slurm クラスタで 1 時間強で完了します。ジョブには、次の 2 つの主要なフェーズがあります。

    • 大規模なベースモデルを各コンピューティング ノードのローカル SSD にコピーする。
    • モデルのコピーが完了すると開始されるトレーニング ジョブ。このジョブの実行には約 35 分かかります。

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

プロジェクトの削除

Google Cloud プロジェクトを削除します。

gcloud projects delete PROJECT_ID

Slurm クラスタを削除する

Slurm クラスタを削除する手順は次のとおりです。

  1. cluster-toolkit ディレクトリに移動します。

  2. Terraform ファイルと作成されたすべてのリソースを破棄します。

    ./gcluster destroy a4-high --auto-approve
    

Filestore インスタンスを削除する

デフォルトでは、Filestore インスタンスの cluster-toolkit ブループリントで deletion_protection 設定が true に設定されています。この設定により、環境を変更する際のデータ損失を防止できます。Filestore インスタンスを削除するには、削除保護を手動で無効にする必要があります。

次のステップ