FSDP を使用して A4 Slurm クラスタで Mixtral-8x7B をファインチューニングする

このチュートリアルでは、マルチノード、マルチ GPU の Slurm クラスタで mistralai/Mixtral-8x7B-v0.1 モデルを微調整する方法について説明します Google Cloud。このクラスタは、それぞれ 8 個の NVIDIA B200 GPU を搭載した 2 つの a4-highgpu-8g 仮想マシン(VM)インスタンスを使用します。

このチュートリアルで説明する主なプロセスは次のとおりです。

  1. Cluster Toolkit を使用して、本番環境グレードの高性能 Slurm クラスタをデプロイします。Google Cloud このデプロイの一環として、必要なソフトウェアがプリインストールされたカスタム VM イメージを作成します。また、共有 Lustre ファイル システムを設定し、高速ネットワークを構成します。
  2. クラスタがデプロイされたら、このチュートリアルに付属のスクリプト セットを使用して、分散型微調整ジョブを実行します。このジョブでは、 Hugging Face Transformer Reinforcement Learning (TRL) ライブラリを介してアクセスする PyTorch Fully Sharded Data Parallel(FSDP)を利用します。

このチュートリアルは、AI ワークロードを複数のノードと GPU に分散することに関心のある ML エンジニア、研究者、プラットフォーム管理者、オペレーター、データおよび AI スペシャリストを対象としています。

目標

  • Hugging Face を使用して Mixtral にアクセスする
  • 環境を準備する
  • 本番環境グレードの A4 High-GPU Slurm クラスタを作成してデプロイする。
  • FSDP を使用した分散トレーニング用にマルチノード環境を構成する。
  • Hugging Face trl.SFTTrainer クラスを使用して Mixtral モデルを微調整する。
  • データをローカル SSD にステージングする。
  • ジョブをモニタリングする。
  • クリーンアップする。

費用

このドキュメントでは、課金対象である次のコンポーネントを使用します。 Google Cloud

料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを生成できます。

新規の Google Cloud ユーザーは無料トライアルをご利用いただける場合があります。

始める前に

  1. アカウントにログインします。 Google Cloud を初めて使用する場合は、 アカウントを作成して、実際のシナリオで Google プロダクトのパフォーマンスを評価してください。 Google Cloud新規のお客様には、ワークロードの実行、テスト、デプロイができる無料クレジット $300 分を差し上げます。
  2. Google Cloud CLI をインストールします。

  3. 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  4. gcloud CLI を初期化するには、次のコマンドを実行します:

    gcloud init
  5. プロジェクトを作成または選択します Google Cloud

    プロジェクトを選択または作成するために必要なロール

    • プロジェクトを選択する: プロジェクトの選択には特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトを選択できます。
    • プロジェクトを作成する: プロジェクトを作成するには、プロジェクト作成者ロール (roles/resourcemanager.projectCreator)が必要です。これには resourcemanager.projects.create 権限が含まれています。ロールを付与する方法を確認する
    • プロジェクトを作成する: Google Cloud

      gcloud projects create PROJECT_ID

      PROJECT_ID は、作成する Google Cloud プロジェクトの名前に置き換えます。

    • 作成した Google Cloud プロジェクトを選択します。

      gcloud config set project PROJECT_ID

      PROJECT_ID は、 Google Cloud プロジェクトの名前に置き換えます。

  6. プロジェクトに対して課金が有効になっていることを確認します Google Cloud 。

  7. 必要な API を有効にします。

    API を有効にするために必要なロール

    API を有効にするには、 権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。serviceusage.services.enableロールを付与する方法を確認する

    gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com lustre.googleapis.com
  8. Google Cloud CLI をインストールします。

  9. 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  10. gcloud CLI を初期化するには、次のコマンドを実行します:

    gcloud init
  11. プロジェクトを作成または選択します Google Cloud

    プロジェクトを選択または作成するために必要なロール

    • プロジェクトを選択する: プロジェクトの選択には特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトを選択できます。
    • プロジェクトを作成する: プロジェクトを作成するには、プロジェクト作成者ロール (roles/resourcemanager.projectCreator)が必要です。これには resourcemanager.projects.create 権限が含まれています。ロールを付与する方法を確認する
    • プロジェクトを作成する: Google Cloud

      gcloud projects create PROJECT_ID

      PROJECT_ID は、作成する Google Cloud プロジェクトの名前に置き換えます。

    • 作成した Google Cloud プロジェクトを選択します。

      gcloud config set project PROJECT_ID

      PROJECT_ID は、 Google Cloud プロジェクトの名前に置き換えます。

  12. プロジェクトに対して課金が有効になっていることを確認します Google Cloud 。

  13. 必要な API を有効にします。

    API を有効にするために必要なロール

    API を有効にするには、 権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。serviceusage.services.enableロールを付与する方法を確認する

    gcloud services enable compute.googleapis.com file.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com lustre.googleapis.com
  14. ユーザー アカウントにロールを付与します。次の IAM ロールごとに次のコマンドを 1 回実行します。 roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/serviceusage.serviceUsageAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    次のように置き換えます。

    • PROJECT_ID: プロジェクト ID。
    • USER_IDENTIFIER: ユーザー アカウントの識別子。例: myemail@example.com
    • ROLE: ユーザー アカウントに付与する IAM ロール。
  15. プロジェクトのデフォルト サービス アカウントを有効にします: Google Cloud
    gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com \
        --project=PROJECT_ID

    PROJECT_NUMBER は、使用するプロジェクト番号に置き換えます。プロジェクト番号を確認するには、 既存のプロジェクトを取得するをご覧ください。

  16. 編集者ロール(roles/editor)をデフォルトのサービス アカウントに付与します。
    gcloud projects add-iam-policy-binding PROJECT_ID \
      --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com" \
      --role=roles/editor
  17. ユーザー アカウントのローカル認証情報を作成します。
    gcloud auth application-default login
  18. プロジェクトの OS Login を有効にします。
    gcloud compute project-info add-metadata --metadata=enable-oslogin=TRUE
  19. Hugging Face アカウントにログインするか、アカウントを作成します
  20. Cluster Toolkit の使用に必要な依存関係をインストールします。

Hugging Face を使用して Mixtral にアクセスする

Hugging Face を使用して Mixtral にアクセスする手順は次のとおりです。

  1. Hugging Face read access トークンを作成します。
  2. read アクセス トークン値をコピーして保存します。このチュートリアルの後半で使用します。

環境を準備する

クラスタのデプロイを準備するには、ローカルマシンで次の手順を行います。

  1. Cluster Toolkit リポジトリのクローンを作成します。 Google Cloud

    git clone https://github.com/GoogleCloudPlatform/cluster-toolkit.git
    
  2. Cloud Storage バケットを作成します。

    export BUCKET_NAME="your-unique-bucket-name"
    gcloud storage buckets create gs://${BUCKET_NAME}
    

A4 Slurm クラスタを作成する

A4 Slurm クラスタを作成する手順は次のとおりです。

  1. クローン作成した cluster-toolkit ディレクトリに移動します。

    cd cluster-toolkit
    
  2. Cluster Toolkit を初めて使用する場合は、gcluster バイナリをビルドします。

    make
    
  3. examples/machine-learning/a4-highgpu-8g ディレクトリに移動します。

    a4high-slurm-deployment.yaml ファイルを開き、次のように編集します。

    terraform_backend_defaults:
      type: gcs
      configuration:
        bucket: BUCKET_NAME
    
    vars:
      deployment_name: DEPLOYMENT_NAME
      project_id: PROJECT_ID
      region: REGION
      zone: ZONE
      a4h_cluster_size: 2
      a4h_reservation_name: RESERVATION_NAME
    

    次のように置き換えます。

    • BUCKET_NAME: : 前のセクションで作成した Cloud Storage バケットの名前。
    • PROJECT_ID: Cloud Storage が存在し、Slurm クラスタを作成する Google Cloud プロジェクトの ID。
    • REGION: 予約が存在するリージョン。
    • ZONE: 予約が存在するゾーン。
    • A4h_reservation_name: A4 予約の名前を使用します。
  4. a4high-slurm-blueprint.yaml ファイルを開き、次のように編集します。

    • filestore_homefs モジュールを削除します。
    • lustrefs モジュールと private-service-access モジュールを有効にします。
    • vars ブロックで、次のように構成します。
      1. Find slurm_vars を選択し、install_managed_lustretrue に設定します。
      2. per_unit_storage_throughput パラメータを 500 に設定します。
      3. size_gib パラメータを 36000 に設定します。
  5. クラスタをデプロイします。

    ./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml \
      examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml \
      --auto-approve
    

    ./gcluster deploy コマンドは、次の 2 段階のプロセスを開始します。

    • 最初のフェーズでは、すべてのソフトウェアがプリインストールされたカスタム イメージをビルドします。完了までに最大 35 分かかることがあります。
    • 2 番目のフェーズでは、そのカスタム イメージを使用してクラスタをデプロイします。このプロセスは、最初のフェーズよりも早く完了します。

ワークロードを準備する

ワークロードを準備する手順は次のとおりです。

  1. ワークロード スクリプトを作成します

  2. スクリプトを Slurm クラスタにアップロードします

  3. Slurm クラスタに接続します

  4. フレームワークとツールをインストールします

ワークロード スクリプトを作成する

微調整ワークロードで使用するスクリプトを作成する手順は次のとおりです。

  1. Python 仮想環境を設定するには、次の内容で install_environment.sh ファイルを作成します。

    #!/bin/bash
    # This script sets a reliable environment for FSDP training.
    # It is meant to be run on a compute node.
    set -e
    
    # --- 1. Create the Python virtual environment ---
    VENV_PATH="$HOME/.venv/venv-fsdp"
    if [ ! -d "$VENV_PATH" ]; then
      echo "--- Creating Python virtual environment at $VENV_PATH ---"
      python3 -m venv $VENV_PATH
    else
      echo "--- Virtual environment already exists at $VENV_PATH ---"
    fi
    
    source $VENV_PATH/bin/activate
    
    # --- 2. Install Dependencies ---
    echo "--- [STEP 2.1] Upgrading build toolchain ---"
    pip install --upgrade pip wheel packaging
    
    echo "--- [STEP 2.2] Installing PyTorch Nightly ---"
    pip install --force-reinstall --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
    
    echo "--- [STEP 2.3] Installing application dependencies ---"
    if [ -f "requirements-fsdp.txt" ]; then
        pip install -r requirements-fsdp.txt
    else
        echo "ERROR: requirements-fsdp.txt not found!"
        exit 1
    fi
    
    # --- [STEP 2.4] Build Flash Attention from Source ---
    echo "--- Building flash-attn from source... This will take a while. ---"
    # Use all available CPU cores to speed up the build
    MAX_JOBS=$(nproc) pip install flash-attn --no-build-isolation
    
    # --- 3. Download the Model ---
    echo "--- [STEP 2.5] Downloading Mixtral model ---"
    if [ -z "$HF_TOKEN" ]; then
      echo "ERROR: The HF_TOKEN environment variable is not set."; exit 1;
    fi
    pip install huggingface_hub[cli]
    
    # Execute the CLI using its full, explicit path
    $VENV_PATH/bin/huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir ~/Mixtral-8x7B-v0.1 --token $HF_TOKEN
    
    echo "--- Environment setup complete. ---"
    
  2. トレーニング スクリプトの Python 依存関係を指定するには、次の内容で requirements-fsdp.txt ファイルを作成します。

    transformers==4.55.0
    datasets==4.0.0
    peft==0.16.0
    accelerate==1.9.0
    trl==0.21.0
    
    # Other dependencies
    sentencepiece==0.2.0
    protobuf==6.31.1
    
  3. メイン トレーニング スクリプトとして train-mixtral.py を指定します。

    import torch
    from torch.distributed.fsdp import MixedPrecision
    from datasets import load_dataset
    import shutil
    import os
    import torch.distributed as dist
    
    from peft import LoraConfig, PeftModel, get_peft_model
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        HfArgumentParser,
    )
    
    from torch.distributed import get_rank, get_world_size
    
    from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
    from trl import SFTTrainer
    from dataclasses import dataclass, field
    from typing import Optional
    
    @dataclass
    class ScriptArguments:
        model_id: str = field(default="mistralai/Mixtral-8x7B-v0.1", metadata={"help": "Hugging Face model ID from the Hub"})
        dataset_name: str = field(default="philschmid/gretel-synthetic-text-to-sql", metadata={"help": "Dataset from the Hub"})
        run_inference_after_training: bool = field(default=False, metadata={"help": "Run sample inference on rank 0 after training"})
        dataset_subset_size: Optional[int] = field(default=None, metadata={"help": "Number of samples to use from the dataset for training. If None, uses the full dataset."})
    
    @dataclass
    class PeftArguments:
        lora_r: int = field(default=16, metadata={"help": "LoRA attention dimension"})
        lora_alpha: int = field(default=32, metadata={"help": "LoRA alpha scaling factor"})
        lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout probability"})
    
    @dataclass
    class SftTrainingArguments(TrainingArguments):
        max_length: Optional[int] = field(default=2048, metadata={"help": "The maximum sequence length for SFTTrainer"})
        packing: Optional[bool] = field(default=False, metadata={"help": "Enable packing for SFTTrainer"})
        ddp_find_unused_parameters: Optional[bool] = field(default=False, metadata={"help": "When using FSDP activation checkpointing, this must be set to False for Mixtral"})
    
    def formatting_prompts_func(example):
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        user_prompt = f"### SCHEMA:\n{example['sql_context']}\n\n### USER QUERY:\n{example['sql_prompt']}"
        response = f"\n\n### SQL QUERY:\n{example['sql']}"
        return f"{system_message}\n\n{user_prompt}{response}"
    
    def main():
        parser = HfArgumentParser((ScriptArguments, PeftArguments, SftTrainingArguments))
        script_args, peft_args, training_args = parser.parse_args_into_dataclasses()
    
        training_args.gradient_checkpointing = True
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
    
        training_args.optim = "adamw_torch_fused"
    
        bf16_policy = MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        )
    
        training_args.fsdp = "full_shard"
        training_args.fsdp_config = {
            "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
            "fsdp_transformer_layer_cls_to_wrap": [MixtralDecoderLayer],
            "fsdp_state_dict_type": "SHARDED_STATE_DICT",
            "fsdp_offload_params": False,
            "fsdp_forward_prefetch": True,
            "fsdp_mixed_precision_policy": bf16_policy
        }
    
        tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, trust_remote_code=True)
    
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right"
    
        model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            attn_implementation="flash_attention_2",
        )
    
        peft_config = LoraConfig(
            r=peft_args.lora_r,
            lora_alpha=peft_args.lora_alpha,
            lora_dropout=peft_args.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        )
    
        model = get_peft_model(model, peft_config)
    
        data_splits = load_dataset(script_args.dataset_name)
    
        dataset = data_splits["train"]
        eval_dataset = data_splits["test"]
    
        if script_args.dataset_subset_size is not None:
            dataset = dataset.select(range(script_args.dataset_subset_size))
    
        dataset = dataset.shuffle(seed=training_args.seed)
    
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            eval_dataset=eval_dataset,
            formatting_func=formatting_prompts_func,
            processing_class=tokenizer,
        )
    
        trainer.train()
    
        dist.barrier()
        if trainer.is_world_process_zero():
            best_model_path = trainer.state.best_model_checkpoint
    
            final_model_dir = os.path.join(training_args.output_dir, "final_best_model")
            print(f"Copying best model to: {final_model_dir}")
    
            if os.path.exists(final_model_dir):
                shutil.rmtree(final_model_dir)
            shutil.copytree(best_model_path, final_model_dir)
    
            if script_args.run_inference_after_training:
                del model, trainer
                torch.cuda.empty_cache()
                run_post_training_inference(script_args, final_model_dir, tokenizer)
    
    def run_post_training_inference(script_args, best_model_path, tokenizer):
        print("\n" + "="*50)
        print("=== RUNNING POST-TRAINING INFERENCE TEST ===")
        print("="*50 + "\n")
    
        base_model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        )
        model = PeftModel.from_pretrained(base_model, best_model_path)
        model = model.merge_and_unload()
        model.eval()
    
        # Define the test case
        schema = "CREATE TABLE artists (Name TEXT, Country TEXT, Genre TEXT)"
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        question = "Show me all artists from the Country just north of the USA."
    
        prompt = f"{system_message}\n\n### SCHEMA:\n{schema}\n\n### USER QUERY:\n{question}\n\n### SQL QUERY:\n"
    
        print(f"Test Prompt:\n{prompt}")
    
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
        print("\n--- Generating SQL... ---")
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False,
            temperature=None,
            top_p=None,
        )
    
        generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
    
        print(f"\n--- Generated SQL Query ---")
        print(generated_sql)
        print("\n" + "="*50)
        print("=== INFERENCE TEST COMPLETE ===")
        print("="*50 + "\n")
    
    if __name__ == "__main__":
        main()
    
  4. Slurm クラスタで実行するジョブのタスクを指定するには、次の内容で train-mixtral.sh ファイルを作成します。

    #!/bin/bash
    #SBATCH --job-name=mixtral-fsdp
    #SBATCH --nodes=2
    #SBATCH --ntasks-per-node=8
    #SBATCH --gpus-per-node=8
    #SBATCH --partition=a4high
    #SBATCH --output=mixtral-%j.out
    #SBATCH --error=mixtral-%j.err
    
    set -e
    set -x
    
    echo "--- Slurm Job Started ---"
    
    # --- Define Paths ---
    LOCAL_SSD_PATH="/mnt/localssd/job_${SLURM_JOB_ID}"
    VENV_PATH="${HOME}/.venv/venv-fsdp"
    MODEL_PATH="${HOME}/Mixtral-8x7B-v0.1"
    
    # --- STAGE 1: Stage Data to Local SSD on Each Node ---
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "
    echo '--- Staging on node: $(hostname) ---'
    mkdir -p ${LOCAL_SSD_PATH}
    
    echo 'Copying virtual environment...'
    rsync -a -q ${VENV_PATH}/ ${LOCAL_SSD_PATH}/venv/
    
    echo 'Copying model weights...'
    rsync -a ${MODEL_PATH}/ ${LOCAL_SSD_PATH}/model/
    
    mkdir -p ${LOCAL_SSD_PATH}/hf_cache
    
    echo '--- Staging on $(hostname) complete ---'
    "
    echo "--- Staging complete on all nodes ---"
    
    # --- STAGE 2: Run the Training Job ---
    echo "--- Launching Distributed Training with GIB NCCL Plugin ---"
    nodes=( $( scontrol show hostnames "$SLURM_JOB_NODELIST" ) )
    head_node=${nodes[0]}
    head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
    
    export MASTER_ADDR=$head_node_ip
    export MASTER_PORT=29500
    
    export NCCL_SOCKET_IFNAME=enp0s19
    
    export NCCL_NET=gIB
    
    # export NCCL_DEBUG=INFO # Un-comment to diagnose NCCL issues if needed
    
    srun --cpu-bind=none --accel-bind=g bash -c '
    # Activate the environment from the local copy
    source '${LOCAL_SSD_PATH}'/venv/bin/activate
    
    # Point Hugging Face cache to the local SSD
    export HF_HOME='${LOCAL_SSD_PATH}'/hf_cache
    
    export RANK=$SLURM_PROCID
    export WORLD_SIZE=$SLURM_NTASKS
    export LOCAL_RANK=$SLURM_LOCALID
    
    export LD_LIBRARY_PATH=/usr/local/gib/lib64:$LD_LIBRARY_PATH
    source /usr/local/gib/scripts/set_nccl_env.sh
    
    # --- Launch the training ---
    python \
        '${SLURM_SUBMIT_DIR}'/train-mixtral.py \
        --model_id="'${LOCAL_SSD_PATH}'/model/" \
        --output_dir="${HOME}/outputs/mixtral_job_${SLURM_JOB_ID}" \
        --dataset_name="philschmid/gretel-synthetic-text-to-sql" \
        --seed=900913 \
        --bf16=True \
        --num_train_epochs=3 \
        --per_device_train_batch_size=32 \
        --gradient_accumulation_steps=4 \
        --learning_rate=4e-5 \
        --logging_steps=3 \
        --lora_r=32 \
        --lora_alpha=32 \
        --lora_dropout=0.05 \
        --eval_strategy=steps \
        --eval_steps=10 \
        --save_strategy=steps \
        --save_steps=10 \
        --load_best_model_at_end=False \
        --metric_for_best_model=eval_loss \
        --run_inference_after_training \
        --dataset_subset_size=67000
    '
    
    # --- STAGE 3: Cleanup ---
    echo "--- Cleaning up local SSD on all nodes ---"
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "rm -rf ${LOCAL_SSD_PATH}"
    
    echo "--- Slurm Job Finished ---"
    

スクリプトを Slurm クラスタにアップロードする

前のセクションで作成したスクリプトを Slurm クラスタにアップロードする手順は次のとおりです。

  1. ログインノードを特定するには、 プロジェクト内のすべての VM を一覧表示します。

    gcloud compute instances list
    

    ログインノードの名前は a4-high-login-001 のようになります。

  2. スクリプトをログインノードのホーム ディレクトリにアップロードします。

    # Run this from your local machine where you created the files
    LOGIN_NODE_NAME="your-login-node-name" # e.g., a4high-login-001
    PROJECT_ID="your-gcp-project-id"
    ZONE="your-cluster-zone" # e.g., us-west4-a
    
    gcloud compute scp --project="$PROJECT_ID" --zone="$ZONE" --tunnel-through-iap \
      ./install_environment.sh \
      ./requirements-fsdp.txt \
      ./train-mixtral.py \
      ./train-mixtral.sh \
      "${LOGIN_NODE_NAME}":~/
    

Slurm クラスタに接続する

SSH を使用してログインノードに接続し、Slurm クラスタに接続します。

gcloud compute ssh $LOGIN_NODE_NAME \
    --project=$PROJECT_ID \
    --tunnel-through-iap \
    --zone=$ZONE

フレームワークとツールをインストールする

ログインノードに接続したら、フレームワークとツールをインストールします。

  1. Hugging Face トークンをエクスポートします。

    # On the login node
    export HF_TOKEN="hf_..." # Replace with your token
    
  2. コンピューティング ノードでインストール スクリプトを実行します。

    # On the login node
    srun \
      --job-name=env-setup \
      --nodes=1 \
      --ntasks=1 \
      --gpus-per-node=1 \
      --partition=a4high \
      bash ./install_environment.sh
    

    このコマンドは、仮想環境を設定し、すべての依存関係をインストールして、Mixtral モデルの重みを ~/Mixtral-8x7B-v0.1 にダウンロードします。このプロセスが完了するまでに 30 分以上かかることがあります。

微調整ワークロードを開始する

ワークロードのトレーニングを開始する手順は次のとおりです。

  1. ジョブを Slurm スケジューラに送信します。

    # On the login node
    sbatch train-mixtral.sh
    
  2. Slurm クラスタのログインノードで、home ディレクトリに作成された出力ファイルを確認して、ジョブの進行状況をモニタリングできます。

    # On the login node
    tail -f mixtral-*.out
    

    ジョブが正常に開始されると、.err ファイルにプログレスバーが表示され、ジョブの進行状況に応じて更新されます。

    ジョブには主に次の 2 つのフェーズがあります。

    • 各コンピューティング ノードのローカル SSD に大きなベースモデルをコピーする。
    • モデルのコピーが完了すると開始されるトレーニング ジョブ。

    ジョブ全体の実行時間は約 40 分です。

クリーンアップする。

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

Slurm クラスタを削除する

Slurm クラスタを削除する手順は次のとおりです。

  1. cluster-toolkit ディレクトリに移動します。

  2. Terraform ファイルと作成したすべてのリソースを破棄します。

    ./gcluster destroy DEPLOYMENT_NAME --auto-approve
    

プロジェクトの削除

プロジェクトを削除する: Google Cloud

gcloud projects delete PROJECT_ID

次のステップ