Pathways でインタラクティブなワークロードを実行する

Pathways インタラクティブ ワークロードは、Pathways クラスタをホストする GKE クラスタの一部ではない VM 内で実行されるリモート JAX ワークロードです。バッチ ワークロードとは異なり、インタラクティブ ワークロード オペレーションが完了しても Pathways クラスタ コンポーネントはシャットダウンされず、他の JAX クライアントによる接続が可能な状態が維持されます。このドキュメントでは、インタラクティブ ワークロードの例として Jupyter ノートブックを使用します。

JAX ユーザーは、IFRT インターフェースを使用して、Pathways クラスタにコマンドを送信します。JAX コードは、ターミナル、ノートブック、Python 互換の環境のいずれから実行されても、Pathways リソースとシームレスにやり取りできます。

始める前に

インストールに必要なもの:

インタラクティブ モードで Pathways を実行する

xpk または kubectl を使用して、インタラクティブ モードで Pathways を実行できます。

XPK

  1. 次の環境変数を設定します。

    export WORKLOAD=WORKLOAD
    export WORKLOAD_NODEPOOL_COUNT=WORKLOAD_NODEPOOL_COUNT
    export TPU_TYPE=TPU_TYPE
    export PROJECT_ID=PROJECT
    export ZONE=ZONE \
    export CLUSTER=CLUSTER

    次のように置き換えます。

    • WORKLOAD: ワークロードを識別するための一意の名前に設定します
    • WORKLOAD_NODEPOOL_COUNT: Pathways ワークロードで使用されるノードプールの数
    • TPU_TYPE: TPU タイプは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされている TPU タイプの詳細については、TPU のバージョンをご覧ください。
    • PROJECT: 実際の Google Cloud プロジェクト ID
    • ZONE: ワークロードを実行する予定のゾーン
    • CLUSTER: GKE クラスタの名前
  2. クラスタに Pathways コンテナを作成します。ヘッドレス ワークロードを実行するには、次のコマンドを実行します。

    xpk workload create-pathways \
    --headless \
    --workload=${WORKLOAD} \
    --num-slices=${WORKLOAD_NODEPOOL_COUNT} \
    --tpu-type=${TPU_TYPE} \
    --project=${PROJECT} \
    --zone=${ZONE} \
    --cluster=${CLUSTER}

この時点で、JAX ワークロードは IFRT プロキシ サーバーに接続できます。

kubectl

次の YAML は、main コンテナを指定していない点を除き、バッチ ワークロード YAML と同じです。

  1. プレースホルダを置き換え、次の YAML をコピーして、pathways-headless-workload.yaml というファイルに貼り付けます。
    apiVersion: pathways-job.pathways.domain/v1
    kind: PathwaysJob
    metadata:
      name: pathways-USERNAME
    spec:
      maxRestarts: MAX_RESTARTS
      workers:
        - type: TPU_MACHINE_TYPE
          topology: TOPOLOGY
          numSlices: WORKLOAD_NODEPOOL_COUNT
      pathwaysDir: gs://BUCKET_NAME
      controller:
        deploymentMode: default
        
    次のように置き換えます。
    • USERNAME : ユーザー名
    • MAX_RESTARTS : PathwaysJob を再起動できる最大回数
    • TPU_MACHINE_TYPE : 使用する TPU マシンタイプ。サポートされている値の例: ct6e-standard-8t、ct5p-hightpu-4t
    • TOPOLOGY : TPU トポロジ
    • WORKLOAD_NODEPOOL_COUNT : Pathways ワークロードで使用されるノードプールの数
    • BUCKET_NAME : 一時ファイルの保存に使用される Cloud Storage バケット
    前の YAML の WORKLOAD_NODEPOOL_COUNT で指定されたノードプール(pathways-worker レプリカ)の数を変更するには、この PathwaysJob を削除し、更新されたノードプール数で新しい PathwaysJob を作成する必要があります。また、接続されているノートブックを再起動して、新しい Pathways クラスタとの接続を確立する必要があります。
  2. pathways-headless-workload.yaml ファイルを適用します。
      kubectl apply -f pathways-headless-workload.yaml
      
  3. kubectl get pods を実行して、Pod 内のすべてのコンテナが実行されていることを確認します。次の出力は 2 スライス v5p 2x2x2 のものです。ここで、USER はコマンドを実行しているユーザーの ID です。
        NAME                                         READY   STATUS    RESTARTS   AGE
        pathways-USER-pathways-head-0-0-n848j      2/2     Running   0          49s
        pathways-USER-pathways-workers-0-0-jxt2z   1/1     Running   0          71s
        pathways-USER-pathways-workers-0-1-cxmhc   1/1     Running   0          70s
        pathways-USER-pathways-workers-1-0-5kmz9   1/1     Running   0          71s
        pathways-USER-pathways-workers-1-1-vg5n4   1/1     Running   0          71s
        

インタラクティブ モードで Pathways クラスタに接続する

ポート転送の有無にかかわらず、Pathways クラスタに接続できます。次のいずれかのセクションを使用して、Pathways クラスタに接続します。

ポート転送を使用して接続する

この時点で、ポート転送(クラスタのコントロール プレーンにアクセスできる任意のホストから)を使用してプロキシ サーバーにアクセスできます。

ワークロードに適したコマンドを使用します。

XPK

PROXY_POD=$(kubectl get pods | grep ${WORKLOAD}-pathways-head | awk '{print $1}')
PROXY_PORT=29000
kubectl port-forward ${PROXY_POD} ${PROXY_PORT}:${PROXY_PORT}

次のような出力が表示されます。

Forwarding from 127.0.0.1:29000 -> 29000
Forwarding from [::1]:29000 -> 29000

kubectl

PROXY_POD=$(kubectl get pods | grep pathways-${USER}-pathways-head | awk '{print $1}')
PROXY_PORT=29000
kubectl port-forward ${PROXY_POD} ${PROXY_PORT}:${PROXY_PORT}

次のような出力が表示されます。

Forwarding from 127.0.0.1:29000 -> 29000
Forwarding from [::1]:29000 -> 29000

同じホストで、新しいターミナル ウィンドウを開きます。JAX_PLATFORMSJAX_BACKEND_TARGET 環境変数を設定し、pathwaysutilsjax をインポートする Python スクリプトを実行します。

python3 -m venv .venv
source .venv/bin/activate
pip install pathwaysutils jax

JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'

次のような出力が表示されます。

[device(144,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(145,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(146,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(147,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(148,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(149,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(150,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(151,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(162,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(163,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(164,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(165,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(166,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(167,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(168,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(169,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
Waiting up to 5 seconds.
Sent all pending logs.
2024-11-13 21:38:51.267523: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled

ポート転送を使用せずに VPC 内のホストから接続する

ポート転送を使用しない場合は、Cloud DNS または内部ロードバランサを使用して Pathways クラスタに接続できます。

Cloud DNS を使用して接続する

クラスタで Cloud DNS を有効にすると、Cloud DNS プロバイダが kube-dns から Cloud DNS に切り替わります。有効にすると、Cloud DNS 名用に Virtual Private Cloud に限定公開 Cloud DNS ゾーンが作成されます。詳細については、GKE 向け Cloud DNS の使用をご覧ください。

クラスタ、追加の VPC、または VPC スコープで Cloud DNS を有効にすると、Virtual Private Cloud 内の非 GKE VM から Kubernetes Cloud DNS 名を解決できます。名前の形式は <service_name>.<namespace>.svc.<custom_dns_domain> です。Pathways ヘッド Pod には、<jobset_name>-pathways-head-0-0.<jobset_name>.<namespace>.svc.<custom_dns_domain> という名前のサービスがあります。

次のコマンドは、Cloud DNS を使用して Pathways クラスタに接続する方法を示しています。

  1. リーダーの Cloud DNS エントリが GKE 以外のホストから解決可能であることを確認します。

    XPK

    host WORKLOAD-pathways-head-0-0.WORKLOAD.default.svc.USERNAME-test

    次のような出力が表示されます。

    <WORKLOAD>-pathways-head-0-0.<WORKLOAD>.default.svc.<user>-test has address 10.0.2.75

    kubectl

    host pathways-USERNAME-pathways-head-0-0.pathways-USERNAME.default.svc.USERNAME-test

    次のような出力が表示されます。

    pathways-<user>-pathways-head-0-0.pathways-<user>.default.svc.<user>-test has address 10.0.2.75
  2. Cloud DNS 名を使用して Pathways クラスタに接続します。

    XPK

    JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://WORKLOAD-pathways-head-0-0.WORKLOAD.default.svc.USERNAME-test:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'

    kubectl

    JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://pathways-USERNAME-pathways-head-0-0.pathways-USERNAME.default.svc.USERNAME-test:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'

    次のような出力が表示されます。

    [device(216,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(217,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(218,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(219,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(220,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(221,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(222,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(223,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(234,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(235,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(236,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(237,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(238,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(239,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(240,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(241,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
    Waiting up to 5 seconds.
    Sent all pending logs.
    2024-11-14 00:02:49.882044: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled

内部ロードバランサを使用して接続する

パスウェイ デプロイを指す VPC のプライベート IP アドレスに対して、内部ロードバランサによってバックアップされるサービスを作成します。これには、クラスタで Cloud DNS を有効にする必要はありません

VM が多いクラスタの場合は、内部ロードバランサを作成する際に ILB サブセット化を有効にすることをおすすめします。詳細については、既存のクラスタで GKE のサブセット化を有効にするをご覧ください。ILB サブセット化が有効になっていない場合、クラスタ内のすべてのノードがすべての内部ロードバランサのバックエンド インスタンス グループの一部になります。これは 250 ノードを超えてスケーリングされません。ILB サブセット化が有効になっている場合、GKE はインスタンス グループではなくネットワーク エンドポイント グループを作成し、サービスのサービング Pod のいずれかを実行しているノードのみが含まれます。ILB サブセットを有効にすると、1 回限りの設定遅延(約 15 分)が発生します。次のコマンドは、ILB サブセットを有効にする方法を示しています。

gcloud container clusters update ${CLUSTER} \
  --project=${PROJECT} \
  [--zone=${ZONE} | --region=${REGION}] \
  --enable-l4-ilb-subsetting

ILB サブセッティングを有効にすると、次の YAML を使用して LoadBalancer タイプの Kubernetes サービスを作成できます。これにより、GKE はクラスタの VPC 内に内部ロードバランサを作成します。

apiVersion: v1
kind: Service
metadata:
  name: pathways-USERNAME-ilb
  annotations:
    networking.gke.io/load-balancer-type: "Internal"
    networking.gke.io/internal-load-balancer-allow-global-access: "true"
spec:
  type: LoadBalancer
  externalTrafficPolicy: Local
  selector:
    jobset.sigs.k8s.io/jobset-name: pathways-USER
    jobset.sigs.k8s.io/replicatedjob-name: pathways-head
  ports:
  - name: tcp-port
    protocol: TCP
    port: 29000
    targetPort: 29000

USER を Google Cloud ユーザー ID で更新し、ファイルを pathways-headless-ilb.yaml として保存します。

次のようにマニフェストを適用します。

kubectl apply -f pathways-headless-ilb.yaml

ロードバランサが作成されると(約 1 分後)、EXTERNAL-IP 列に値が表示されます。

kubectl get services
NAME                  TYPE           CLUSTER-IP      EXTERNAL-IP   PORT(S)        AGE
pathways-$USER       ClusterIP      None            <none>        <none>         30m
pathways-$USER-ilb   LoadBalancer   34.118.232.46   10.0.0.22     80:31246/TCP   2m41s

クラスタと同じ VPC 内のホストでポート転送を行わずに、パスウェイ デプロイにアクセスできます。

JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://10.0.0.22:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'

次のような出力が表示されます。

[device(288,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
 device(289,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
 device(290,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
 device(291,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
 device(292,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
 device(293,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
 device(294,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
 device(295,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
 device(306,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
 device(307,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
 device(308,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
 device(309,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
 device(310,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
 device(311,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
 device(312,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
 device(313,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
Waiting up to 5 seconds.
Sent all pending logs.
2024-11-14 00:30:07.296917: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled

Jupyter ノートブック

Vertex AI を使用して Jupyter ノートブックを作成することも、セルフホストの Jupyter ノートブックを作成することもできます。

Vertex AI Workbench インスタンスを作成する

Pathways クラスタを設定して検証したら、Vertex AI Jupyter ノートブックから GKE TPU VM にアクセスできます。次の設定手順では、GKE Pathways クラスタが同じ Virtual Private Cloud ネットワークに存在することを前提としています(特に構成していない限り、これがデフォルトのネットワークです)。Vertex AI Workbench コンソールに移動します。

[インスタンス] タブで [新規作成] ボタンを使用して、新しい Workbench インスタンスを作成します。ネットワークが GKE クラスタのネットワークと同じであることを確認します。コマンドラインを使用して、新しい Workbench インスタンスを作成できます。

gcloud workbench instances create INSTANCE_NAME \
--machine-type=e2-standard-4 \
--data-disk-size=100 \
--location=ZONE \
[--network=NETWORK]

インスタンスが作成されたら、そのインスタンスに移動して [Jupyterlab を開く] をクリックします。

セルフホスト型の Jupyter ノートブック インスタンスを作成する

次のコマンドは、XPK を使用してセルフホスト Jupyter ノートブック インスタンスを作成する方法を示しています。

xpk workload create-pathways \
--workload=${WORKLOAD} \
--num-slices=${WORKLOAD_NODEPOOL_COUNT} \
--tpu-type=${TPU_TYPE} \
--project=${PROJECT} \
--zone=${ZONE} \
--cluster=${CLUSTER} \
--docker-image=jupyter/base-notebook \
--command "start-notebook.sh"

次の YAML は、kubectl を使用してセルフホストの Jupyter ノートブック インスタンスを作成する方法を示しています。ヘッドレス Pathways クラスタの作成後に、次の YAML を適用します。詳細については、kubectl を使用してインタラクティブ モードで Pathways を実行するをご覧ください。

apiVersion: batch/v1
kind: Job
metadata:
  name: jupyter-notebook-USERNAME
spec:
  template:
    spec:
      restartPolicy: OnFailure
      containers:
      - name: jupyter-notebook
        image: jupyter/base-notebook  # Use the appropriate Jupyter image
        ports:
        - containerPort: 8888

ポート転送を使用してローカルマシンからノートブックに接続します。

XPK

  MAIN_POD=$(kubectl get pods | grep ${WORKLOAD}-pathways-head | awk '{print $1}')
  kubectl port-forward pod/${MAIN_POD} 8888:8888

kubectl

  MAIN_POD=$(kubectl get pods | grep jupyter-notebook-USERNAME | awk '{print $1}')
  kubectl port-forward pod/${MAIN_POD} 8888:8888

ローカル ブラウザで http://localhost:8888?token=<var>your-token</var> に移動します。<your-token> は、Jupyter ノートブック コンテナのログから取得したトークンに置き換えます。

kubectl logs ${MAIN_POD}

出力は次のようになります。

...
Or copy and paste one of these URLs:
  http://jupyter-notebook-<user>-bbbdh:8888/lab?token=<token>
  http://127.0.0.1:8888/lab?token=<token>

Pathways クラスタへのノートブック接続

  1. Jupyterlab 内から、新しい Python 3 ノートブックを作成します。
  2. Pathways プロキシ サーバーに接続する

ノートブックで、pathwaysutils をインストールするセルを追加し、JAX_PLATFORMSproxy に設定し、JAX_BACKEND_TARGETPROXY_ADDRESS に設定します。

!pip install pathwaysutils
%env JAX_PLATFORMS=proxy
# Replace your proxy address below:
%env JAX_BACKEND_TARGET=PROXY_ADDRESS

2 つ目のセルを「hello world」タイプのチェックとして追加し、Pathways クラスタ内のデバイスを出力します。

import pathwaysutils
import jax

pathwaysutils.initialize()
print(jax.devices())

すべてが正常に動作している場合、Pathways-on-Cloud バックエンドが検出されたことを示すメッセージが表示されます。

リストに表示される JAX デバイスの数は、Pathways クラスタの作成時に指定した TPU チップの数とスライスの数と一致する必要があります。

ノートブックにコードを追加する

独自の JAX コードを追加し、Pathways クラスタの TPU でインタラクティブに実行します。次のコードは、単一のノートブックから 2 つのスライスにわたって計算を実行する方法を示しています。

import jax
import jax.numpy as jnp
from jax import lax
import numpy as np

# You can use JAX APIs as usual across any of the devices.
jax.jit(jnp.sin, device=jax.devices()[-1])(np.pi / 2.)

# pmap can run across all devices on all slices
num_tpus = jax.device_count()
f = jax.pmap(lambda x: lax.psum(1, 'i'), 'i')
x = jnp.arange(num_tpus)
y = f(x)
print(y)

# You can also target devices from a specific slice
slice0_devices = [d for d in jax.devices() if d.slice_index == 0]
f = jax.pmap(lambda x: lax.psum(1, 'i'), 'i', devices=slice0_devices)
x = jnp.arange(len(slice0_devices))
y = f(x)
print(y)
print(y.global_shards)

# You can send data produced on one slice to another slice
slice1_devices = [d for d in jax.devices() if d.slice_index == 1]
g = jax.pmap(lambda x: x + lax.axis_index('i'), 'i', devices=slice1_devices)
z = g(y)
print(z)
print(z.global_shards)

Pathways インタラクティブ クラスタを削除する

XPK

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

kubectl

kubectl delete -f pathways-headless-workload.yaml

次のステップ