En raison de la nature distribuée de JAX avec Pathways, certaines opérations peuvent ne pas bien s'adapter en raison des frais généraux de communication. Bien que Pathways minimise ces frais généraux grâce à des fonctionnalités telles que l'envoi asynchrone, vous devez tenir compte de certains éléments lorsque vous portez des charges de travail JAX vers Pathways ou que vous mettez à l'échelle une charge de travail JAX avec Pathways sur un grand nombre d'accélérateurs.
Avant de commencer
Vérifiez que vous disposez bien des éléments suivants :
- Outils Kubernetes installés
- gcloud CLI installé
- Activer l'API TPU
- Activer l'API Google Kubernetes Engine
Index du processus
JAX avec Pathways traite tous les appareils de votre cluster Pathways comme des appareils locaux. Cela simplifie la gestion des appareils et permet à JAX d'utiliser toutes les ressources disponibles. En pratique, cela signifie :
jax.process_index()est toujours défini sur 0 pour tous les appareils.jax.devices()etjax.local_devices()renvoient tous les appareils TPU de l'ensemble du job.
Type de matériel et colocation
Pour optimiser les performances, placez tous les composants Pathways et le job utilisateur dans la même zone cloud Google Cloud . Utilisez un processeur volumineux comme le proxy IFRT et le gestionnaire de ressources. Nous vous recommandons au moins un n2-standard-64 dédié, qui est fourni avec 64 processeurs virtuels et 256 Go de mémoire.
PathwaysUtils
Pathways-utils est un dépôt GitHub basé sur Python qui fournit des utilitaires et des outils essentiels vous permettant de simplifier le déploiement et l'exécution des charges de travail JAX sur l'architecture Pathways on Cloud. Ce package gère les adaptations nécessaires pour l'environnement cloud, ce qui permet aux développeurs JAX de se concentrer sur leurs principaux workflows de machine learning avec une configuration minimale spécifique à la plate-forme. Plus précisément, il offre les avantages suivants :
- Un backend JAX "proxy" : ce backend personnalisé permet à votre application JAX d'utiliser l'infrastructure Pathways en définissant la variable d'environnement
JAX_PLATFORMS=proxy. - Utilitaires de profilage intégrés : fonctionnalités de profilage qui vous permettent de comprendre les performances de votre application. En utilisant des API de profilage JAX standards telles que
jax.profiler.start_traceetjax.profiler.start_server, vous pouvez profiler non seulement votre code JAX, mais aussi les composants Pathways sous-jacents, ce qui vous donne une vue globale de l'exécution dans l'environnement cloud. - Point de contrôle distribué avec Orbax : un gestionnaire de points de contrôle Orbax personnalisé qui vous permet d'utiliser des points de contrôle distribués et de restaurer vos points de contrôle lorsque vous utilisez la bibliothèque Orbax dans l'environnement Pathways. Cette intégration vise à fonctionner sans nécessiter de modification de votre code de point de contrôle Orbax existant, à condition qu'il importe
pathwaysutils. - Primitives d'entraînement élastique : fournit des primitives d'entraînement élastique de base que vous pouvez utiliser pour créer des workflows d'entraînement robustes et évolutifs à l'aide de Pathways. Ces primitives permettent à vos jobs d'entraînement de s'adapter de manière dynamique aux changements de ressources disponibles, ce qui améliore l'efficacité et la résilience dans les environnements cloud.
Points de contrôle
Orbax est minutieusement testé avec Pathways pour la création et la restauration distribuées de points de contrôle avec Cloud Storage. Lorsque vous appelez import pathwaysutils; pathwaysutils.initialize() dans votre train.py, un ArrayHandler personnalisé est enregistré pour gérer efficacement les opérations de point de contrôle via le proxy IFRT, ce qui permet aux nœuds de calcul Pathways sur les accélérateurs d'enregistrer et de restaurer directement les données.
Python colocalisé
Colocated Python est une API JAX Open Source qui vous permet d'exécuter du code Python spécifié par l'utilisateur directement sur les hôtes TPU ou GPU, ce qui est plus simple dans JAX multicontrôleur. Cela permet d'éviter le transfert de données entre le client et les machines TPU pour les tâches nécessitant davantage de puissance de calcul, comme le chargement de données et la création de points de contrôle. Pour configurer votre cluster Pathways afin d'exécuter l'API JAX Python colocalisée, suivez les instructions du fichier README Python colocalisé. Ces instructions expliquent comment démarrer un fichier side-car Python colocalisé en même temps que vos workers Pathways.
Chargement des données…
Pendant l'entraînement, nous chargeons à plusieurs reprises des lots à partir d'un ensemble de données pour les fournir au modèle. Il est important de disposer d'un chargeur de données asynchrone et efficace qui segmente le lot sur les hôtes pour éviter de priver les accélérateurs de travail. Lorsque vous exécutez l'entraînement avec Pathways, le chargeur de données s'exécute sur une VM de processeur (contrairement à une VM TPU qui est utilisée dans les configurations multi-contrôleurs) et distribue les données aux VM TPU. Cela entraîne une latence plus élevée lors de la lecture des données, mais celle-ci est partiellement atténuée en lisant à l'avance un certain nombre de lots sur l'hôte du processeur et en envoyant les données lues de manière asynchrone aux TPU. Cette solution est suffisante pour une échelle petite à moyenne.
Pour des performances optimales à grande échelle, nous vous recommandons vivement de colocaliser votre pipeline de données d'entrée en utilisant Python colocalisé pour exécuter votre pipeline de données directement sur les accélérateurs. Cela élimine le goulot d'étranglement du processeur et exploite les interconnexions rapides de la TPU pour le transfert de données.
Vous trouverez une implémentation de référence de la migration d'un pipeline d'entrée basé sur TFDS dans l'implémentation RemoteIterator dans multihost_dataloading.py.
Cette implémentation fonctionne à la fois sur JAX et Pathways multicontrôleurs de manière distribuée à l'aide de l'API Python JAX colocalisée.
Gestion des versions de Jax
Les versions de Pathways sont étroitement liées aux versions de JAX pour garantir la compatibilité et la stabilité. Pour éviter tout problème potentiel, vérifiez que vos artefacts Pathways et votre version JAX sont alignés. Chaque version de Pathways spécifie clairement les versions JAX compatibles via un tag au format jax-<version>.
Cache de compilation
Le cache de compilation persistant Pathways est une fonctionnalité qui permet aux serveurs Pathways de stocker les exécutables XLA compilés dans un emplacement persistant, tel que Cloud Storage, pour éviter les compilations redondantes. Cette fonctionnalité est activée par défaut. L'emplacement du cache est transmis en tant qu'indicateur --gcs_scratch_location aux conteneurs du gestionnaire de ressources et du nœud de calcul Pathways. Pour réduire au minimum les coûts de stockage associés, le cache associe une règle de cycle de vie à l'emplacement Cloud Storage. Le nombre de règles par bucket Cloud Storage est limité à 50. Par conséquent, nous vous recommandons d'utiliser un emplacement Cloud Storage commun pour toutes les charges de travail.
Ce cache est semblable au cache de compilation JAX, qui est désactivé par pathwaysutils.initialize() pour les charges de travail Pathways.
Profilage
Vous pouvez utiliser le profileur JAX pour générer des traces d'un programme JAX. Deux méthodes courantes sont compatibles avec les parcours :
- Programmatique
- Capturer des profils de manière programmatique à partir de votre code JAX
- Manuel
- Capturer des profils à la demande après avoir démarré le serveur du profileur à partir de votre code JAX
Dans les deux cas, les profils sont écrits dans un bucket Cloud Storage. Plusieurs fichiers de trace seront créés dans le bucket Cloud Storage, potentiellement dans des dossiers d'horodatage différents. Par exemple :
- Processus Python principal qui a appelé la trace (généralement la VM de votre notebook) :
<jax-client-vm-name>.xplane.pb - Proxy Pathways IFRT :
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Gestionnaire de ressources Pathways :
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Employé(s) Pathways :
server.*<tpu-node-name>.xplane.pb
Ces fichiers de trace peuvent être analysés avec TensorBoard en exécutant la commande suivante. Pour en savoir plus sur TensorBoard et tous ses outils de profilage, consultez Optimiser les performances de TensorFlow à l'aide de Profiler.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Remplacez les éléments suivants :
BUCKET: bucket Cloud Storage permettant de stocker les fichiers de tracePREFIX: chemin d'accès dans votre bucket Cloud Storage pour stocker les fichiers de trace
Capture programmatique de profils
Capturez un profil depuis votre code. Les profils sont enregistrés dans gs://<bucket>/<prefix> sous un répertoire d'horodatage.
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Capture manuelle de profil
Pour capturer manuellement des informations de profil, vous devez démarrer le serveur du profileur à partir de votre code Python :
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
Pendant que le serveur de profileur est en cours d'exécution, vous pouvez capturer un profil et exporter les données vers l'emplacement Cloud Storage cible :
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
Vous trouverez des informations sur le timing des méthodes du client proxy IFRT, telles que Compile et Execute, dans la trace de votre programme. Ces événements, qui détaillent les interactions avec le serveur gRPC du proxy IFRT lors de la compilation et de l'exécution, s'affichent sur le thread nommé GrpcClientSessionUserFuturesWorkQueue. En examinant ce thread dans votre trace, vous pouvez obtenir des informations sur les performances de ces opérations.
Options XLA
Lorsque vous utilisez Pathways, vous devez définir les indicateurs XLA dans le conteneur pathways-proxy. Vous pouvez le faire à l'aide de XPK ou de l'API PathwaysJob.
Lorsque vous utilisez XPK, définissez les indicateurs XLA comme suit :
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Lorsque vous utilisez l'API PathwaysJob, définissez les indicateurs XLA comme suit :
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Remplacez les éléments suivants :
USER: votre nom d'utilisateur Google Cloudvalue[n]: les indicateurs XLA que vous souhaitez définir
Vidage HLO
Pour examiner en détail les entrées HLO (High Level Optimizer) fournies au compilateur XLA, vous pouvez configurer Pathways pour qu'il les transfère vers un emplacement Cloud Storage spécifié, comme suit :
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"
Étapes suivantes
- Créer un cluster GKE avec Pathways
- Inférence multihôte avec Pathways
- Charges de travail par lot avec Pathways
- Mode interactif des parcours
- Entraînement résilient avec Pathways
- Parcours de dépannage