Use ML Diagnostics with MaxText
MaxText is an open-source
Large Language Model (LLM) library built to run on TPUs and GPUs for model
training. ML Diagnostics SDK comes pre-integrated with MaxText. You can enable
ML Diagnostics with MaxText by using the managed-mldiagnostics flag.
When enabled, you will be able to perform the following:
- Create a managed
MachineLearningrun with all the MaxText configs. - Upload profiling traces, if the profiling is enabled by the
profiler="xplane"flag. - Upload training metrics at a defined interval using the
log_periodflag.
Use the following flags when using ML Diagnostics with MaxText:
managed_mldiagnostics: True # Enable the managed diagnostics
managed_mldiagnostics_run_group: GROUP_NAME # Used to group multiple runs. (Optional)
To enable ML Diagnostics in MaxText, you can either change the configuration
file of your run, or pass the flags from the command line. For example, pass the
following flags with the MaxText.train command:
python3 -m MaxText.train \
src/MaxText/configs/base.yml \
run_name="demo-mldiagnostics-run-2" \
model_name="<your_chosen_model>" \
base_output_directory=gs://<your_gcs_folder>/ \
dataset_type=synthetic \
steps=100 \
log_period=10 \
profiler=xplane \
upload_all_profiler_results=True \
managed_mldiagnostics=True \
managed_mldiagnostics_run_group="demo-mldiagnostics-group"
In this example, setting the upload_all_profiler_results=True flag captures
multi-host profiles from all hosts.