EstimatorTrial is deprecated and will be removed in a future version. TensorFlow has advised
Estimator users to switch to Keras since TensorFlow 2.0 was released. Consequently, we recommend
users of EstimatorTrial to switch to the
In this guide, you’ll learn how to use the Estimator API.
Visit the API reference
This document guides you through training a Estimator model in Determined. You need to implement a
trial class that inherits
EstimatorTrial and specify it as the
entrypoint in the experiment configuration.
Define Optimizer and Datasets#
Before loading data, read this document Prepare Data to understand how to work with different sources of data.
tf.estimator models with Determined, you’ll need to wrap your optimizer and datasets
wrap_dataset(). Note that the concrete context
object where these functions will be found will be in
Determined supports proper reduction of arbitrary validation metrics during distributed training by
allowing users to define custom reducers for their metrics. Custom reducers can be either a function
or an implementation of the
context.make_metric() for more
A checkpoint includes the model definition (Python source code), experiment configuration file, network architecture, and the values of the model’s parameters (i.e., weights) and hyperparameters. When using a stateful optimizer during training, checkpoints will also include the state of the optimizer (i.e., learning rate). You can also embed arbitrary metadata in checkpoints via the Python SDK.
TensorFlow Estimator trials are checkpointed using the SavedModel format. Please consult the TensorFlow documentation for details on how to restore models from the SavedModel format.
To execute arbitrary Python code during the lifecycle of a
RunHook extends tf.estimator.SessionRunHook. When utilizing
determined.estimator.RunHook, users can use native estimator hooks such as
and Determined hooks such as
Example usage of
determined.estimator.RunHook which adds custom metadata checkpoints:
class MyHook(determined.estimator.RunHook): def __init__(self, context, metadata) -> None: self._context = context self._metadata = metadata def on_checkpoint_end(self, checkpoint_dir) -> None: with open(os.path.join(checkpoint_dir, "metadata.txt"), "w") as fp: fp.write(self._metadata) class MyEstimatorTrial(determined.estimator.EstimatorTrial): ... def build_train_spec(self) -> tf.estimator.TrainSpec: return tf.estimator.TrainSpec( make_input_fn(), hooks=[MyHook(self.context, "my_metadata")], )