Cloud TPU에서 PyTorch를 사용하여 Resnet50 학습

이 튜토리얼에서는 Cloud TPU 기기에서 PyTorch를 사용하여 ResNet-50 모델을 학습시키는 방법을 보여줍니다. PyTorch 및 ImageNet 데이터 세트를 사용하는 다른 TPU 최적화 이미지 분류 모델에 같은 패턴을 적용할 수 있습니다.

이 튜토리얼의 모델은 최초로 레지듀얼 네트워크(ResNet) 아키텍처를 도입한 이미지 인식을 위한 딥 레지듀얼 학습을 바탕으로 합니다. 이 튜토리얼에서는 50 레이어 변형판 ResNet-50을 사용하며 PyTorch/XLA를 통한 모델 학습을 보여줍니다.

TPU VM 만들기

  1. Cloud Shell 창을 엽니다.

    Cloud Shell 열기

  2. TPU VM을 만듭니다.

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. SSH를 사용하여 TPU VM에 연결합니다.

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. TPU VM에 PyTorch/XLA를 설치합니다.

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. PyTorch/XLA GitHub 저장소를 클론합니다.

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. 허위 데이터로 학습 스크립트 실행

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1