Crie perfis de cargas de trabalho do PyTorch XLA
A otimização do desempenho é uma parte crucial da criação de modelos de aprendizagem automática eficientes. Pode usar a ferramenta de criação de perfis XProf para medir o desempenho das suas cargas de trabalho de aprendizagem automática. O XProf permite-lhe capturar rastreios detalhados da execução do seu modelo em dispositivos XLA. Estes rastreios podem ajudar a identificar gargalos de desempenho, compreender a utilização do dispositivo e otimizar o seu código.
Este guia descreve o processo de captura programática de um rastreio do seu script PyTorch XLA e de visualização através do XProf e do Tensorboard.
Capture um rastreio
Pode capturar um rastreio adicionando algumas linhas de código ao seu script de
preparação existente. A ferramenta principal para capturar um rastreio é o módulo torch_xla.debug.profiler
, que é normalmente importado com o alias xp
.
1. Inicie o servidor do criador de perfis
Antes de poder capturar um rastreio, tem de iniciar o servidor do criador de perfis. Este servidor é executado em segundo plano no seu script e recolhe os dados de rastreio. Pode iniciá-lo chamando xp.start_server()
perto do início do bloco de execução principal.
2. Defina a duração do rastreio
Inclua o código que quer analisar em chamadas xp.start_trace()
e xp.stop_trace()
. A função start_trace
recebe um caminho para um diretório onde os ficheiros de rastreio são guardados.
É uma prática comum envolver o ciclo de treino principal para captar as operações mais relevantes.
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Adicione etiquetas de rastreio personalizadas
Por predefinição, os rastreios capturados são funções Pytorch XLA de baixo nível e podem ser difíceis de navegar. Pode adicionar etiquetas personalizadas a secções específicas do seu código
usando o gestor de contexto xp.Trace()
. Estas etiquetas aparecem como blocos com nomes na vista de cronologia do criador de perfis, o que facilita muito a identificação de operações específicas, como a preparação de dados, a passagem direta ou o passo do otimizador.
O exemplo seguinte mostra como pode adicionar contexto a diferentes partes de um passo de preparação.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Exemplo completo
O exemplo seguinte mostra como capturar um rastreio de um script do PyTorch XLA com base no ficheiro mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Visualize o rastreio
Quando o script terminar, os ficheiros de rastreio são guardados no diretório que
especificou (por exemplo, /root/logs/
). Pode visualizar este rastreio através do
XProf e do TensorBoard.
Instale o TensorBoard.
pip install tensorboard_plugin_profile tensorboard
Inicie o TensorBoard. Indique ao TensorBoard o diretório de registo que usou em
xp.start_trace()
:tensorboard --logdir /root/logs/
Ver o perfil. Abra o URL fornecido pelo TensorBoard no seu navegador de Internet (normalmente,
http://localhost:6006
). Navegue para o separador PyTorch XLA - Profile para ver o rastreio interativo. Pode ver as etiquetas personalizadas que criou e analisar o tempo de execução de diferentes partes do seu modelo.
Se usar o Google Cloud para executar as suas cargas de trabalho, recomendamos a ferramenta cloud-diagnostics-xprof. Oferece uma experiência de visualização e recolha de perfis simplificada através de VMs que executam o Tensorboard e o XProf.