Criar perfis de cargas de trabalho do PyTorch XLA

A otimização de desempenho é essencial para criar modelos de machine learning eficientes. Use a ferramenta de criação de perfil XProf para medir o desempenho das cargas de trabalho de machine learning. O XProf permite capturar traces detalhados da execução do modelo em dispositivos XLA. Esses traces podem ajudar você a identificar gargalos de desempenho, entender a utilização do dispositivo e otimizar o código.

Este guia descreve o processo de captura programática de um trace do script do PyTorch XLA e visualização usando o XProf e o Tensorboard.

Capturar um trace

Para capturar um trace, adicione algumas linhas de código ao script de treinamento atual. A principal ferramenta para capturar um trace é o módulo torch_xla.debug.profiler, que geralmente é importado com o alias xp.

1. Iniciar o servidor do criador de perfil

Antes de capturar um trace, é necessário iniciar o servidor do criador de perfil. Esse servidor é executado em segundo plano no script e coleta os dados de trace. Para iniciá-lo, chame xp.start_server() perto do início do bloco de execução principal.

2. Definir a duração do trace

Encapsule o código que passará pela criação de perfil nas chamadas xp.start_trace() e xp.stop_trace(). A função start_trace usa um caminho para um diretório em que os arquivos de trace são salvos.

É uma prática comum encapsular o loop de treinamento principal para capturar as operações mais relevantes.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Adicionar rótulos de trace personalizados

Por padrão, os traces capturados são funções XLA do Pytorch de baixo nível e podem ser difíceis de compreender. É possível adicionar rótulos personalizados a seções específicas do código usando o gerenciador de contexto xp.Trace(). Esses rótulos aparecem como blocos nomeados na visualização da linha do tempo do criador de perfil e facilitam a identificação de operações específicas, como preparação de dados, transmissão direta ou etapa do otimizador.

O exemplo a seguir mostra como adicionar contexto a diferentes partes de uma etapa de treinamento.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Exemplo completo

O exemplo a seguir mostra como capturar um trace de um script do PyTorch XLA com base no arquivo mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Visualizar o trace

Quando o script for concluído, os arquivos de trace serão salvos no diretório especificado, como /root/logs/. É possível visualizar esse trace usando o XProf e o TensorBoard.

  1. Instale o TensorBoard.

    pip install tensorboard_plugin_profile tensorboard
  2. Inicie o TensorBoard. Aponte o TensorBoard para o diretório de registros usado em xp.start_trace():

    tensorboard --logdir /root/logs/
  3. Acesse o perfil. Abra o URL fornecido pelo TensorBoard no navegador da Web (geralmente http://localhost:6006) e acesse a guia PyTorch XLA – Perfil para conferir o trace interativo. Você poderá conferir os rótulos personalizados que criou e analisar o tempo de execução de diferentes partes do modelo.

Se você usa o Google Cloud para executar cargas de trabalho, recomendamos a ferramenta cloud-diagnostics-xprof. Ela oferece uma experiência simplificada de coleta e visualização de perfis usando VMs que executam o Tensorboard e o XProf.