Genera perfiles de cargas de trabajo de PyTorch XLA
La optimización del rendimiento es una parte fundamental de la compilación de modelos de aprendizaje automático eficientes. Puedes usar la herramienta de generación de perfiles XProf para medir el rendimiento de tus cargas de trabajo de aprendizaje automático. XProf te permite capturar registros detallados de la ejecución de tu modelo en dispositivos XLA. Estos registros pueden ayudarte a identificar cuellos de botella en el rendimiento, comprender el uso del dispositivo y optimizar tu código.
En esta guía, se describe el proceso para capturar de forma programática un seguimiento de tu secuencia de comandos de PyTorch XLA y visualizarlo con XProf y TensorBoard.
Captura un seguimiento
Para capturar un seguimiento, agrega algunas líneas de código a tu secuencia de comandos de entrenamiento
existente. La herramienta principal para capturar un seguimiento es el módulo torch_xla.debug.profiler,
que suele importarse con el alias xp.
1. Inicia el servidor del generador de perfiles
Antes de capturar un seguimiento, debes iniciar el servidor del generador de perfiles. Este
servidor se ejecuta en segundo plano en tu secuencia de comandos y recopila los datos de seguimiento. Puedes
iniciarlo llamando a xp.start_server() cerca del comienzo de tu bloque de ejecución
principal.
2. Define la duración del seguimiento
Encapsula el código que deseas analizar dentro de las llamadas xp.start_trace() y
xp.stop_trace(). La función start_trace toma una ruta de acceso a un directorio
en el que se guardan los archivos de seguimiento.
Es una práctica común encapsular el bucle de entrenamiento principal para capturar las operaciones más relevantes.
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Agrega etiquetas de seguimiento personalizadas
De forma predeterminada, los seguimientos capturados son funciones de Pytorch XLA de bajo nivel y
pueden ser difíciles de navegar. Puedes agregar etiquetas personalizadas a secciones específicas de tu código
con el administrador de contexto xp.Trace(). Estas etiquetas aparecerán como bloques
con nombre en la vista de línea de tiempo del generador de perfiles, lo que facilitará la identificación de operaciones específicas,
como la preparación de datos, el pase hacia delante o el paso del optimizador.
En el siguiente ejemplo, se muestra cómo puedes agregar contexto a diferentes partes de un paso de entrenamiento.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Ejemplo completo
En el siguiente ejemplo, se muestra cómo capturar un seguimiento de secuencia de comandos de PyTorch XLA, basado en el archivo mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Visualiza el seguimiento
Cuando finalice la secuencia de comandos, los archivos de seguimiento se guardarán en el directorio que
especificaste (por ejemplo, /root/logs/). Puedes visualizar este seguimiento
con XProf y TensorBoard.
Instala TensorBoard.
pip install tensorboard_plugin_profile tensorboard
Inicia TensorBoard. Dirige TensorBoard al directorio de registros que usaste en
xp.start_trace():tensorboard --logdir /root/logs/
Consulta el perfil. Abre la URL que proporciona TensorBoard en tu navegador web (generalmente,
http://localhost:6006). Navega a la pestaña PyTorch XLA - Profile para ver el seguimiento interactivo. Podrás ver las etiquetas personalizadas que creaste y analizar el tiempo de ejecución de diferentes partes de tu modelo.
Si usas Google Cloud para ejecutar tus cargas de trabajo, te recomendamos la herramienta cloud-diagnostics-xprof. Proporciona una experiencia optimizada de recopilación y visualización de perfiles con VMs que ejecutan TensorBoard y XProf.