Checkpoints#

Determined provides APIs for downloading checkpoints and loading them into memory in a Python process.

This guide discusses:

  1. Querying model checkpoints from trials and experiments.

  2. Loading model checkpoints in Python.

  3. Storing additional user-defined metadata in a checkpoint.

  4. Using the Determined CLI to download checkpoints to disk.

The Checkpoint Export API is a subset of the features found in the client module.

Querying Checkpoints#

Use the Experiment class to reference an experiment. The list_checkpoints() method, when called without arguments, returns checkpoints sorted based on the metric and smaller_is_better values from the experiment configuration’s searcher field.

For example, in the following experiment configuration file snippet, Determined will sort checkpoints by the loss metric and in ascending order.

searcher:
  metric: "loss"
  smaller_is_better: true

After generating a checkpoint for the specified experiment, you can run the Python code below. This code retrieves a list of sorted Checkpoint instances associated with the experiment and selects the checkpoint with the best validation metric.

from determined.experimental import client

checkpoint = client.get_experiment(id).list_checkpoints()[0]

To sort checkpoints by any metric, use the sort_by argument to specify the metric and order_by to define the sorting order (ascending or descending).

from determined.experimental import checkpoint, client

checkpoints = (
    client.get_experiment(id).list_checkpoints(
        sort_by="accuracy",
        order_by=client.OrderBy.DESC
    )
)

To sort checkpoints using preset checkpoint parameters, use the CheckpointSortBy class. The example below fetches all checkpoints for an experiment, sorting them by trial ID in descending order.

from determined.experimental import checkpoint, client

checkpoints = client.get_experiment(id).list_checkpoints(
    sort_by=checkpoint.CheckpointSortBy.TRIAL_ID,
    order_by=client.OrderBy.DESC
)

Trial is used for fine-grained control over checkpoint selection within a trial. It contains a list_checkpoints() method, which mirrors list_checkpoints() for an experiment.

The following code illustrates methods to select specific checkpoints from a trial:

from determined.experimental import checkpoint, client

trial = client.get_trial(id)

most_recent_checkpoint = trial.list_checkpoints(
    sort_by=checkpoint.CheckpointSortBy.END_TIME,
    order_by=client.OrderBy.DESC,
    max_results=1
)[0]

# Sort checkpoints by "accuracy" metric, if your training code reports it.
most_accurate_checkpoint = trial.list_checkpoints(
    sort_by="accuracy",
    order_by=client.OrderBy.DESC,
    max_results=1
)[0]

specific_checkpoint = client.get_checkpoint(uuid="uuid-for-checkpoint")

Using the Checkpoint Class#

The Checkpoint class can both download the checkpoint from persistent storage and load it into memory in a Python process.

The download() method downloads a checkpoint from persistent storage to a directory on the local file system. By default, checkpoints are downloaded to checkpoints/<checkpoint-uuid>/ (relative to the current working directory). The download() method accepts path as an optional parameter, which changes the checkpoint download location.

from determined.experimental import client

checkpoint = client.get_experiment(id).list_checkpoints()[0]
checkpoint_path = checkpoint.download()

specific_path = checkpoint.download(path="specific-checkpoint-path")

The load() method downloads the checkpoint, if it does not already exist locally, and loads it into memory. The return type and behavior is different depending on whether you are using TensorFlow or PyTorch.

PyTorch Checkpoints#

When using PyTorch models, the load() method returns a parameterized instance of your trial class as defined in the experiment config under the entrypoint field. The trained model can then be accessed from the model attribute of the Trial object, as shown in the following snippet.

from determined.experimental import client
from determined import pytorch

checkpoint = client.get_experiment(id).list_checkpoints()[0]
path = checkpoint.download()
trial = pytorch.load_trial_from_checkpoint_path(path)
model = trial.model

predictions = model(samples)

PyTorch checkpoints are saved using pickle and loaded as PyTorch API objects (see the PyTorch documentation for details).

TensorFlow Checkpoints#

When using TensorFlow models, the load() method returns a compiled model with weights loaded. This will be the same TensorFlow model returned by your build_model() method defined in your trial class specified by the experiment config entrypoint field. The trained model can then be used to make predictions as shown in the following snippet.

from determined.experimental import client
from determined import keras

checkpoint = client.get_experiment(id).list_checkpoints()[0]
path = checkpoint.download()
model = keras.load_model_from_checkpoint_path(path)

predictions = model(samples)

TensorFlow checkpoints are saved in either the saved_model or h5 formats and are loaded as trackable objects (see documentation for tf.compat.v1.saved_model.load_v2 for details).

Adding User-Defined Checkpoint Metadata#

You can add arbitrary user-defined metadata to a checkpoint via the Python SDK. This feature is useful for storing post-training metrics, labels, information related to deployment, etc.

from determined.experimental import client

checkpoint = client.get_experiment(id).list_checkpoints()[0]
checkpoint.add_metadata({"environment": "production"})

# Metadata will be stored in Determined and accessible on the checkpoint object.
print(checkpoint.metadata)

You may store an arbitrarily nested dictionary using the add_metadata() method. If the top level key already exists the entire tree beneath it will be overwritten.

from determined.experimental import client

checkpoint = client.get_experiment(id).list_checkpoints()[0]
checkpoint.add_metadata({"metrics": {"loss": 0.12}})
checkpoint.add_metadata({"metrics": {"acc": 0.92}})

print(checkpoint.metadata)  # Output: {"metrics": {"acc": 0.92}}

You may remove metadata via the remove_metadata() method. The method accepts a list of top level keys. The entire tree beneath the keys passed will be deleted.

from determined.experimental import client

checkpoint = client.get_experiment(id).list_checkpoints()[0]
checkpoint.remove_metadata(["metrics"])

Downloading Checkpoints using the CLI#

The Determined CLI can be used to view all the checkpoints associated with an experiment:

$ det experiment list-checkpoints <experiment-id>

Checkpoints are saved to external storage, according to the checkpoint_storage section in the experiment configuration. Each checkpoint has a UUID, which is used as the name of the checkpoint directory on the external storage system. For example, if the experiment is configured to save checkpoints to a shared file system:

checkpoint_storage:
  type: shared_fs
  host_path: /mnt/nfs-volume-1

A checkpoint with UUID b3ed462c-a6c9-41e9-9202-5cb8ff00e109 can be found in the directory /mnt/nfs-volume-1/b3ed462c-a6c9-41e9-9202-5cb8ff00e109.

Determined offers the following CLI commands for downloading checkpoints locally:

  1. det checkpoint download

  2. det trial download

  3. det experiment download

Warning

When downloading checkpoints in a shared file system, we assume the same shared file system is mounted locally.

The det checkpoint download command downloads a checkpoint for the given UUID as shown below:

# Download a specific checkpoint.
det checkpoint download 46985143-af68-4d48-ab91-a6447052ca49

The command should display output resembling the following upon successfully downloading the checkpoint.

Local checkpoint path:
checkpoints/46985143-af68-4d48-ab91-a6447052ca49

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+---------------------------------------------
      1000 | 46985143-af68-4d48-ab91-a6447052ca49 | {
           |                                      |     "num_inputs": 0,
           |                                      |     "validation_metrics": {
           |                                      |         "loss": 7.906739711761475,
           |                                      |         "accuracy": 0.9646000266075134,
           |                                      |         "global_step": 1000,
           |                                      |         "average_loss": 0.12492649257183075
           |                                      |     }
           |                                      | }

The det trial download command downloads checkpoints for a specified trial. Similar to the Trial API, the det trial download command accepts --best, --latest, and --uuid options.

# Download best checkpoint.
det trial download <trial_id> --best
# Download best checkpoint to a particular directory.
det trial download <trial_id> --best --output-dir local_checkpoint

The command should display output resembling the following upon successfully downloading the checkpoint.

Local checkpoint path:
checkpoints/46985143-af68-4d48-ab91-a6447052ca49

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+---------------------------------------------
      1000 | 46985143-af68-4d48-ab91-a6447052ca49 | {
           |                                      |     "num_inputs": 0,
           |                                      |     "validation_metrics": {
           |                                      |         "loss": 7.906739711761475,
           |                                      |         "accuracy": 0.9646000266075134,
           |                                      |         "global_step": 1000,
           |                                      |         "average_loss": 0.12492649257183075
           |                                      |     }
           |                                      | }

The --latest and --uuid options are used as follows:

# Download the most recent checkpoint.
det trial download <trial_id> --latest

# Download a specific checkpoint.
det trial download <trial_id> --uuid <uuid-for-checkpoint>

Finally, the det experiment download command provides a similar experience to using the Python SDK.

# Download the best checkpoint for a given experiment.
det experiment download <experiment_id>

# Download the best 3 checkpoints for a given experiment.
det experiment download <experiment_id> --top-n 3

The command should display output resembling the following upon successfully downloading the checkpoints.

Local checkpoint path:
checkpoints/8d45f621-8652-4268-8445-6ae9a735e453

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+------------------------------------------
       400 | 8d45f621-8652-4268-8445-6ae9a735e453 | {
           |                                      |     "num_inputs": 56,
           |                                      |     "validation_metrics": {
           |                                      |         "val_loss": 0.26509127765893936,
           |                                      |         "val_categorical_accuracy": 1
           |                                      |     }
           |                                      | }

Local checkpoint path:
checkpoints/62131ba1-983c-49a8-98ef-36207611d71f

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+------------------------------------------
      1600 | 62131ba1-983c-49a8-98ef-36207611d71f | {
           |                                      |     "num_inputs": 50,
           |                                      |     "validation_metrics": {
           |                                      |         "val_loss": 0.04411194706335664,
           |                                      |         "val_categorical_accuracy": 1
           |                                      |     }
           |                                      | }

Local checkpoint path:
checkpoints/a36d2a61-a384-44f7-a84b-8b30b09cb618

     Batch | Checkpoint UUID                      | Validation Metrics
-----------+--------------------------------------+------------------------------------------
       400 | a36d2a61-a384-44f7-a84b-8b30b09cb618 | {
           |                                      |     "num_inputs": 46,
           |                                      |     "validation_metrics": {
           |                                      |         "val_loss": 0.07265569269657135,
           |                                      |         "val_categorical_accuracy": 1
           |                                      |     }
           |                                      | }

Next Steps#