Transférer des charges de travail JAX vers Pathways

En raison de la nature distribuée de JAX avec Pathways, certaines opérations peuvent ne pas bien s'adapter en raison des frais généraux de communication. Bien que Pathways minimise ces frais généraux grâce à des fonctionnalités telles que l'envoi asynchrone, vous devez tenir compte de certains éléments lorsque vous portez des charges de travail JAX vers Pathways ou que vous mettez à l'échelle une charge de travail JAX avec Pathways sur un grand nombre d'accélérateurs.

Avant de commencer

Vérifiez que vous disposez bien des éléments suivants :

Index du processus

JAX avec Pathways traite tous les appareils de votre cluster Pathways comme des appareils locaux. Cela simplifie la gestion des appareils et permet à JAX d'utiliser toutes les ressources disponibles. En pratique, cela signifie :

  • jax.process_index() est toujours défini sur 0 pour tous les appareils.
  • jax.devices() et jax.local_devices() renvoient tous les appareils TPU de l'ensemble du job.

Type de matériel et colocation

Pour optimiser les performances, placez tous les composants Pathways et le job utilisateur dans la même zone cloud Google Cloud . Utilisez un processeur volumineux comme le proxy IFRT et le gestionnaire de ressources. Nous vous recommandons au moins un n2-standard-64 dédié, qui est fourni avec 64 processeurs virtuels et 256 Go de mémoire.

PathwaysUtils

Pathways-utils est un dépôt GitHub basé sur Python qui fournit des utilitaires et des outils essentiels vous permettant de simplifier le déploiement et l'exécution des charges de travail JAX sur l'architecture Pathways on Cloud. Ce package gère les adaptations nécessaires pour l'environnement cloud, ce qui permet aux développeurs JAX de se concentrer sur leurs principaux workflows de machine learning avec une configuration minimale spécifique à la plate-forme. Plus précisément, il offre les avantages suivants :

  • Un backend JAX "proxy" : ce backend personnalisé permet à votre application JAX d'utiliser l'infrastructure Pathways en définissant la variable d'environnement JAX_PLATFORMS=proxy.
  • Utilitaires de profilage intégrés : fonctionnalités de profilage qui vous permettent de comprendre les performances de votre application. En utilisant des API de profilage JAX standards telles que jax.profiler.start_trace et jax.profiler.start_server, vous pouvez profiler non seulement votre code JAX, mais aussi les composants Pathways sous-jacents, ce qui vous donne une vue globale de l'exécution dans l'environnement cloud.
  • Point de contrôle distribué avec Orbax : un gestionnaire de points de contrôle Orbax personnalisé qui vous permet d'utiliser des points de contrôle distribués et de restaurer vos points de contrôle lorsque vous utilisez la bibliothèque Orbax dans l'environnement Pathways. Cette intégration vise à fonctionner sans nécessiter de modification de votre code de point de contrôle Orbax existant, à condition qu'il importe pathwaysutils.
  • Primitives d'entraînement élastique : fournit des primitives d'entraînement élastique de base que vous pouvez utiliser pour créer des workflows d'entraînement robustes et évolutifs à l'aide de Pathways. Ces primitives permettent à vos jobs d'entraînement de s'adapter de manière dynamique aux changements de ressources disponibles, ce qui améliore l'efficacité et la résilience dans les environnements cloud.

Points de contrôle

Orbax est minutieusement testé avec Pathways pour la création et la restauration distribuées de points de contrôle avec Cloud Storage. Lorsque vous appelez import pathwaysutils; pathwaysutils.initialize() dans votre train.py, un ArrayHandler personnalisé est enregistré pour gérer efficacement les opérations de point de contrôle via le proxy IFRT, ce qui permet aux nœuds de calcul Pathways sur les accélérateurs d'enregistrer et de restaurer directement les données.

Python colocalisé

Colocated Python est une API JAX Open Source qui vous permet d'exécuter du code Python spécifié par l'utilisateur directement sur les hôtes TPU ou GPU, ce qui est plus simple dans JAX multicontrôleur. Cela permet d'éviter le transfert de données entre le client et les machines TPU pour les tâches nécessitant davantage de puissance de calcul, comme le chargement de données et la création de points de contrôle. Pour configurer votre cluster Pathways afin d'exécuter l'API JAX Python colocalisée, suivez les instructions du fichier README Python colocalisé. Ces instructions expliquent comment démarrer un fichier side-car Python colocalisé en même temps que vos workers Pathways.

Chargement des données…

Pendant l'entraînement, nous chargeons à plusieurs reprises des lots à partir d'un ensemble de données pour les fournir au modèle. Il est important de disposer d'un chargeur de données asynchrone et efficace qui segmente le lot sur les hôtes pour éviter de priver les accélérateurs de travail. Lorsque vous exécutez l'entraînement avec Pathways, le chargeur de données s'exécute sur une VM de processeur (contrairement à une VM TPU qui est utilisée dans les configurations multi-contrôleurs) et distribue les données aux VM TPU. Cela entraîne une latence plus élevée lors de la lecture des données, mais celle-ci est partiellement atténuée en lisant à l'avance un certain nombre de lots sur l'hôte du processeur et en envoyant les données lues de manière asynchrone aux TPU. Cette solution est suffisante pour une échelle petite à moyenne.

Pour des performances optimales à grande échelle, nous vous recommandons vivement de colocaliser votre pipeline de données d'entrée en utilisant Python colocalisé pour exécuter votre pipeline de données directement sur les accélérateurs. Cela élimine le goulot d'étranglement du processeur et exploite les interconnexions rapides de la TPU pour le transfert de données.

Vous trouverez une implémentation de référence de la migration d'un pipeline d'entrée basé sur TFDS dans l'implémentation RemoteIterator dans multihost_dataloading.py. Cette implémentation fonctionne à la fois sur JAX et Pathways multicontrôleurs de manière distribuée à l'aide de l'API Python JAX colocalisée.

Gestion des versions de Jax

Les versions de Pathways sont étroitement liées aux versions de JAX pour garantir la compatibilité et la stabilité. Pour éviter tout problème potentiel, vérifiez que vos artefacts Pathways et votre version JAX sont alignés. Chaque version de Pathways spécifie clairement les versions JAX compatibles via un tag au format jax-<version>.

Cache de compilation

Le cache de compilation persistant Pathways est une fonctionnalité qui permet aux serveurs Pathways de stocker les exécutables XLA compilés dans un emplacement persistant, tel que Cloud Storage, pour éviter les compilations redondantes. Cette fonctionnalité est activée par défaut. L'emplacement du cache est transmis en tant qu'indicateur --gcs_scratch_location aux conteneurs du gestionnaire de ressources et du nœud de calcul Pathways. Pour réduire au minimum les coûts de stockage associés, le cache associe une règle de cycle de vie à l'emplacement Cloud Storage. Le nombre de règles par bucket Cloud Storage est limité à 50. Par conséquent, nous vous recommandons d'utiliser un emplacement Cloud Storage commun pour toutes les charges de travail.

Ce cache est semblable au cache de compilation JAX, qui est désactivé par pathwaysutils.initialize() pour les charges de travail Pathways.

Profilage

Vous pouvez utiliser le profileur JAX pour générer des traces d'un programme JAX. Deux méthodes courantes sont compatibles avec les parcours :

  • Programmatique
    • Capturer des profils de manière programmatique à partir de votre code JAX
  • Manuel
    • Capturer des profils à la demande après avoir démarré le serveur du profileur à partir de votre code JAX

Dans les deux cas, les profils sont écrits dans un bucket Cloud Storage. Plusieurs fichiers de trace seront créés dans le bucket Cloud Storage, potentiellement dans des dossiers d'horodatage différents. Par exemple :

  • Processus Python principal qui a appelé la trace (généralement la VM de votre notebook) : <jax-client-vm-name>.xplane.pb
  • Proxy Pathways IFRT : client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Gestionnaire de ressources Pathways : server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Employé(s) Pathways : server.*<tpu-node-name>.xplane.pb

Ces fichiers de trace peuvent être analysés avec TensorBoard en exécutant la commande suivante. Pour en savoir plus sur TensorBoard et tous ses outils de profilage, consultez Optimiser les performances de TensorFlow à l'aide de Profiler.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

Remplacez les éléments suivants :

  • BUCKET : bucket Cloud Storage permettant de stocker les fichiers de trace
  • PREFIX : chemin d'accès dans votre bucket Cloud Storage pour stocker les fichiers de trace

Capture programmatique de profils

Capturez un profil depuis votre code. Les profils sont enregistrés dans gs://<bucket>/<prefix> sous un répertoire d'horodatage.

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Capture manuelle de profil

Pour capturer manuellement des informations de profil, vous devez démarrer le serveur du profileur à partir de votre code Python :

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

Pendant que le serveur de profileur est en cours d'exécution, vous pouvez capturer un profil et exporter les données vers l'emplacement Cloud Storage cible :

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

Vous trouverez des informations sur le timing des méthodes du client proxy IFRT, telles que Compile et Execute, dans la trace de votre programme. Ces événements, qui détaillent les interactions avec le serveur gRPC du proxy IFRT lors de la compilation et de l'exécution, s'affichent sur le thread nommé GrpcClientSessionUserFuturesWorkQueue. En examinant ce thread dans votre trace, vous pouvez obtenir des informations sur les performances de ces opérations.

Options XLA

Lorsque vous utilisez Pathways, vous devez définir les indicateurs XLA dans le conteneur pathways-proxy. Vous pouvez le faire à l'aide de XPK ou de l'API PathwaysJob.

Lorsque vous utilisez XPK, définissez les indicateurs XLA comme suit :

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

Lorsque vous utilisez l'API PathwaysJob, définissez les indicateurs XLA comme suit :

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

Remplacez les éléments suivants :

  • USER : votre nom d'utilisateur Google Cloud
  • value[n] : les indicateurs XLA que vous souhaitez définir

Vidage HLO

Pour examiner en détail les entrées HLO (High Level Optimizer) fournies au compilateur XLA, vous pouvez configurer Pathways pour qu'il les transfère vers un emplacement Cloud Storage spécifié, comme suit :

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"

Étapes suivantes