Torch Batch Processing API#

In this guide, you’ll learn about the Torch Batch Process API and how to perform batch inference (also known as offline inference).

Visit the API reference

pytorch.experimental.torch_batch_process API Reference


This is an experimental API and may change at any time.


The Torch Batch Processing API takes in (1) a dataset and (2) a user-defined processor class and runs distributed data processing.

This API automatically handles the following for you:

  • shards a dataset by number of workers available

  • applies user-defined logic to each batch of data

  • handles synchronization between workers

  • tracks job progress to enable preemption and resumption of trial

This is a flexible API that can be used for many different tasks, including batch (offline) inference.

If you have some trained models in a Checkpoint or a Model with more than one ModelVersion inside, you can associate the trial with the Checkpoint or ModelVersion used in a given inference run to aggregate custom inference metrics.

You can then query those Checkpoint or ModelVersion objects using the Python SDK to see all metrics associated with them.


The main arguments to torch_batch_process() are processor class and dataset.


In the experiment config file, use a distributed launcher as the API requires information such as rank set by the launcher. Below is an example.

entrypoint: >-
    python3 -m determined.launch.torch_distributed
  slots_per_trial: 4


During __init__() of TorchBatchProcessor, we pass in a TorchBatchProcessorContext object, which contains useful methods that can be used within the TorchBatchProcessor class.

TorchBatchProcessor is compatible with Determined’s MetricReducer. You can pass MetricReducer to TorchBatchProcessor as follow:


TorchBatchProcessorContext should be a subclass of TorchBatchProcessor. The two functions you must implement are the __init__() and process_batch(). The other lifecycle functions are optional.

class MyProcessor(TorchBatchProcessor):
    def __init__(self, context):
        self.reducer = context.wrap_reducer(reducer=AccuracyMetricReducer(), name="accuracy")

How To Perform Batch (Offline) Inference#

In this section, we’ll learn how to perform batch inference using the Torch Batch Processing API.

Step 1: Define an InferenceProcessor#

The first step is to define an InferenceProcessor. You should initialize your model in the __init__() function of the InferenceProcessor. You should implement process_batch() function with inference logic.

You can optionally implement on_checkpoint_start() and on_finish() to be run before every checkpoint and after all the data has been processed, respectively.

Define custom processor class
class InferenceProcessor(TorchBatchProcessor):
    def __init__(self, context):
        self.context = context
        self.model = context.prepare_model_for_inference(get_model())
        self.output = []
        self.last_index = 0

    def process_batch(self, batch, batch_idx) -> None:
        model_input = batch[0]
        model_input = self.context.to_device(model_input)

        with torch.no_grad():
            with self.profiler as p:
                pred = self.model(model_input)
                output = {"predictions": pred, "input": batch}

        self.last_index = batch_idx

    def on_checkpoint_start(self):
        During checkpoint, we persist prediction result
        if len(self.output) == 0:
        file_name = f"prediction_output_{self.last_index}"
        with self.context.upload_path() as path:
            file_path = pathlib.Path(path, file_name)
  , file_path)

        self.output = []

Step 3: Initialize the Dataset#

Initialize the dataset you want to process.

Initialize dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
with filelock.FileLock(os.path.join("/tmp", "inference.lock")):
    inference_data = tv.datasets.CIFAR10(
        root="/data", train=False, download=True, transform=transform

Step 4: Pass the InferenceProcessor Class and Dataset#

Pass the InferenceProcessor class and the dataset to torch_batch_process.

Pass processor class and dataset to torch_batch_process

Step 5: Send and Query Custom Inference Metrics (Optional)#

Report metrics anywhere in the trial to have them aggregated for the Checkpoint or ModelVersion in question.

For example, you could send metrics in on_finish().

def on_finish(self):
            "my_metric": 1.0,

And check the metric afterwards from the SDK:

from determined.experimental import client

# Checkpoint
ckpt = client.get_checkpoint("<CHECKPOINT_UUID>")
metrics = ckpt.get_metrics("inference")

# Or Model Version
model = client.get_model("<MODEL_NAME>")
model_version = model.get_version(MODEL_VERSION_NUM)
metrics = model_version.get_metrics("inference")