Use ML Diagnostics with MaxText

MaxText is an open-source Large Language Model (LLM) library built to run on TPUs and GPUs for model training. ML Diagnostics SDK comes pre-integrated with MaxText. You can enable ML Diagnostics with MaxText by using the managed-mldiagnostics flag.

When enabled, you will be able to perform the following:

  • Create a managed MachineLearning run with all the MaxText configs.
  • Upload profiling traces, if the profiling is enabled by the profiler="xplane" flag.
  • Upload training metrics at a defined interval using the log_period flag.

Use the following flags when using ML Diagnostics with MaxText:

managed_mldiagnostics: True  # Enable the managed diagnostics
managed_mldiagnostics_run_group: GROUP_NAME  # Used to group multiple runs. (Optional)

To enable ML Diagnostics in MaxText, you can either change the configuration file of your run, or pass the flags from the command line. For example, pass the following flags with the MaxText.train command:

python3 -m MaxText.train \
  src/MaxText/configs/base.yml \
  run_name="demo-mldiagnostics-run-2" \
  model_name="<your_chosen_model>" \
  base_output_directory=gs://<your_gcs_folder>/ \
  dataset_type=synthetic \
  steps=100 \
  log_period=10 \
  profiler=xplane \
  upload_all_profiler_results=True \
  managed_mldiagnostics=True \
  managed_mldiagnostics_run_group="demo-mldiagnostics-group"

In this example, setting the upload_all_profiler_results=True flag captures multi-host profiles from all hosts.