使用 Ray 擴充機器學習工作負載

本文詳細說明如何在 TPU 上使用 Ray 和 JAX 執行機器學習 (ML) 工作負載。使用 Ray 時,TPU 有兩種模式:以裝置為中心的模式 (PyTorch/XLA)以主機為中心的模式 (JAX)

本文假設您已設定 TPU 環境。詳情請參閱下列資源:

以裝置為中心的模式 (PyTorch/XLA)

以裝置為中心的模式保留了許多傳統 PyTorch 的程式設計風格。在這個模式下,您會新增 XLA 裝置類型,這與任何其他 PyTorch 裝置的運作方式相同。每個個別程序都會與一個 XLA 裝置互動。

如果您已熟悉 GPU 版 PyTorch,並想使用類似的程式碼抽象化,就很適合使用這個模式。

以下章節說明如何在不使用 Ray 的情況下,在一或多個裝置上執行 PyTorch/XLA 工作負載,然後說明如何使用 Ray 在多部主機上執行相同的工作負載。

建立 TPU

  1. 為 TPU 建立參數建立環境變數。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export RUNTIME_VERSION=v2-alpha-tpuv5

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用下列指令建立具有 8 個核心的 v5p TPU VM:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 使用下列指令連線至 TPU VM:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

如果您使用 GKE,請參閱 GKE 上的 KubeRay 設定指南。

安裝需求

在 TPU VM 上執行下列指令,安裝必要依附元件:

  1. 將下列內容儲存為檔案。例如 requirements.txt

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. 如要安裝必要依附元件,請執行下列指令:

    pip install -r requirements.txt
    

如果您在 GKE 上執行工作負載,建議建立 Dockerfile 來安裝必要依附元件。如需範例,請參閱 GKE 說明文件中的「在 TPU 節點上執行工作負載」。

在單一裝置上執行 PyTorch/XLA 工作負載

以下範例說明如何在單一裝置 (即 TPU 晶片) 上建立 XLA 張量。這與 PyTorch 處理其他裝置類型的方式類似。

  1. 將下列程式碼片段儲存到檔案中。例如 workload.py

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    import torch_xla 匯入陳述式會初始化 PyTorch/XLA,而 xm.xla_device() 函式會傳回目前的 XLA 裝置 (TPU 晶片)。

  2. PJRT_DEVICE 環境變數設為 TPU。

    export PJRT_DEVICE=TPU
    
  3. 執行指令碼。

    python workload.py
    

    輸出結果看起來與下列內容相似。確認輸出內容顯示已找到 XLA 裝置。

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

在多部裝置上執行 PyTorch/XLA

  1. 更新上一節的程式碼片段,以便在多部裝置上執行。

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. 執行指令碼。

    python workload.py
    

    如果您在 TPU v5p-8 上執行程式碼片段,輸出結果會與下列內容相似:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() 會採用兩個引數:函式和參數清單。這項作業會為每個可用的 XLA 裝置建立程序,並呼叫引數中指定的函式。在這個例子中,有 4 個可用的 TPU 裝置,因此 torch_xla.launch() 會建立 4 個程序,並在每個裝置上呼叫 _mp_fn()。每個程序只能存取一部裝置,因此每部裝置的索引都是 0,且所有程序都會列印 xla:0

使用 Ray 在多個主機上執行 PyTorch/XLA

後續章節會說明如何在較大的多主機 TPU 切片上執行相同的程式碼片段。如要進一步瞭解多主機 TPU 架構,請參閱「系統架構」。

在本範例中,您會手動設定 Ray。如果您已熟悉如何設定 Ray,可以跳到最後一節「執行 Ray 工作負載」。如要進一步瞭解如何為正式環境設定 Ray,請參閱下列資源:

建立多主機 TPU VM

  1. 為 TPU 建立參數建立環境變數。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=v2-alpha-tpuv5

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用下列指令建立具有 2 部主機的多主機 TPU v5p (每部主機上都有 4 個 TPU 晶片的 v5p-16):

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE \
       --version=$RUNTIME_VERSION

設定 Ray

TPU v5p-16 具有 2 個 TPU 主機,每個主機有 4 個 TPU 晶片。在本範例中,您會在一個主機上啟動 Ray 頭部節點,並將第二個主機新增為 Ray 叢集的工作站節點。

  1. 使用 SSH 連線至第一個主機。

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
  2. 使用與安裝需求部分相同的需求檔案安裝依附元件。

    pip install -r requirements.txt
    
  3. 啟動 Ray 程序。

    ray start --head --port=6379
    

    輸出看起來類似以下內容:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    這個 TPU 主機現在是 Ray 頭部節點。請記下顯示如何將其他節點新增至 Ray 叢集的行,類似於下列內容:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    您將在後續步驟中使用此指令。

  4. 檢查 Ray 叢集狀態:

    ray status
    

    輸出看起來類似以下內容:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    由於您目前只新增了頭部節點,因此叢集只包含 4 個 TPU (0.0/4.0 TPU)。

    現在頭部節點正在執行,您可以將第二部主機新增至叢集。

  5. 使用 SSH 連線至第二部主機。

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
  6. 使用與「安裝必要條件」一節中相同的必要條件檔案,安裝依附元件。

    pip install -r requirements.txt
    
  7. 啟動 Ray 程序。如要將這個節點新增至現有 Ray 叢集,請使用 ray start 指令輸出內容中的指令。請務必在下列指令中取代 IP 位址和通訊埠:

    ray start --address='10.130.0.76:6379'

    輸出看起來類似以下內容:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  8. 再次檢查 Ray 狀態:

    ray status
    

    輸出看起來類似以下內容:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    第二個 TPU 主機現在是叢集中的節點。可用資源清單現在會顯示 8 個 TPU (0.0/8.0 TPU)。

執行 Ray 工作負載

  1. 更新程式碼片段,在 Ray 叢集上執行:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host.
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster.
    NUM_OF_HOSTS = 4
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip.
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster.
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)]
    ray.get(tasks)
    
    ray.shutdown()
    
  2. 在 Ray 頭部節點上執行指令碼。將 ray-workload.py 替換為指令碼的路徑。

    python ray-workload.py

    輸出看起來類似以下內容:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    輸出內容表示已在多主機 TPU 配量中的每個 XLA 裝置 (本例為 8 個裝置) 上成功呼叫函式。

以主機為中心的模式 (JAX)

以下各節說明以主機為中心的 JAX 模式。JAX 採用函式程式設計範例,並支援高階單一程式、多重資料 (SPMD) 語意。JAX 程式碼的設計是讓每個程序與單一主機上的多個裝置同時互動,而不是與單一 XLA 裝置互動。

JAX 專為高效能運算而設計,可有效運用 TPU 進行大規模訓練和推論。如果您熟悉函式程式設計概念,就能充分發揮 JAX 的潛力,因此非常適合使用這個模式。

這些操作說明假設您已設定 Ray 和 TPU 環境,包括含有 JAX 和其他相關套件的軟體環境。如要建立 Ray TPU 叢集,請按照「使用 TPU 啟動 KubeRay 的 GKE 叢集」Google Cloud 中的操作說明進行。如要進一步瞭解如何將 TPU 與 KubeRay 搭配使用,請參閱「將 TPU 與 KubeRay 搭配使用」。

在單一主機 TPU 上執行 JAX 工作負載

以下範例指令碼示範如何在具有單一主機 TPU (例如 v6e-4) 的 Ray 叢集上執行 JAX 函式。如果您使用多主機 TPU,這個指令碼會因 JAX 的多重控制器執行模型而停止回應。如要進一步瞭解如何在多主機 TPU 上執行 Ray,請參閱「在多主機 TPU 上執行 JAX 工作負載」。

  1. 為 TPU 建立參數建立環境變數。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-4
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用下列指令建立具有 4 個核心的 v6e TPU VM:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 使用下列指令連線至 TPU VM:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
  4. 在 TPU 上安裝 JAX 和 Ray。

    pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. 將下列程式碼儲存到檔案中。例如 ray-jax-single-host.py

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    h = my_function.remote()
    print(ray.get(h)) # => 4
    

    如果您習慣使用 GPU 執行 Ray,使用 TPU 時會發現一些主要差異:

    • 請勿設定 num_gpus,而是將 TPU 指定為自訂資源,並設定 TPU 晶片數量。
    • 使用每個 Ray 工作站節點的晶片數量指定 TPU。舉例來說,如果您使用 v6e-4,並將 TPU 設為 4 來執行遠端函式,就會耗用整個 TPU 主機。
    • 這與 GPU 通常的執行方式不同,因為每個主機都有一個程序。不建議將 TPU 設為 4 以外的數字。
      • 例外狀況:如果您有單一主機 v6e-8v5litepod-8,則應將這個值設為 8。
  6. 執行指令碼。

    python ray-jax-single-host.py

在多主機 TPU 上執行 JAX 工作負載

下列範例指令碼示範如何在具有多主機 TPU 的 Ray 叢集上執行 JAX 函式。範例指令碼使用 v6e-16。

  1. 為 TPU 建立參數建立環境變數。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-16
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用下列指令建立具有 16 個核心的 v6e TPU VM:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 在所有 TPU 工作站上安裝 JAX 和 Ray。

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
  4. 將下列程式碼儲存到檔案中。例如 ray-jax-multi-host.py

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    ray.init()
    num_tpus = ray.available_resources()["TPU"]
    num_hosts = int(num_tpus) # 4
    h = [my_function.remote() for _ in range(num_hosts)]
    print(ray.get(h)) # [16, 16, 16, 16]
    

    如果您習慣使用 GPU 執行 Ray,使用 TPU 時會發現一些主要差異:

    • 與 GPU 上的 PyTorch 工作負載類似:
    • 與 GPU 上的 PyTorch 工作負載不同,JAX 可全面掌握叢集中可用的裝置。
  5. 將指令碼複製到所有 TPU 工作站。

    gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
  6. 執行指令碼。

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="python ray-jax-multi-host.py"

執行 Multislice JAX 工作負載

Multislice 可讓您在單一 TPU Pod 內或透過資料中心網路的多個 Pod,執行跨多個 TPU Slice 的工作負載。

您可以使用 ray-tpu 套件,簡化 Ray 與 TPU 切片的互動。

使用 pip 安裝 ray-tpu

pip install ray-tpu

如要進一步瞭解如何使用 ray-tpu 套件,請參閱 GitHub 存放區的「開始使用」一節。如需使用 Multislice 的範例,請參閱「在 Multislice 上執行」。

使用 Ray 和 MaxText 自動化調度管理工作負載

如要進一步瞭解如何搭配使用 Ray 和 MaxText,請參閱「使用 MaxText 執行訓練工作」。

TPU 和 Ray 資源

Ray 會以不同於 GPU 的方式處理 TPU,以因應使用上的差異。在下列範例中,共有九個 Ray 節點:

  • Ray 頭部節點在 n1-standard-16 VM 上執行。
  • Ray 工作站節點會在兩個 v6e-16 TPU 上執行。每個 TPU 包含四個工作站。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

資源使用情況欄位說明:

  • CPU:叢集中可用的 CPU 總數。
  • TPU:叢集中的 TPU 晶片數量。
  • TPU-v6e-16-head:資源的特殊 ID,對應至 TPU 節點的 Worker 0。這對存取個別 TPU 節點至關重要。
  • memory:應用程式使用的 Worker 堆積記憶體。
  • object_store_memory:應用程式使用 ray.put 在物件儲存庫中建立物件,以及從遠端函式傳回值時所用的記憶體。
  • tpu-group-0tpu-group-1:個別 TPU 節點的專屬 ID。這對於在切片上執行工作非常重要。這些欄位設為 4,是因為在 v6e-16 中,每個 TPU 節點有 4 個主機。