Get started with the ML Diagnostics SDK
The ML Diagnostics Python SDK can be integrated with ML workloads to collect and manage workload metrics, configs, and profiles on Google Cloud. This guide shows you how to create machine learning runs, collect and manage workload metrics and configs, deploy managed XProf resources, and enable programmatic and on-demand profile capture.
For more information on using the ML Diagnostics SDK, see the google-cloud-mldiagnostics repository.
Install ML Diagnostics SDK
Install the google-cloud-mldiagnostics
library:
pip install google-cloud-mldiagnostics
Import the following packages in your ML workload code:
from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import metrics
from google_cloud_mldiagnostics import xprof
Enable Cloud Logging
The SDK uses the standard Python logging module to output metrics and config
information. To route these logs to Cloud Logging, install and configure the
google-cloud-logging library. This lets you view SDK logs, logged metrics, and
your own application logs within the Google Cloud console.
Install the google-cloud-logging library:
pip install google-cloud-logging
Configure logging in your script by attaching the Cloud Logging handler to the Python root logger. Add the following lines to the beginning of your Python script:
import logging
import google.cloud.logging
# Instantiate a Cloud Logging client
logging_client = google.cloud.logging.Client()
# Attach the Cloud Logging handler to the Python root logger
logging_client.setup_logging()
# Standard logging calls will go to Cloud Logging
logging.info("SDK logs and application logs will appear in Cloud Logging.")
Enable detailed logging
By default, the logging level is set to INFO. To receive more detailed logs
from the SDK, such as machine learning run details, set the logging level to
DEBUG after calling setup_logging():
import logging
import google.cloud.logging
logging_client = google.cloud.logging.Client()
logging_client.setup_logging()
logging.getLogger().setLevel(logging.DEBUG) # Enable DEBUG level logs
logging.debug("This is a debug message.")
logging.info("This is an info message.")
With DEBUG enabled, you receive additional SDK diagnostics in
Cloud Logging. For example:
DEBUG:google_cloud_mldiagnostics.core.global_manager:current run details:
{'name': 'projects/my-gcp-project/locations/us-central1/mlRuns/my-run-12345',
'gcs_path': 'gs://my-bucket/profiles', ...}
Create a machine learning run
To use the ML Diagnostics platform, you need to first create a machine learning run. This involves instrumenting your ML workload with the SDK to perform logging, collect metrics, and enable profile tracing.
The following is a basic example that initializes Cloud Logging, creates a
machine learning run (MLRun), records metrics, and captures a profile:
import logging
import os
import google.cloud.logging
from google_cloud_mldiagnostics import machinelearning_run, metrics, xprof, metric_types
# 1. Set up Cloud Logging
# Make sure to pip install google-cloud-logging
logging_client = google.cloud.logging.Client()
logging_client.setup_logging()
# Optional: Set logging level to DEBUG for more detailed SDK logs
logging.getLogger().setLevel(logging.DEBUG)
# 2. Define and start machinelearning run
try:
run = machinelearning_run(
name="<run_name>",
run_group="<run_group>",
configs={ "epochs": 100, "batch_size": 32 },
project="<some_project>",
region="<some_zone>",
gcs_path="gs://<some_bucket>",
on_demand_xprof=True,
)
logging.info(f"MLRun created: {run.name}")
# 3. Collect metrics during your run
metrics.record(metric_types.MetricType.LOSS, 0.123, step=1)
logging.info("Loss metric recorded.")
# 4. Capture profiles programmatically
with xprof():
# ... your code to profile here ...
pass
logging.info("Profile captured.")
except Exception as e:
logging.error(f"Error during MLRun: {e}", exc_info=True)
The code example uses the following variables:
| Variable | Requirement | Description |
|---|---|---|
name |
Required | An identifier for the specific run. The SDK automatically creates a
machine-learning-run-id to ensure that run
names are unique. |
run_group |
Optional | An identifier that can help group multiple runs belonging to the same experiment. For example, all runs associated with a TPU slice size sweep could belong to the same group. |
project |
Optional | If not specified, the project is extracted from Google Cloud CLI. |
region |
Required | All Cluster Director locations
are supported except us-east5. This flag can be set by an
argument for each command, or with the command:
gcloud config set compute/region. |
configs |
Optional | Key-value pairs containing configuration parameters for the run. If configs are not defined, default software and system configs appear but the ML workload configs do not. |
gcs_path |
Conditionally Required | The Google Cloud Storage location where all profiles are saved.
For example: gs://my-bucket or gs://my-bucket/folder1.
Required only if the SDK is used for profile capture. |
on-demand-xprof |
Optional | Starts xprofz daemon on port 9999 to enable on-demand
profiling. You can enable both on-demand profiling and programmatic profiling
in the same code, as long as they don't occur at the same time. |
The following configs are automatically collected by the SDK and don't need to
be specified within machinelearning_run:
- Software configs: Framework, framework version, XLA flags.
- System configs: Device type, number of slices, slice size, number of hosts.
Project and region information is stored as machine learning run metadata. The region used for the machine learning run does not have to match the region used for the workload run.
Write configs
Many workloads contain too many configs to define directly in the
machinelearning_run definition. In these cases, you can write configs to your
run using JSON or YAML.
import yaml
import json
# Read the YAML file
with open('config.yaml', 'r') as yaml_file:
# Parse YAML into a Python dictionary
yaml_data = yaml.safe_load(yaml_file)
# Define machinelearning run
machinelearning_run(
name="RUN_NAME",
run_group="GROUP_NAME",
configs=yaml_data,
project="PROJECT_NAME",
region="ZONE",
gcs_path="gs://BUCKET_NAME",
)
Collect metrics
You can collect model metrics, model performance metrics, and system metrics with the SDK. You can create visualizations of these metrics as average values and with time series charts.
The SDK provides two functions for recording metrics: metrics.record() for
capturing individual data points, and metrics.record_metrics() for recording
multiple metrics in a single batch. Both functions write metrics to
Cloud Logging, enabling visualization and analysis.
To record a single metric:
# Record a metric only with time as the x-axis
metrics.record(metric_types.MetricType.LOSS, 0.123)
# Record a metric with time and step as the x-axis
metrics.record(metric_types.MetricType.LOSS, 0.123, step=1)
To record multiple metrics:
from google_cloud_mldiagnostics import metric_types
# User codes
# machinelearning_run should be called
# ......
for step in range(num_steps):
if (step + 1) % 10 == 0:
metrics.record_metrics([
# Model quality metrics
{"metric_name": metric_types.MetricType.LEARNING_RATE, "value": step_size},
{"metric_name": metric_types.MetricType.LOSS, "value": loss},
{"metric_name": metric_types.MetricType.GRADIENT_NORM, "value": gradient},
{"metric_name": metric_types.MetricType.TOTAL_WEIGHTS, "value": total_weights},
# Model performance metrics
{"metric_name": metric_types.MetricType.STEP_TIME, "value": step_time},
{"metric_name": metric_types.MetricType.THROUGHPUT, "value": throughput},
{"metric_name": metric_types.MetricType.LATENCY, "value": latency},
{"metric_name": metric_types.MetricType.TFLOPS, "value": tflops},
{"metric_name": metric_types.MetricType.MFU, "value": mfu},
], step=step+1)
The following system metrics are automatically collected by the SDK from libTPU,
psutil, and JAX libraries:
- TPU TensorCore utilization
- TPU duty cycle
- HBM utilization
- Host CPU utilization
- Host memory utilization
You don't need to manually specify these metrics. These system metrics have time as the default x-axis.
The following predefined metric keys will automatically appear in the Google Cloud console if assigned. These metrics aren't calculated automatically; they are predefined keys that you can assign values to.
- Model quality metric keys:
LEARNING_RATE,LOSS,GRADIENT_NORM,TOTAL_WEIGHTS. - Model performance metric keys:
STEP_TIME,THROUGHPUT,LATENCY,MFU,TFLOPS.
The predefined metrics, as well as other user-defined metrics can be recorded
with x-axis as time, or both time and step. You can record any custom
metric in the workload.
The following example captures a single metric for the workload, which you can view in the Model Metrics tab for the specific machine learning run:
metrics.record("custom_metrics_1", step_size, step=step + 1)
To record multiple metrics in one call, use the record_metrics method. For
example:
metrics.record_metrics([
# Model quality metrics
{"metric_name": metric_types.MetricType.LEARNING_RATE, "value": step_size},
{"metric_name": metric_types.MetricType.LOSS, "value": loss},
{"metric_name": metric_types.MetricType.GRADIENT_NORM, "value": gradient},
{"metric_name": metric_types.MetricType.TOTAL_WEIGHTS, "value": total_weights},
# Model performance metrics
{"metric_name": metric_types.MetricType.STEP_TIME, "value": step_time},
{"metric_name": metric_types.MetricType.THROUGHPUT, "value": throughput},
{"metric_name": metric_types.MetricType.LATENCY, "value": latency},
{"metric_name": metric_types.MetricType.TFLOPS, "value": tflops},
{"metric_name": metric_types.MetricType.MFU, "value": mfu},
# Custom metrics
{"custom_metrics_1", "value":<value>},
{"custom_metrics_2", "value":<value>},
{"avg_mtp_acceptance_rate_percent", "value":<value>},
{"dpo_reward_accuracy", "value":<value>},
], step=step+1)
Capture profiles
You can capture XProf profiles of your ML workload with programmatic capture or on-demand capture (manual capture). Programmatic capture involves embedding profiling commands directly into your machine learning code, and explicitly stating when to start and stop recording data. On-demand capture occurs in real-time, where you trigger the profiler while the workload is already actively running.
The SDK commands to capture profiles are framework-agnostic since all framework-level profiling commands are automatically integrated into ML Diagnostics profiling commands. This means that your profiling code is not dependent on the framework you use.
Programmatic profile capture
Programmatic capture requires you to annotate your model code and specify where you want to capture profiles. Typically, you capture a profile for a few training steps, or profile a specific block of code within your model.
You can perform programmatic profile capture with the ML Diagnostics SDK in the following ways:
- API-based collection: Control profiling with
start()andstop()methods. - Decorator-based collection: Annotate functions with
@xprof(run)for automatic profiling. - Context manager: Use with
xprof()for scope-based profiling that automatically handlesstart()andstop()operations.
You can use the same profile capture code across all frameworks. All the profile sessions are captured in the Cloud Storage bucket defined in the machine learning run.
# Support collection via APIs
prof = xprof() # Updates metadata and starts xprofz collector
prof.start() # Collects traces to bucket
# ..... Your code execution here
# ....
prof.stop()
# Also supports collection via decorators
@xprof()
def abc(self):
# does something
pass
# Use xprof as a context manager to automatically start and stop collection
with xprof() as prof:
# Your training or execution code here
train_model()
evaluate_model()
Multi-host (process) profiling
During programmatic profiling, the SDK starts profiling on each host (process) where ML workload code is executing. If the list of nodes is not provided, all hosts are included.
# starts profiling on all nodes
prof = xprof()
prof.start()
# ...
prof.stop()
By default, calling the prof.start() method without the session_id argument
on multiple hosts results in separate trace sessions - one for each host. To
group traces from different hosts into a single, unified multi-host session in
XProf, ensure that the prof.start() method is called with the same
session_id argument on all participating hosts. For example:
# Use the same session_id on all hosts to group traces
prof = xprof()
prof.start(session_id="profiling_session")
# ...
prof.stop()
To enable profiling for specific hosts:
# starts profiling on node with index 0 and 2
prof = xprof(process_index_list=[0,2])
prof.start()
# ...
prof.stop()
On-demand profile capture
Use on-demand profile capture when you want to capture profiles in an ad hoc manner, or when programmatic profile capture is not already enabled. On-demand capture is helpful when there are problems with model metrics during the run, and you want to capture profiles in those moments to diagnose the issues.
To enable on-demand profile capture, configure the run with on-demand support:
# Define machinelearning run
machinelearning_run(
name="<run_name>",
# specify where profiling data is stored
gcs_path="gs://<bucket>",
...
# enable on demand profiling, starts xprofz daemon on port 9999
on_demand_xprof=True
)
You can use the same profile capture code across all frameworks. All profile sessions are captured in the Cloud Storage bucket defined in the machine learning run.
For on-demand profiling on GKE, deploy GKE
connection-operator and injection-webhook into the GKE cluster. This
ensures that your machine learning run can locate the GKE nodes
it is running on, and the on-demand capture drop-down can autopopulate those
nodes. For more information, see Configure
GKE cluster.
Package workload for GKE
You can use a Dockerfile to package an application that uses the ML Diagnostics
SDK. Install the google-cloud-logging package for Cloud Logging
integration. For example:
# Base image (user's choice, e.g., python:3.10-slim, or a base with ML frameworks)
FROM python:3.11-slim
# Install base utilities
RUN pip install --no-cache-dir --upgrade pip
# Install SDK and Logging client
# psutil is installed as a dependency of google-cloud-mldiagnostics
RUN pip install --no-cache-dir \
google-cloud-mldiagnostics \
google-cloud-logging
# Optional: For JAX/TPU workloads
# RUN pip install --no-cache-dir "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html &&
# pip install --no-cache-dir libtpu xprof
# Add your application code
COPY ./app /app
WORKDIR /app
# Run your script
CMD ["python", "your_train_script.py"]
Deploy workload
After integrating the SDK with your workload, package the workload in an image
and create your YAML file with the specified image. Label the workload in the
YAML file with managed-mldiagnostics-gke=true.
For GKE:
kubectl apply -f YAML_FILE_NAME
For Compute Engine, connect to the VM using SSH and run the Python code for your workload:
source venv/bin/activate
python3.11 WORKLOAD_FILE_NAME
After deploying the workload, find your job name by searching for your workload namespace:
kubectl get job -n YOUR_NAMESPACE
You can find the run name and link in your kubectl logs by passing the job
name and namespace. You must specify the workload container (for example: -c
workload) because the ML Diagnostics sidecar handles its own logging.
kubectl logs jobs/s5-tpu-slice-0 -n YOUR_NAMESPACE -c workload