Ce tutoriel explique comment entraîner le modèle ResNet-50 sur un appareil Cloud TPU avec PyTorch. La même procédure peut s'appliquer à d'autres modèles de classification d'image optimisés pour TPU, qui utilisent PyTorch et l'ensemble de données ImageNet.
Le modèle utilisé dans ce tutoriel est basé sur l'article Deep Residual Learning for Image Recognition (Deep learning résiduel pour la reconnaissance d'images), qui présente l'architecture de réseau résiduel (ResNet). Le tutoriel emploie la variante à 50 couches, ResNet-50, et illustre l'entraînement du modèle à l'aide de PyTorch/XLA.
Objectifs
- Préparer l'ensemble de données
- Exécuter la tâche d'entraînement
- Vérifier les résultats
Coûts
Dans ce document, vous utilisez les composants facturables suivants de Google Cloud:
- Compute Engine
- Cloud TPU
Obtenez une estimation des coûts en fonction de votre utilisation prévue,
utilisez le simulateur de coût.
Avant de commencer
Avant de commencer ce tutoriel, vérifiez que votre Google Cloud projet est correctement configuré.
- Connectez-vous à votre Google Cloud compte. Si vous débutez sur Google Cloud, créez un compte pour évaluer les performances de nos produits en conditions réelles. Les nouveaux clients bénéficient également de 300 $de crédits sans frais pour exécuter, tester et déployer des charges de travail.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Verify that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Verify that billing is enabled for your Google Cloud project.
Ce tutoriel utilise des composants facturables de Google Cloud. Consultez la grille tarifaire de Cloud TPU pour estimer vos coûts. Veillez à nettoyer les ressources que vous avez créées lorsque vous avez terminé, afin d'éviter des frais inutiles.
Créer une VM TPU
Ouvrez une fenêtre Cloud Shell.
Créez une VM TPU.
gcloud compute tpus tpu-vm create your-tpu-name \ --accelerator-type=v5litepod-8 \ --version=tpu-ubuntu2204-base \ --zone=us-central1-a \ --project=your-project
Connectez-vous à votre VM TPU à l'aide de SSH :
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=us-central1-a
Installez PyTorch/XLA sur votre VM TPU :
(vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
Clonez le dépôt GitHub PyTorch/XLA.
(vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
Exécutez le script d'entraînement avec des données factices.
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
Effectuer un nettoyage
Pour éviter que les ressources utilisées dans ce tutoriel soient facturées sur votre compte Google Cloud, supprimez le projet contenant les ressources, ou conservez le projet et supprimez chaque ressource individuellement.
Déconnectez-vous de la VM TPU :
(vm) $ exit
Votre invite de commande devrait maintenant être
username@projectname, indiquant que vous êtes dans Cloud Shell.Supprimez votre VM TPU.
$ gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=us-central1-a
Étapes suivantes
- Entraîner des modèles de diffusion avec PyTorch
- Résoudre les problèmes liés à PyTorch sur les TPU
- Documentation PyTorch/XLA