XLA を使用して高効率の大規模言語モデル(LLM)のサービスを提供する Hex-LLM は、Cloud TPU ハードウェア用に設計および最適化された Vertex AI LLM サービング フレームワークです。Hex-LLM は、継続的なバッチ処理や PagedAttention などの LLM サービング テクノロジーと、XLA と Cloud TPU 向けに調整された Vertex AI の最適化を組み合わせています。これは、オープンソース モデル用の Cloud TPU で、効率的で低コストの LLM サービスを提供します。
Hex-LLM は、モデル プレイグラウンド、ワンクリック デプロイ、ノートブックを通じて Model Garden で利用できます。
機能
Hex-LLM は、XLA と Cloud TPU 向けの Google 独自の最適化を備えたオープンソース プロジェクトに基づいています。Hex-LLM は、頻繁に使用される LLM のサービスを提供する際に、高いスループットと低レイテンシを実現します。
Hex-LLM には次の最適化が含まれています。
- 多数の同時リクエストでモデルがハードウェアを十分活用できるようにするトークンベースの連続バッチ処理アルゴリズム。
- XLA 用に最適化されたアテンション カーネルの完全な置き換え。
- 複数の Cloud TPU チップで LLM を効率的に実行するために、高度に最適化された重みシャーディング手法による柔軟でコンポーザブルなデータ並列処理とテンソル並列処理の戦略。
Hex-LLM は、高密度の LLM からスパース LLM まで幅広くサポートしています。
- Gemma 2B および 7B
- Gemma-2 9B および 27B
- Llama-2 7B、13B、70B
- Llama-3 8B および 70B
- Llama-3.1 8B および 70B
- Llama-3.2 1B および 3B
- Llama-3.3 70B
- Llama-Guard-3 1B および 8B
- Llama-4 Scout-17B-16E
- Mistral 7B
- Mixtral 8x7B および 8x22B
- Phi-3 mini および Phi-3 medium
- Phi-4、Phi-4 reasoning、reasoning plus
- Qwen-2 0.5B、1.5B、7B
- Qwen-2.5 0.5B、1.5B、7B、14B、32B
Hex-LLM には、次のようなさまざまな機能も用意されています。
- Hex-LLM は単一のコンテナに含まれています。Hex-LLM は、API サーバー、推論エンジン、サポートされているモデルを単一の Docker イメージにパッケージ化してデプロイします。
- Hugging Face モデル形式に対応しています。Hex-LLM は、ローカル ディスク、Hugging Face Hub、Cloud Storage バケットから Hugging Face モデルを読み込むことができます。
- bitsandbytes と AWQ を使用した量子化。
- 動的 LoRA 読み込み。Hex-LLM は、サービング中にリクエスト引数を読み取ることで LoRA 重みを読み込むことができます。
高度な機能
Hex-LLM は、次の高度な機能をサポートしています。
- マルチホスト サービング
- 分離型サービング [試験運用版]
- 接頭辞のキャッシュ保存
- 4 ビット量子化のサポート
マルチホスト サービング
Hex-LLM で、マルチホスト TPU スライスを使用したモデルのサービングがサポートされるようになりました。この機能を使用すると、単一のホスト TPU VM に読み込むことができない大規模なモデルをサービングできます。この VM には最大 8 個の v5e コアが含まれています。
この機能を有効にするには、Hex-LLM コンテナ引数で --num_hosts
を設定し、Vertex AI SDK モデル アップロード リクエストで --tpu_topology
を設定します。次の例は、Llama 3.1 70B bfloat16 モデルを提供する TPU 4x4 v5e トポロジで Hex-LLM コンテナをデプロイする方法を示しています。
hexllm_args = [
"--host=0.0.0.0",
"--port=7080",
"--model=meta-llama/Meta-Llama-3.1-70B",
"--data_parallel_size=1",
"--tensor_parallel_size=16",
"--num_hosts=4",
"--hbm_utilization_factor=0.9",
]
model = aiplatform.Model.upload(
display_name=model_name,
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=["python", "-m", "hex_llm.server.api_server"],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=env_vars,
serving_container_shared_memory_size_mb=(16 * 1024), # 16 GB
serving_container_deployment_timeout=7200,
location=TPU_DEPLOYMENT_REGION,
)
model.deploy(
endpoint=endpoint,
machine_type=machine_type,
tpu_topology="4x4",
deploy_request_timeout=1800,
service_account=service_account,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
マルチホスト TPU トポロジで Hex-LLM コンテナをデプロイするエンドツーエンドのチュートリアルについては、Vertex AI Model Garden - Llama 3.1(デプロイ)ノートブックをご覧ください。
一般に、マルチホスト配信を有効にするために必要な変更は次のとおりです。
- 引数
--tensor_parallel_size
を TPU トポロジ内のコアの合計数に設定します。 - 引数
--num_hosts
を TPU トポロジ内のホスト数に設定します。 - Vertex AI SDK モデル アップロード API を使用して
--tpu_topology
を設定します。
分離型サービング [試験運用版]
Hex-LLM で、試験運用版機能として分離型サービングがサポートされるようになりました。これは単一ホスト設定でのみ有効にでき、パフォーマンスは調整中です。
分離されたサービングは、各リクエストの最初のトークンまでの時間(TTFT)と出力トークンあたりの時間(TPOT)、およびサービングのスループット全体をバランスさせる効果的な方法です。プリフィル フェーズとデコード フェーズを別々のワークロードに分離し、互いに干渉しないようにします。この方法は、レイテンシ要件が厳しいシナリオで特に有用です。
この機能を有効にするには、Hex-LLM コンテナ引数で --disagg_topo
を設定します。次の例は、Llama 3.1 8B bfloat16 モデルを提供する TPU v5e-8 に Hex-LLM コンテナをデプロイする方法を示しています。
hexllm_args = [
"--host=0.0.0.0",
"--port=7080",
"--model=meta-llama/Llama-3.1-8B",
"--data_parallel_size=1",
"--tensor_parallel_size=2",
"--disagg_topo=3,1",
"--hbm_utilization_factor=0.9",
]
model = aiplatform.Model.upload(
display_name=model_name,
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=["python", "-m", "hex_llm.server.api_server"],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=env_vars,
serving_container_shared_memory_size_mb=(16 * 1024), # 16 GB
serving_container_deployment_timeout=7200,
location=TPU_DEPLOYMENT_REGION,
)
model.deploy(
endpoint=endpoint,
machine_type=machine_type,
deploy_request_timeout=1800,
service_account=service_account,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
--disagg_topo
引数は、"number_of_prefill_workers,number_of_decode_workers"
形式の文字列を受け入れます。前の例では、"3,1"
に設定して、3 つのプリフィル ワーカーと 1 つのデコード ワーカーを構成しています。各ワーカーは 2 つの TPU v5e コアを使用します。
接頭辞のキャッシュ保存
プレフィックス キャッシュ保存により、プロンプトの先頭に同じコンテンツ(会社全体の序文、共通のシステム指示、マルチターンの会話履歴など)があるプロンプトの最初のトークンまでの時間(TTFT)が短縮されます。Hex-LLM は、同じ入力トークンを繰り返し処理するのではなく、処理された入力トークン計算の一時キャッシュを保持して TTFT を改善できます。
この機能を有効にするには、Hex-LLM コンテナの引数で --enable_prefix_cache_hbm
を設定します。次の例は、Llama 3.1 8B bfloat16 モデルを提供する TPU v5e-8 に Hex-LLM コンテナをデプロイする方法を示しています。
hexllm_args = [
"--host=0.0.0.0",
"--port=7080",
"--model=meta-llama/Llama-3.1-8B",
"--data_parallel_size=1",
"--tensor_parallel_size=4",
"--hbm_utilization_factor=0.9",
"--enable_prefix_cache_hbm",
]
model = aiplatform.Model.upload(
display_name=model_name,
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=["python", "-m", "hex_llm.server.api_server"],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=env_vars,
serving_container_shared_memory_size_mb=(16 * 1024), # 16 GB
serving_container_deployment_timeout=7200,
location=TPU_DEPLOYMENT_REGION,
)
model.deploy(
endpoint=endpoint,
machine_type=machine_type,
deploy_request_timeout=1800,
service_account=service_account,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
Hex-LLM は、接頭辞のキャッシュ保存を使用して、特定の長さ(デフォルトでは 512 トークン。prefill_len_padding
を使用して構成可能)を超えるプロンプトのパフォーマンスを最適化します。キャッシュ ヒットはこの値の増分で発生し、キャッシュに保存されたトークン数が常に prefill_len_padding
の倍数になるようにします。チャット補完 API レスポンスの usage.prompt_tokens_details
の cached_tokens
フィールドは、プロンプト トークンのうちキャッシュ ヒットしたトークンの数を示します。
"usage": {
"prompt_tokens": 643,
"total_tokens": 743,
"completion_tokens": 100,
"prompt_tokens_details": {
"cached_tokens": 512
}
}
チャンク化されたプレフィル
チャンク化されたプリフィルは、リクエスト プリフィルをより小さなチャンクに分割し、プリフィルとデコードを 1 つのバッチステップに統合します。Hex-LLM は、チャンク化されたプリフィルを実装して、最初のトークンまでの時間(TTFT)と出力トークンあたりの時間(TPOT)のバランスを取り、スループットを向上させます。
この機能を有効にするには、Hex-LLM コンテナの引数で --enable_chunked_prefill
を設定します。次の例は、Llama 3.1 8B モデルを提供する TPU v5e-8 に Hex-LLM コンテナをデプロイする方法を示しています。
hexllm_args = [
"--host=0.0.0.0",
"--port=7080",
"--model=meta-llama/Llama-3.1-8B",
"--data_parallel_size=1",
"--tensor_parallel_size=4",
"--hbm_utilization_factor=0.9",
"--enable_chunked_prefill",
]
model = aiplatform.Model.upload(
display_name=model_name,
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=["python", "-m", "hex_llm.server.api_server"],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=env_vars,
serving_container_shared_memory_size_mb=(16 * 1024), # 16 GB
serving_container_deployment_timeout=7200,
location=TPU_DEPLOYMENT_REGION,
)
model.deploy(
endpoint=endpoint,
machine_type=machine_type,
deploy_request_timeout=1800,
service_account=service_account,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
4 ビット量子化のサポート
量子化は、通常の BF16 や FP32 ではなく、INT8 や INT4 などの低精度のデータ型で重みやアクティベーションを表すことで、推論の実行に必要な計算コストとメモリコストを削減する手法です。
Hex-LLM は INT8 の重みのみの量子化をサポートしています。拡張サポートには、AWQ ゼロポイント量子化を使用して量子化された INT4 重みを持つモデルが含まれます。Hex-LLM は、Mistral、Mixtral、Llama モデル ファミリーの INT4 バリアントをサポートしています。
量子化モデルのサービングに、追加のフラグは必要ありません。
Model Garden を使ってみる
Hex-LLM Cloud TPU サービング コンテナは Model Garden に統合されています。このサービング テクノロジーは、さまざまなモデルのプレイグラウンド、ワンクリック デプロイ、Colab Enterprise ノートブックの例で利用できます。
プレイグラウンドを使用する
Model Garden のプレイグラウンドは、事前にデプロイされた Vertex AI エンドポイントであり、モデルカードでリクエストを送信することでアクセスできます。
プロンプトを入力し、必要に応じてリクエストの引数を指定します。
[送信] をクリックして、モデルのレスポンスをすばやく取得します。
ワンクリック デプロイを使用する
モデルカードを使用して、Hex-LLM を備えたカスタム Vertex AI エンドポイントをデプロイできます。
モデルカード ページに移動し、[デプロイ] をクリックします。
使用するモデルのバリエーションに対して、デプロイするための Cloud TPU v5e マシンタイプを選択します。
下部にある [デプロイ] をクリックして、デプロイ プロセスを開始します。2 通のメール通知が届きます。1 通はモデルがアップロードされたとき、もう 1 通はエンドポイントの準備が整ったときです。
Colab Enterprise ノートブックを使用する
柔軟性とカスタマイズのために、Colab Enterprise ノートブックの例を使用して、Vertex AI SDK for Python を使用して Hex-LLM で Vertex AI エンドポイントをデプロイできます。
モデルカードのページに移動し、[ノートブックを開く] をクリックします。
Vertex Serving ノートブックを選択します。ノートブックが Colab Enterprise で開きます。
ノートブックを実行して、Hex-LLM を使用してモデルをデプロイし、予測リクエストをエンドポイントに送信します。デプロイのコード スニペットは次のとおりです。
hexllm_args = [
f"--model=google/gemma-2-9b-it",
f"--tensor_parallel_size=4",
f"--hbm_utilization_factor=0.8",
f"--max_running_seqs=512",
]
hexllm_envs = {
"PJRT_DEVICE": "TPU",
"MODEL_ID": "google/gemma-2-9b-it",
"DEPLOY_SOURCE": "notebook",
}
model = aiplatform.Model.upload(
display_name="gemma-2-9b-it",
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=[
"python", "-m", "hex_llm.server.api_server"
],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=hexllm_envs,
serving_container_shared_memory_size_mb=(16 * 1024),
serving_container_deployment_timeout=7200,
)
endpoint = aiplatform.Endpoint.create(display_name="gemma-2-9b-it-endpoint")
model.deploy(
endpoint=endpoint,
machine_type="ct5lp-hightpu-4t",
deploy_request_timeout=1800,
service_account="<your-service-account>",
min_replica_count=1,
max_replica_count=1,
)
Colab Enterprise ノートブックの例:
サーバー引数と環境変数を構成する
次の引数を設定して、Hex-LLM サーバーを起動できます。引数を調整して、目的のユースケースと要件に最適なものにすることができます。引数は、最も簡単なデプロイ エクスペリエンスを実現するために、ワンクリック デプロイ用に事前定義されています。引数をカスタマイズするには、ノートブックの例を参考にして、引数を適宜設定します。
モデル
--model
: 読み込むモデル。Hugging Face モデル ID、Cloud Storage バケットパス(gs://my-bucket/my-model
)、ローカルパスを指定できます。モデル アーティファクトは、Hugging Face 形式に従い、モデルの重み付けに safetensors ファイルを使用することが想定されています。Llama、Gemma 2、Mistral/Mixtral では、BitsAndBytes int8 と AWQ 量子化モデル アーティファクトがサポートされています。--tokenizer
: 読み込むトークナイザ。Hugging Face モデル ID、Cloud Storage バケットパス(gs://my-bucket/my-model
)、ローカルパスを指定できます。この引数が設定されていない場合、デフォルトで--model
の値になります。--tokenizer_mode
: トークナイザ モード。選択肢は["auto", "slow"]
です。デフォルト値は"auto"
です。"auto"
に設定すると、高速トークナイザーが使用可能であれば使用されます。遅いトークナイザーは Python で記述され、Transformers ライブラリで提供されます。パフォーマンスの改善を提供する高速トークナイザーは Rust で記述され、Tokenizers ライブラリで提供されます。詳細については、Hugging Face のドキュメントをご覧ください。--trust_remote_code
: Hugging Face モデル リポジトリで定義されたリモートコード ファイルを許可するかどうか。デフォルト値はFalse
です。--load_format
: 読み込むモデル チェックポイントの形式。選択肢は["auto", "dummy"]
です。デフォルト値は"auto"
です。"auto"
に設定すると、モデルの重みが safetensors 形式で読み込まれます。"dummy"
に設定すると、モデルの重みがランダムに初期化されます。これを"dummy"
に設定すると、テストに役立ちます。--max_model_len
: モデルで提供するコンテキストの最大長(入力長と出力長の合計)。デフォルト値は、Hugging Face 形式のモデル構成ファイルconfig.json
から読み取られます。最大コンテキスト長が大きいほど、必要な TPU メモリが増えます。--sliding_window
: 設定すると、この引数はスライディング ウィンドウ アテンションのモデルのウィンドウ サイズをオーバーライドします。この引数を大きな値に設定すると、注意機構はより多くのトークンを含み、標準のセルフ アテンションの効果に近づきます。この引数は試験運用でのみ使用してください。一般的なユースケースでは、モデルの元のウィンドウ サイズを使用することをおすすめします。--seed
: すべての乱数生成ツールを初期化するためのシード。この引数を変更すると、次のトークンとしてサンプリングされるトークンが変更され、同じプロンプトで生成される出力に影響する可能性があります。デフォルト値は0
です。
推論エンジン
--num_hosts
: 実行するホストの数。デフォルト値は1
です。詳細については、TPU v5e の構成に関するドキュメントをご覧ください。--disagg_topo
: 試験運用版の機能である分離型サービングを使用して、プリフィル ワーカーとデコード ワーカーの数を定義します。デフォルト値はNone
です。引数の形式は"number_of_prefill_workers,number_of_decode_workers"
です。--data_parallel_size
: データ並列レプリカの数。デフォルト値は1
です。これを1
からN
に設定すると、レイテンシを維持しながらスループットが約N
向上します。--tensor_parallel_size
: テンソル並列レプリカの数。デフォルト値は1
です。一般に、テンソル並列レプリカの数を増やすと、行列のサイズが縮小され、行列乗算が高速化されるため、レイテンシが改善されます。--worker_distributed_method
: ワーカーを起動する分散メソッド。マルチプロセッシング モジュールの場合はmp
、Ray ライブラリの場合はray
を使用します。デフォルト値はmp
です。--enable_jit
: JIT(ジャストインタイム コンパイル)モードを有効にするかどうか。デフォルト値はTrue
です。--no-enable_jit
に設定すると、無効になります。JIT モードを有効にすると、推論のパフォーマンスが向上しますが、初期コンパイルに追加の時間がかかります。一般に、推論パフォーマンスのメリットはオーバーヘッドを上回ります。--warmup
: 初期化中にサンプル リクエストでサーバーをウォームアップするかどうか。デフォルト値はTrue
です。--no-warmup
を設定すると、無効になります。最初のリクエストはコンパイルの負荷が大きいため、ウォームアップをおすすめします。--max_prefill_seqs
: 1 回の反復でプリフィル用にスケジュールできるシーケンスの最大数。デフォルト値は1
です。この値を大きくすると、サーバーが達成できるスループットが高くなりますが、レイテンシに悪影響を及ぼす可能性があります。--prefill_seqs_padding
: サーバーは、この値の倍数になるまでプリフィル バッチサイズをパディングします。デフォルト値は8
です。この値を大きくすると、モデルの再コンパイル時間は短縮されますが、無駄な計算と推論のオーバーヘッドが増加します。最適な設定はリクエスト トラフィックによって異なります。--prefill_len_padding
: サーバーは、この値の倍数になるまでシーケンス長をパディングします。デフォルト値は512
です。この値を大きくすると、モデルの再コンパイル時間は短縮されますが、無駄な計算と推論のオーバーヘッドが増加します。最適な設定は、リクエストのデータ分布によって異なります。--max_decode_seqs
/--max_running_seqs
: 反復ごとにデコード用にスケジュールできるシーケンスの最大数。デフォルト値は256
です。この値を大きくすると、サーバーが達成できるスループットが高くなりますが、レイテンシに悪影響を及ぼす可能性があります。--decode_seqs_padding
: サーバーは、デコード バッチサイズをこの値の倍数になるまでパディングします。デフォルト値は8
です。この値を増やすと、モデルの再コンパイル時間が短縮されますが、無駄な計算と推論のオーバーヘッドが増加します。最適な設定はリクエスト トラフィックによって異なります。--decode_blocks_padding
: デコード中に、シーケンスの Key-Value キャッシュ(KV キャッシュ)に使用されるメモリブロックの数が、この値の倍数になるまでパディングされます。デフォルト値は128
です。この値を大きくすると、モデルの再コンパイル時間は短縮されますが、無駄な計算と推論のオーバーヘッドが増加します。最適な設定は、リクエストのデータ分布によって異なります。--enable_prefix_cache_hbm
: HBM で接頭辞キャッシュ保存を有効にするかどうか。デフォルト値はFalse
です。この引数を設定すると、以前のリクエストの共有プレフィックスの計算を再利用して、パフォーマンスを向上させることができます。--enable_chunked_prefill
: チャンク化されたプリフィルを有効にするかどうか。デフォルト値はFalse
です。この引数を設定すると、コンテキストの長さを長くしてパフォーマンスを向上させることができます。
メモリ管理
--hbm_utilization_factor
: モデルの重みが読み込まれた後に KV キャッシュに割り当てることができる無料の Cloud TPU 高帯域幅メモリ(HBM)の割合。デフォルト値は0.9
です。この引数を大きな値に設定すると、KV キャッシュサイズが増加し、スループットが向上する可能性がありますが、初期化時と実行時に Cloud TPU HBM が不足するリスクが高まります。--num_blocks
: KV キャッシュに割り当てるデバイスブロックの数。この引数が設定されている場合、サーバーは--hbm_utilization_factor
を無視します。この引数が設定されていない場合、サーバーは HBM 使用量をプロファイリングし、--hbm_utilization_factor
に基づいて割り当てるデバイス ブロックの数を計算します。この引数を大きな値に設定すると、KV キャッシュサイズが増加し、スループットが向上する可能性がありますが、初期化時と実行時に Cloud TPU HBM が不足するリスクが高まります。--block_size
: ブロックに保存されているトークンの数。選択肢は[8, 16, 32, 2048, 8192]
です。デフォルト値は32
です。この引数を大きな値に設定すると、メモリの無駄遣いが増える代わりに、ブロック管理のオーバーヘッドが削減されます。パフォーマンスへの正確な影響は、経験的に判断する必要があります。
動的 LoRA
--enable_lora
: Cloud Storage からの動的 LoRA アダプタの読み込みを有効にするかどうか。デフォルト値はFalse
です。これは、Llama モデル ファミリーでサポートされています。--max_lora_rank
: リクエストで定義された LoRA アダプターでサポートされている最大 LoRA ランク。デフォルト値は16
です。この引数を大きい値に設定すると、サーバーで使用できる LoRA アダプターの柔軟性が高まりますが、LoRA 重みに割り当てられる Cloud TPU HBM の量が増加し、スループットが低下します。--enable_lora_cache
: 動的 LoRA アダプタのキャッシュ保存を有効にするかどうか。デフォルト値はTrue
です。--no-enable_lora_cache
に設定すると、無効になります。キャッシュを使用すると、以前に使用した LoRA アダプタ ファイルを再ダウンロードする必要がなくなるため、パフォーマンスが向上します。--max_num_mem_cached_lora
: TPU メモリ キャッシュに保存される LoRA アダプターの最大数。デフォルト値は16
です。この引数を大きい値に設定すると、キャッシュ ヒットの可能性が高くなりますが、Cloud TPU HBM の使用量が増加します。
次の環境変数を使用してサーバーを構成することもできます。
HEX_LLM_LOG_LEVEL
: 生成されるロギング情報の量を制御します。デフォルト値はINFO
です。これは、ロギング モジュールで定義されている標準の Python ロギングレベルのいずれかに設定します。HEX_LLM_VERBOSE_LOG
: 詳細なロギング出力を有効にするかどうか。指定できる値はtrue
またはfalse
です。デフォルト値はfalse
です。
サーバー引数を調整する
サーバー引数は相互に関連しており、サービング パフォーマンスに総合的な影響を与えます。たとえば、--max_model_len=4096
の設定を大きくすると、TPU メモリの使用量が増えるため、メモリ割り当てを大きくし、バッチ処理を減らす必要があります。また、ユースケースによって決まる引数もあれば、チューニング可能な引数もあります。Hex-LLM サーバーを構成するワークフローは次のとおりです。
- 対象のモデル ファミリーとモデル バリアントを特定します。たとえば、Llama 3.1 8B Instruct などです。
- モデルのサイズと精度に基づいて、必要な TPU メモリの下限を見積もります(
model_size * (num_bits / 8)
)。8B モデルと bfloat16 精度の場合、必要な TPU メモリの下限は8 * (16 / 8) = 16 GB
になります。 - 必要な TPU v5e チップの数を推定します。各 v5e チップは 16 GB を提供します。
tpu_memory / 16
。8B モデルと bfloat16 精度の場合、1 つ以上のチップが必要です。1 チップ、4 チップ、8 チップの構成のうち、1 チップを超える最小の構成は 4 チップ構成(ct5lp-hightpu-4t
)です。--tensor_parallel_size=4
は後で設定できます。 - 目的のユースケースの最大コンテキスト長(入力長 + 出力長)を決定します。たとえば、4096 です。
--max_model_len=4096
は後で設定できます。 - モデル、ハードウェア、サーバーの構成(
--hbm_utilization_factor
)で実現可能な最大値になるように、KV キャッシュに割り当てられる空き TPU メモリの量を調整します。0.95
から始めます。Hex-LLM サーバーをデプロイし、長いプロンプトと高い同時実行性でサーバーをテストします。サーバーでメモリ不足が発生した場合は、使用率係数を適宜減らします。
Llama 3.1 8B Instruct をデプロイする引数のサンプルセットは次のとおりです。
python -m hex_llm.server.api_server \
--model=meta-llama/Llama-3.1-8B-Instruct \
--tensor_parallel_size=4 \
--max_model_len=4096
--hbm_utilization_factor=0.95
ct5lp-hightpu-4t
に Llama 3.1 70B Instruct AWQ をデプロイする場合の引数のサンプルセットは次のとおりです。
python -m hex_llm.server.api_server \
--model=hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 \
--tensor_parallel_size=4 \
--max_model_len=4096
--hbm_utilization_factor=0.45
Cloud TPU の割り当てをリクエストする
Model Garden では、デフォルトの割り当ては us-west1
リージョンの Cloud TPU v5e チップ 32 個です。この割り当ては、ワンクリック デプロイと Colab Enterprise ノートブックのデプロイに適用されます。割り当て値の引き上げをリクエストするには、割り当ての調整をリクエストするをご覧ください。