ML-Arbeitslasten mit Ray skalieren

In diesem Dokument wird beschrieben, wie Sie Arbeitslasten für maschinelles Lernen (ML) mit Ray und JAX auf TPUs ausführen. Es gibt zwei verschiedene Modi für die Verwendung von TPUs mit Ray: geräteorientierter Modus (PyTorch/XLA) und hostzentrierter Modus (JAX).

In diesem Dokument wird davon ausgegangen, dass Sie bereits eine TPU-Umgebung eingerichtet haben. Weitere Informationen finden Sie in den folgenden Ressourcen:

Geräteorientierter Modus (PyTorch/XLA)

Im geräteorientierten Modus wird der programmatische Stil des klassischen PyTorch beibehalten. In diesem Modus fügen Sie einen neuen XLA-Gerätetyp hinzu, der wie jedes andere PyTorch-Gerät funktioniert. Jeder einzelne Prozess interagiert mit einem XLA-Gerät.

Dieser Modus ist ideal, wenn Sie bereits mit PyTorch auf GPUs vertraut sind und ähnliche Codeabstraktionen verwenden möchten.

In den folgenden Abschnitten wird beschrieben, wie Sie eine PyTorch/XLA-Arbeitslast auf einem oder mehreren Geräten ohne Ray ausführen und dann dieselbe Arbeitslast mit Ray auf mehreren Hosts ausführen.

TPU erstellen

  1. Erstellen Sie Umgebungsvariablen für die Parameter zum Erstellen von TPUs.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export RUNTIME_VERSION=v2-alpha-tpuv5

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Projekt-ID in Google Cloud . Verwenden Sie ein vorhandenes oder erstellen Sie ein neues Projekt.
    TPU_NAME Der Name der TPU
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und -Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.

  2. Verwenden Sie den folgenden Befehl, um eine v5p-TPU-VM mit 8 Kernen zu erstellen:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. Stellen Sie mit dem folgenden Befehl eine Verbindung zur TPU-VM her:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

Wenn Sie GKE verwenden, finden Sie Informationen zur Einrichtung im Leitfaden zu KubeRay in der GKE.

Installationsanforderungen

Führen Sie die folgenden Befehle auf Ihrer TPU-VM aus, um die erforderlichen Abhängigkeiten zu installieren:

  1. Speichern Sie Folgendes in einer Datei. Beispiel: requirements.txt.

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. Führen Sie den folgenden Befehl aus, um die erforderlichen Abhängigkeiten zu installieren:

    pip install -r requirements.txt
    

Wenn Sie Ihre Arbeitslast in der GKE ausführen, empfehlen wir, ein Dockerfile zu erstellen, in dem die erforderlichen Abhängigkeiten installiert werden. Ein Beispiel finden Sie in der GKE-Dokumentation unter Arbeitslast auf TPU-Slice-Knoten ausführen.

PyTorch/XLA-Arbeitslast auf einem einzelnen Gerät ausführen

Im folgenden Beispiel wird gezeigt, wie Sie einen XLA-Tensor auf einem einzelnen Gerät, einem TPU-Chip, erstellen. Das ist ähnlich wie bei der Verarbeitung anderer Gerätetypen in PyTorch.

  1. Speichern Sie das folgende Code-Snippet in einer Datei. Beispiel: workload.py.

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    Mit der Importanweisung import torch_xla wird PyTorch/XLA initialisiert und die Funktion xm.xla_device() gibt das aktuelle XLA-Gerät, einen TPU-Chip, zurück.

  2. Legen Sie die Umgebungsvariable PJRT_DEVICE auf „TPU“ fest.

    export PJRT_DEVICE=TPU
    
  3. Führen Sie das Skript aus:

    python workload.py
    

    Die Ausgabe sieht ungefähr so aus: Achten Sie darauf, dass in der Ausgabe angegeben ist, dass das XLA-Gerät gefunden wurde.

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

PyTorch/XLA auf mehreren Geräten ausführen

  1. Aktualisieren Sie das Code-Snippet aus dem vorherigen Abschnitt, damit es auf mehreren Geräten ausgeführt werden kann.

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. Führen Sie das Skript aus:

    python workload.py
    

    Wenn Sie das Code-Snippet auf einer TPU v5p-8 ausführen, sieht die Ausgabe in etwa so aus:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() verwendet zwei Argumente: eine Funktion und eine Liste von Parametern. Für jedes verfügbare XLA-Gerät wird ein Prozess erstellt und die in den Argumenten angegebene Funktion aufgerufen. In diesem Beispiel sind 4 TPU-Geräte verfügbar. torch_xla.launch() erstellt also 4 Prozesse und ruft _mp_fn() auf jedem Gerät auf. Jeder Prozess hat nur Zugriff auf ein Gerät. Daher hat jedes Gerät den Index 0 und xla:0 wird für alle Prozesse ausgegeben.

PyTorch/XLA mit Ray auf mehreren Hosts ausführen

In den folgenden Abschnitten wird gezeigt, wie Sie dasselbe Code-Snippet auf einem größeren TPU-Slice mit mehreren Hosts ausführen. Weitere Informationen zur TPU-Architektur mit mehreren Hosts finden Sie unter Systemarchitektur.

In diesem Beispiel richten Sie Ray manuell ein. Wenn Sie bereits mit der Einrichtung von Ray vertraut sind, können Sie mit dem letzten Abschnitt Ray-Arbeitslast ausführen fortfahren. Weitere Informationen zum Einrichten von Ray für eine Produktionsumgebung finden Sie in den folgenden Ressourcen:

TPU-VM mit mehreren Hosts erstellen

  1. Erstellen Sie Umgebungsvariablen für die Parameter zum Erstellen von TPUs.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=v2-alpha-tpuv5

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Projekt-ID in Google Cloud . Verwenden Sie ein vorhandenes oder erstellen Sie ein neues Projekt.
    TPU_NAME Der Name der TPU
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und -Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.

  2. Erstellen Sie mit dem folgenden Befehl eine Multi-Host-TPU v5p mit 2 Hosts (eine v5p-16 mit 4 TPU-Chips auf jedem Host):

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE \
       --version=$RUNTIME_VERSION

Ray einrichten

Eine TPU v5p-16 hat 2 TPU-Hosts mit jeweils 4 TPU-Chips. In diesem Beispiel starten Sie den Ray-Head-Knoten auf einem Host und fügen den zweiten Host als Worker-Knoten zum Ray-Cluster hinzu.

  1. Stellen Sie eine SSH-Verbindung zum ersten Host her.

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
  2. Installieren Sie Abhängigkeiten mit derselben Anforderungsdatei wie im Abschnitt Anforderungen installieren.

    pip install -r requirements.txt
    
  3. Starten Sie den Ray-Prozess.

    ray start --head --port=6379
    

    Die Ausgabe sieht dann ungefähr so aus:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    Dieser TPU-Host ist jetzt der Ray-Head-Knoten. Notieren Sie sich die Zeilen, in denen beschrieben wird, wie Sie dem Ray-Cluster einen weiteren Knoten hinzufügen, z. B.:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    Sie verwenden diesen Befehl in einem späteren Schritt.

  4. Prüfen Sie den Status des Ray-Clusters:

    ray status
    

    Die Ausgabe sieht dann ungefähr so aus:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    Der Cluster enthält nur 4 TPUs (0.0/4.0 TPU), da Sie bisher nur den Head-Knoten hinzugefügt haben.

    Nachdem der Head-Knoten ausgeführt wird, können Sie den zweiten Host dem Cluster hinzufügen.

  5. Stellen Sie eine SSH-Verbindung zum zweiten Host her.

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
  6. Installieren Sie Abhängigkeiten mit derselben Anforderungsdatei wie im Abschnitt Anforderungen installieren.

    pip install -r requirements.txt
    
  7. Starten Sie den Ray-Prozess. Wenn Sie diesen Knoten dem vorhandenen Ray-Cluster hinzufügen möchten, verwenden Sie den Befehl aus der Ausgabe des ray start-Befehls. Achten Sie darauf, die IP-Adresse und den Port im folgenden Befehl zu ersetzen:

    ray start --address='10.130.0.76:6379'

    Die Ausgabe sieht dann ungefähr so aus:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  8. Prüfen Sie den Ray-Status noch einmal:

    ray status
    

    Die Ausgabe sieht dann ungefähr so aus:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    Der zweite TPU-Host ist jetzt ein Knoten im Cluster. In der Liste der verfügbaren Ressourcen werden jetzt 8 TPUs (0.0/8.0 TPU) angezeigt.

Ray-Arbeitslast ausführen

  1. Aktualisieren Sie das Code-Snippet, damit es im Ray-Cluster ausgeführt wird:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host.
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster.
    NUM_OF_HOSTS = 4
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip.
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster.
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)]
    ray.get(tasks)
    
    ray.shutdown()
    
  2. Führen Sie das Skript auf dem Ray-Head-Knoten aus. Ersetzen Sie ray-workload.py durch den Pfad zu Ihrem Skript.

    python ray-workload.py

    Die Ausgabe sieht dann ungefähr so aus:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    Die Ausgabe zeigt, dass die Funktion auf jedem XLA-Gerät (in diesem Beispiel 8 Geräte) im TPU-Slice mit mehreren Hosts erfolgreich aufgerufen wurde.

Hostzentrierter Modus (JAX)

In den folgenden Abschnitten wird der hostzentrierte Modus mit JAX beschrieben. JAX verwendet ein funktionales Programmierparadigma und unterstützt SPMD-Semantik (Single Program, Multiple Data) auf höherer Ebene. Anstatt dass jeder Prozess mit einem einzelnen XLA-Gerät interagiert, ist JAX-Code so konzipiert, dass er gleichzeitig auf mehreren Geräten auf einem einzelnen Host ausgeführt wird.

JAX wurde für Hochleistungs-Computing entwickelt und kann TPUs für umfangreiches Training und Inferenz effizient nutzen. Dieser Modus ist ideal, wenn Sie mit den Konzepten der funktionalen Programmierung vertraut sind, damit Sie das volle Potenzial von JAX nutzen können.

In dieser Anleitung wird davon ausgegangen, dass Sie bereits eine Ray- und TPU-Umgebung eingerichtet haben, einschließlich einer Softwareumgebung, die JAX und andere zugehörige Pakete enthält. Wenn Sie einen Ray-TPU-Cluster erstellen möchten, folgen Sie der Anleitung unter GKE-Cluster in Google Cloud mit TPUs für KubeRay starten. Weitere Informationen zur Verwendung von TPUs mit KubeRay finden Sie unter TPUs mit KubeRay verwenden.

JAX-Arbeitslast auf einer TPU mit einem Host ausführen

Das folgende Beispielskript zeigt, wie Sie eine JAX-Funktion in einem Ray-Cluster mit einer Single-Host-TPU wie einer v6e-4 ausführen. Wenn Sie eine Multi-Host-TPU verwenden, reagiert dieses Skript aufgrund des JAX-Ausführungsmodells mit mehreren Controllern nicht mehr. Weitere Informationen zum Ausführen von Ray auf einer TPU mit mehreren Hosts finden Sie unter JAX-Arbeitslast auf einer TPU mit mehreren Hosts ausführen.

  1. Erstellen Sie Umgebungsvariablen für die Parameter zum Erstellen von TPUs.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-4
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Projekt-ID in Google Cloud . Verwenden Sie ein vorhandenes oder erstellen Sie ein neues Projekt.
    TPU_NAME Der Name der TPU
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und -Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.

  2. Verwenden Sie den folgenden Befehl, um eine v6e-TPU-VM mit 4 Kernen zu erstellen:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. Stellen Sie mit dem folgenden Befehl eine Verbindung zur TPU-VM her:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
  4. Installieren Sie JAX und Ray auf Ihrer TPU.

    pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. Speichern Sie den folgenden Code in einer Datei. Beispiel: ray-jax-single-host.py.

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    h = my_function.remote()
    print(ray.get(h)) # => 4
    

    Wenn Sie es gewohnt sind, Ray mit GPUs auszuführen, gibt es einige wichtige Unterschiede bei der Verwendung von TPUs:

    • Geben Sie TPU als benutzerdefinierte Ressource an und legen Sie die Anzahl der TPU-Chips fest, anstatt num_gpus festzulegen.
    • Geben Sie die TPU mit der Anzahl der Chips pro Ray-Worker-Knoten an. Wenn Sie beispielsweise eine v6e-4 verwenden und eine Remote-Funktion mit TPU auf 4 ausführen, wird der gesamte TPU-Host belegt.
    • Das unterscheidet sich von der üblichen Ausführung von GPUs mit einem Prozess pro Host. Es wird nicht empfohlen, TPU auf eine andere Zahl als 4 festzulegen.
      • Ausnahme: Wenn Sie v6e-8 oder v5litepod-8 mit nur einem Host haben, sollten Sie diesen Wert auf 8 festlegen.
  6. Führen Sie das Skript aus:

    python ray-jax-single-host.py

JAX-Arbeitslast auf einer TPU mit mehreren Hosts ausführen

Das folgende Beispielskript zeigt, wie Sie eine JAX-Funktion in einem Ray-Cluster mit einer TPU mit mehreren Hosts ausführen. Im Beispielskript wird ein v6e-16 verwendet.

  1. Erstellen Sie Umgebungsvariablen für die Parameter zum Erstellen von TPUs.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-16
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Projekt-ID in Google Cloud . Verwenden Sie ein vorhandenes oder erstellen Sie ein neues Projekt.
    TPU_NAME Der Name der TPU
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und -Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.

  2. Verwenden Sie den folgenden Befehl, um eine v6e-TPU-VM mit 16 Kernen zu erstellen:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. Installieren Sie JAX und Ray auf allen TPU-Workern.

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
  4. Speichern Sie den folgenden Code in einer Datei. Beispiel: ray-jax-multi-host.py.

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    ray.init()
    num_tpus = ray.available_resources()["TPU"]
    num_hosts = int(num_tpus) # 4
    h = [my_function.remote() for _ in range(num_hosts)]
    print(ray.get(h)) # [16, 16, 16, 16]
    

    Wenn Sie es gewohnt sind, Ray mit GPUs auszuführen, gibt es einige wichtige Unterschiede bei der Verwendung von TPUs:

    • Ähnlich wie bei PyTorch-Arbeitslasten auf GPUs:
    • Im Gegensatz zu PyTorch-Arbeitslasten auf GPUs hat JAX eine globale Ansicht der verfügbaren Geräte im Cluster.
  5. Kopieren Sie das Skript auf alle TPU-Worker.

    gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
  6. Führen Sie das Skript aus:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="python ray-jax-multi-host.py"

Multislice-JAX-Arbeitslast ausführen

Mit Multislice können Sie Arbeitslasten ausführen, die sich über mehrere TPU-Slices in einem einzelnen TPU-Pod oder in mehreren Pods im Rechenzentrumsnetzwerk erstrecken.

Sie können das Paket ray-tpu verwenden, um die Interaktionen von Ray mit TPU-Slices zu vereinfachen.

Installieren Sie ray-tpu mit pip.

pip install ray-tpu

Weitere Informationen zur Verwendung des ray-tpu-Pakets finden Sie im GitHub-Repository unter Erste Schritte. Ein Beispiel für die Verwendung von Multislice finden Sie unter Ausführen mit Multislice.

Arbeitslasten mit Ray und MaxText orchestrieren

Weitere Informationen zur Verwendung von Ray mit MaxText finden Sie unter Trainingsjob mit MaxText ausführen.

TPU- und Ray-Ressourcen

Ray behandelt TPUs anders als GPUs, um den Unterschieden in der Nutzung Rechnung zu tragen. Im folgenden Beispiel gibt es insgesamt neun Ray-Knoten:

  • Der Ray-Head-Knoten wird auf einer n1-standard-16-VM ausgeführt.
  • Die Ray-Worker-Knoten werden auf zwei v6e-16-TPUs ausgeführt. Jede TPU besteht aus vier Workern.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Feldbeschreibungen für die Ressourcennutzung:

  • CPU: die Gesamtzahl der im Cluster verfügbaren CPUs.
  • TPU: die Anzahl der TPU-Chips im Cluster.
  • TPU-v6e-16-head: eine spezielle Kennung für die Ressource, die dem Worker 0 eines TPU-Slice entspricht. Das ist wichtig für den Zugriff auf einzelne TPU-Slices.
  • memory: der von Ihrer Anwendung verwendete Worker-Heap-Speicher.
  • object_store_memory: Arbeitsspeicher, der verwendet wird, wenn Ihre Anwendung mit ray.put Objekte im Objektspeicher erstellt oder Werte von Remote-Funktionen zurückgibt.
  • tpu-group-0 und tpu-group-1: eindeutige Kennungen für die einzelnen TPU-Slices. Das ist wichtig, um Jobs auf Slices auszuführen. Diese Felder sind auf 4 festgelegt, da es in einem v6e-16-Slice vier Hosts pro TPU-Slice gibt.