Crie perfis de cargas de trabalho do PyTorch XLA

A otimização do desempenho é uma parte crucial da criação de modelos de aprendizagem automática eficientes. Pode usar a ferramenta de criação de perfis XProf para medir o desempenho das suas cargas de trabalho de aprendizagem automática. O XProf permite-lhe capturar rastreios detalhados da execução do seu modelo em dispositivos XLA. Estes rastreios podem ajudar a identificar gargalos de desempenho, compreender a utilização do dispositivo e otimizar o seu código.

Este guia descreve o processo de captura programática de um rastreio do seu script PyTorch XLA e de visualização através do XProf e do Tensorboard.

Capture um rastreio

Pode capturar um rastreio adicionando algumas linhas de código ao seu script de preparação existente. A ferramenta principal para capturar um rastreio é o módulo torch_xla.debug.profiler, que é normalmente importado com o alias xp.

1. Inicie o servidor do criador de perfis

Antes de poder capturar um rastreio, tem de iniciar o servidor do criador de perfis. Este servidor é executado em segundo plano no seu script e recolhe os dados de rastreio. Pode iniciá-lo chamando xp.start_server() perto do início do bloco de execução principal.

2. Defina a duração do rastreio

Inclua o código que quer analisar em chamadas xp.start_trace() e xp.stop_trace(). A função start_trace recebe um caminho para um diretório onde os ficheiros de rastreio são guardados.

É uma prática comum envolver o ciclo de treino principal para captar as operações mais relevantes.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Adicione etiquetas de rastreio personalizadas

Por predefinição, os rastreios capturados são funções Pytorch XLA de baixo nível e podem ser difíceis de navegar. Pode adicionar etiquetas personalizadas a secções específicas do seu código usando o gestor de contexto xp.Trace(). Estas etiquetas aparecem como blocos com nomes na vista de cronologia do criador de perfis, o que facilita muito a identificação de operações específicas, como a preparação de dados, a passagem direta ou o passo do otimizador.

O exemplo seguinte mostra como pode adicionar contexto a diferentes partes de um passo de preparação.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Exemplo completo

O exemplo seguinte mostra como capturar um rastreio de um script do PyTorch XLA com base no ficheiro mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Visualize o rastreio

Quando o script terminar, os ficheiros de rastreio são guardados no diretório que especificou (por exemplo, /root/logs/). Pode visualizar este rastreio através do XProf e do TensorBoard.

  1. Instale o TensorBoard.

    pip install tensorboard_plugin_profile tensorboard
  2. Inicie o TensorBoard. Indique ao TensorBoard o diretório de registo que usou em xp.start_trace():

    tensorboard --logdir /root/logs/
  3. Ver o perfil. Abra o URL fornecido pelo TensorBoard no seu navegador de Internet (normalmente, http://localhost:6006). Navegue para o separador PyTorch XLA - Profile para ver o rastreio interativo. Pode ver as etiquetas personalizadas que criou e analisar o tempo de execução de diferentes partes do seu modelo.

Se usar o Google Cloud para executar as suas cargas de trabalho, recomendamos a ferramenta cloud-diagnostics-xprof. Oferece uma experiência de visualização e recolha de perfis simplificada através de VMs que executam o Tensorboard e o XProf.