Genera perfiles de cargas de trabajo de PyTorch XLA

La optimización del rendimiento es una parte fundamental de la compilación de modelos de aprendizaje automático eficientes. Puedes usar la herramienta de generación de perfiles XProf para medir el rendimiento de tus cargas de trabajo de aprendizaje automático. XProf te permite capturar registros detallados de la ejecución de tu modelo en dispositivos XLA. Estos registros pueden ayudarte a identificar cuellos de botella en el rendimiento, comprender el uso del dispositivo y optimizar tu código.

En esta guía, se describe el proceso para capturar de forma programática un seguimiento de tu secuencia de comandos de PyTorch XLA y visualizarlo con XProf y TensorBoard.

Captura un seguimiento

Para capturar un seguimiento, agrega algunas líneas de código a tu secuencia de comandos de entrenamiento existente. La herramienta principal para capturar un seguimiento es el módulo torch_xla.debug.profiler, que suele importarse con el alias xp.

1. Inicia el servidor del generador de perfiles

Antes de capturar un seguimiento, debes iniciar el servidor del generador de perfiles. Este servidor se ejecuta en segundo plano en tu secuencia de comandos y recopila los datos de seguimiento. Puedes iniciarlo llamando a xp.start_server() cerca del comienzo de tu bloque de ejecución principal.

2. Define la duración del seguimiento

Encapsula el código que deseas analizar dentro de las llamadas xp.start_trace() y xp.stop_trace(). La función start_trace toma una ruta de acceso a un directorio en el que se guardan los archivos de seguimiento.

Es una práctica común encapsular el bucle de entrenamiento principal para capturar las operaciones más relevantes.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Agrega etiquetas de seguimiento personalizadas

De forma predeterminada, los seguimientos capturados son funciones de Pytorch XLA de bajo nivel y pueden ser difíciles de navegar. Puedes agregar etiquetas personalizadas a secciones específicas de tu código con el administrador de contexto xp.Trace(). Estas etiquetas aparecerán como bloques con nombre en la vista de línea de tiempo del generador de perfiles, lo que facilitará la identificación de operaciones específicas, como la preparación de datos, el pase hacia delante o el paso del optimizador.

En el siguiente ejemplo, se muestra cómo puedes agregar contexto a diferentes partes de un paso de entrenamiento.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Ejemplo completo

En el siguiente ejemplo, se muestra cómo capturar un seguimiento de secuencia de comandos de PyTorch XLA, basado en el archivo mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Visualiza el seguimiento

Cuando finalice la secuencia de comandos, los archivos de seguimiento se guardarán en el directorio que especificaste (por ejemplo, /root/logs/). Puedes visualizar este seguimiento con XProf y TensorBoard.

  1. Instala TensorBoard.

    pip install tensorboard_plugin_profile tensorboard
  2. Inicia TensorBoard. Dirige TensorBoard al directorio de registros que usaste en xp.start_trace():

    tensorboard --logdir /root/logs/
  3. Consulta el perfil. Abre la URL que proporciona TensorBoard en tu navegador web (generalmente, http://localhost:6006). Navega a la pestaña PyTorch XLA - Profile para ver el seguimiento interactivo. Podrás ver las etiquetas personalizadas que creaste y analizar el tiempo de ejecución de diferentes partes de tu modelo.

Si usas Google Cloud para ejecutar tus cargas de trabajo, te recomendamos la herramienta cloud-diagnostics-xprof. Proporciona una experiencia optimizada de recopilación y visualización de perfiles con VMs que ejecutan TensorBoard y XProf.