En raison de la nature distribuée de JAX avec Pathways, certaines opérations peuvent ne pas être bien mises à l'échelle en raison des frais généraux de communication. Bien que Pathways minimise ces frais généraux avec des fonctionnalités telles que la distribution asynchrone, vous devez tenir compte de certains éléments lorsque vous portez des charges de travail JAX vers Pathways ou que vous mettez à l'échelle une charge de travail JAX avec Pathways sur un grand nombre d'accélérateurs.
Avant de commencer
Vérifiez que vous disposez bien des éléments suivants :
- Outils Kubernetes installés
- gcloud CLI installée
- API TPU activée
- API Google Kubernetes Engine activée
Index du processus
JAX avec Pathways traite tous les appareils de votre cluster Pathways comme des appareils locaux. Cela simplifie la gestion des appareils et permet à JAX d'utiliser toutes les ressources disponibles. En pratique, cela signifie que :
jax.process_index()est toujours 0 pour tous les appareils.jax.devices()etjax.local_devices()renvoient tous les appareils TPU pour l'ensemble de la tâche.
Type de matériel et colocation
Pour optimiser les performances, placez tous les composants Pathways et la tâche utilisateur dans la
même Google Cloud zone cloud. Utilisez un processeur volumineux, tel que le proxy IFRT et le gestionnaire de ressources. Nous vous recommandons au moins un n2-standard-64 dédié, qui est fourni avec 64 processeurs virtuels et 256 Go de mémoire.
PathwaysUtils
Pathways-utils est un dépôt GitHub basé sur Python qui fournit des utilitaires et des outils essentiels vous permettant de simplifier le déploiement et l’exécution des charges de travail JAX sur l’architecture Pathways on Cloud. Ce package gère les adaptations nécessaires pour l'environnement cloud, ce qui permet aux développeurs JAX de se concentrer sur leurs workflows de machine learning de base avec une configuration minimale spécifique à la plate-forme. Plus précisément, il offre les avantages suivants :
- Un backend JAX "proxy" : ce backend personnalisé permet à votre application JAX d'utiliser l'infrastructure Pathways en définissant la variable d'environnement
JAX_PLATFORMS=proxy. - Utilitaires de profilage intégrés : fonctionnalités de profilage qui vous permettent de comprendre les performances de votre application. En utilisant des API de profilage JAX standards telles que
jax.profiler.start_traceetjax.profiler.start_server, vous pouvez profiler non seulement votre code JAX, mais également les composants Pathways sous-jacents, ce qui vous offre une vue globale de l'exécution dans l'environnement cloud. - Points de contrôle distribués avec Orbax : un gestionnaire de points de contrôle Orbax personnalisé qui vous permet d'utiliser des points de contrôle distribués et de les restaurer lorsque vous utilisez la bibliothèque Orbax dans l'environnement Pathways. Cette intégration vise à fonctionner sans nécessiter de modification de votre code de point de contrôle Orbax existant, à condition qu'il importe
pathwaysutils. - Primitives d'entraînement élastiques : fournit des primitives d'entraînement élastiques fondamentales que vous pouvez utiliser pour créer des workflows d'entraînement robustes et évolutifs à l'aide de Pathways. Ces primitives permettent à vos tâches d'entraînement de s'adapter de manière dynamique aux modifications des ressources disponibles, ce qui améliore l'efficacité et la résilience dans les environnements cloud.
Points de contrôle
Orbax est entièrement testé avec Pathways pour
la création et la restauration de points de contrôle distribués avec Cloud Storage. Lorsque vous
appelez import pathwaysutils; pathwaysutils.initialize() dans votre train.py, un
ArrayHandler personnalisé est enregistré. Il gère efficacement les opérations de point de contrôle
via le proxy IFRT, ce qui permet aux nœuds de calcul Pathways sur les accélérateurs d'enregistrer et de restaurer directement les données.
Python colocalisé
Python colocalisé est une API JAX Open Source qui vous permet d'exécuter du code Python spécifié par l'utilisateur directement sur les hôtes TPU ou GPU, ce qui est plus simple dans JAX multi-contrôleur JAX. Cela permet aux tâches nécessitant davantage de calcul, telles que le chargement des données et la création de points de contrôle, d'éviter le transfert de données entre le client et les machines TPU. Pour configurer votre cluster Pathways afin qu'il exécute l'API JAX Python colocalisée, suivez les instructions du fichier Lisez-moi de Python colocalisé. Ces instructions expliquent comment démarrer un side-car Python colocalisé en même temps que vos nœuds de calcul Pathways.
Chargement des données
Pendant l'entraînement, nous chargeons à plusieurs reprises des lots à partir d'un ensemble de données pour les fournir au modèle. Il est important de disposer d'un chargeur de données asynchrone et efficace qui segmente le lot sur les hôtes pour éviter de priver les accélérateurs de travail. Lorsque vous exécutez l'entraînement avec Pathways, le chargeur de données s'exécute sur une VM de processeur (contrairement à une VM TPU utilisée sur les configurations multi-contrôleurs) et distribue les données aux VM TPU. Cela entraîne une latence plus élevée lors de la lecture des données, mais elle est partiellement atténuée en lisant à l'avance un nombre X de lots sur l'hôte du processeur et en distribuant les données lues de manière asynchrone aux TPU. Cette solution est suffisante lorsque vous exécutez à petite ou moyenne échelle.
Pour des performances optimales à grande échelle, nous vous recommandons vivement de colocaliser votre pipeline de données d'entrée en utilisant Python colocalisé pour exécuter votre pipeline de données directement sur les accélérateurs. Cela élimine le goulot d'étranglement du processeur et exploite les interconnexions rapides du TPU pour le transfert de données.
Vous trouverez une implémentation de référence de la migration d'un pipeline d'entrée basé sur TFDS
dans l'implémentation RemoteIterator dans
multihost_dataloading.py.
Cette implémentation fonctionne à la fois sur JAX multi-contrôleur et sur Pathways de manière distribuée à l'aide de l'API JAX Python colocalisée.
Gestion des versions de Jax
Les versions de Pathways sont étroitement liées aux versions de JAX pour garantir la compatibilité et la stabilité. Pour éviter tout problème potentiel, vérifiez que vos artefacts Pathways et votre version de JAX sont alignés. Chaque version de Pathways spécifie clairement les
versions de JAX compatibles via un tag au format jax-<version>.
Cache de compilation
Le cache de compilation persistant Pathways est une fonctionnalité qui permet aux serveurs Pathways de stocker des exécutables XLA compilés dans un emplacement persistant, tel que Cloud Storage, afin d'éviter toute compilation redondante. Cette fonctionnalité est activée par défaut. L'emplacement du cache est transmis en tant qu'option --gcs_scratch_location aux conteneurs de gestionnaire de ressources et de nœud de calcul Pathways. Pour réduire au minimum les coûts de stockage associés, le cache associe une règle de cycle de vie à l'emplacement Cloud Storage. Il existe une limite de 50 règles par bucket Cloud Storage. Par conséquent, nous vous recommandons d'utiliser un emplacement Cloud Storage commun pour toutes les charges de travail.
Ce cache est semblable au cache de compilation JAX
qui est désactivé par pathwaysutils.initialize() pour les charges de travail Pathways.
Les autorisations Cloud Storage suivantes sont requises pour le cache de compilation :
storage.buckets.get: pour récupérer les métadonnées du bucket.storage.buckets.update: essentiel pour que Pathways configure des règles de cycle de vie des objets afin d'appliquer la valeur TTL pour l'éviction du cache.storage.objects.list: pour répertorier les objets de cache existants dans le bucket.storage.objects.create: pour écrire de nouveaux exécutables compilés dans le cache.storage.objects.get: pour lire les exécutables mis en cache à partir du bucket.
Profilage
Vous pouvez utiliser le profileur JAX pour générer des traces d'un programme JAX. Il existe deux méthodes courantes compatibles avec Pathways :
- Programmatique
- Capturer des profils de manière programmatique à partir de votre code JAX
- Manuelle
- Capturer des profils à la demande après avoir démarré le serveur de profileur à partir de votre code JAX
Dans les deux cas, les profils sont écrits dans un bucket Cloud Storage. Plusieurs fichiers de trace sont créés dans le bucket Cloud Storage, potentiellement dans différents dossiers d'horodatage, par exemple :
- Processus Python principal qui a appelé la trace (généralement votre VM de notebook) :
<jax-client-vm-name>.xplane.pb - Proxy IFRT Pathways :
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Gestionnaire de ressources Pathways :
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Nœud de calcul Pathways :
server.*<tpu-node-name>.xplane.pb
Ces fichiers de trace peuvent être analysés avec TensorBoard en exécutant la commande suivante. Pour en savoir plus sur TensorBoard et tous ses outils de profilage, consultez la section Optimiser les performances de TensorFlow à l'aide de Profiler.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Remplacez les éléments suivants :
BUCKET: bucket Cloud Storage pour stocker les fichiers de tracePREFIX: chemin d'accès dans votre bucket Cloud Storage pour stocker les fichiers de trace
Capture de profil programmatique
Capturez un profil à partir de votre code. Les profils sont enregistrés dans
gs://<bucket>/<prefix> sous un répertoire d'horodatage.
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Capture de profil manuelle
Pour capturer manuellement les informations de profil, vous devez démarrer le serveur de profileur à partir de votre code Python :
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functionally a no-op
Lorsque le serveur de profileur est en cours d'exécution, vous pouvez capturer un profil et exporter les données vers l'emplacement Cloud Storage cible :
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
Vous trouverez des informations de timing pour les méthodes de client proxy IFRT telles que Compile et Execute dans la trace de votre programme. Ces événements, qui détaillent les interactions avec le serveur gRPC du proxy IFRT lors de la compilation et de l'exécution, s'affichent sur le thread nommé GrpcClientSessionUserFuturesWorkQueue. En examinant ce thread dans votre trace, vous pouvez obtenir des informations sur les performances de ces opérations.
Options XLA
Lorsque vous utilisez Pathways, vous devez définir les options XLA dans le conteneur pathways-proxy. Vous pouvez le faire à l'aide de XPK ou de l'API PathwaysJob.
Lorsque vous utilisez XPK, définissez les options XLA comme suit :
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Lorsque vous utilisez l'API PathwaysJob, définissez les options XLA comme suit :
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Remplacez les éléments suivants :
USER: votre Google Cloud nom d'utilisateurvalue[n]: les options XLA que vous souhaitez définir
Vidage HLO
Pour approfondir les entrées HLO (High Level Optimizer) fournies au compilateur XLA, vous pouvez configurer Pathways pour qu'il vide le HLO dans un emplacement Cloud Storage spécifié comme suit :
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"
Étape suivante
- Créer un cluster GKE avec Pathways
- Inférence multihôte avec Pathways
- Charges de travail par lot avec Pathways
- Mode interactif Pathways
- Entraînement résilient avec Pathways
- Dépannage de Pathways