det.pytorch.samplers API Reference#

class determined.pytorch.samplers.DistributedBatchSampler(batch_sampler: torch.utils.data.sampler.BatchSampler, num_workers: int, rank: int)#

DistributedBatchSampler will iterate through an underlying batch sampler and return batches which belong to this shard.

DistributedBatchSampler is different from the PyTorch built-in torch.utils.data.distributed.DistributedSampler, because that DistributedSampler expects to bbe called before the BatchSampler, and additionally the DistributedSampler is meant to be a stand-alone sampler.

DistributedBatchSampler has the potential gotcha that when wrapping a non-repeating BatchSampler, if the length of the BatchSampler is not divisible by the number of replicas the length of the resulting DistributedBatchSampler will differ based on the rank. In that case, the divergent paths of multiple workers could cause problems during training. PyTorchTrial always uses RepeatBatchSampler during training, PyTorchTrial does not require that the workers stay in step during validation, so this potential gotcha is not a problem in Determined.

class determined.pytorch.samplers.DistributedSampler(sampler: torch.utils.data.sampler.Sampler, num_workers: int, rank: int)#

DistributedSampler will iterate through an underlying sampler and return samples which belong to this shard.

DistributedSampler is different from the PyTorch built-in torch.utils.data.DistributedSampler because theirs is meant to be a standalone sampler. Theirs does shuffling and assumes a constant size dataset as an input. Ours is meant to be used a building block in a chain of samplers, so it accepts a sampler as input that may or may not be constant-size.

class determined.pytorch.samplers.RepeatBatchSampler(batch_sampler: torch.utils.data.sampler.BatchSampler)#

RepeatBatchSampler yields infinite batches indices by repeatedly iterating through the batches of another BatchSampler. __len__ is just the length of the underlying BatchSampler.

class determined.pytorch.samplers.RepeatSampler(sampler: torch.utils.data.sampler.Sampler)#

RepeatSampler yields infinite batches indices by repeatedly iterating through the batches of another Sampler. __len__ is just the length of the underlying Sampler.

class determined.pytorch.samplers.ReproducibleShuffleBatchSampler(batch_sampler: torch.utils.data.sampler.BatchSampler, seed: int)#

ReproducibleShuffleBatchSampler will apply a deterministic shuffle based on a seed.

Warning

Always shuffle before skipping and before repeating. Skip-before-shuffle would break the reproducibility of the shuffle, and repeat-before-shuffle would cause the shuffle to hang as it iterates through an infinite sampler.

Warning

Always prefer ReproducibleShuffleSampler over this class when possible. The reason is that shuffling at the batch level results in a superior shuffle, where the contents of each batch are varied between epochs, rather than just the order of batches.

class determined.pytorch.samplers.ReproducibleShuffleSampler(sampler: torch.utils.data.sampler.Sampler, seed: int)#

ReproducibleShuffleSampler will apply a deterministic shuffle based on a seed.

Warning

Always shuffle before skipping and before repeating. Skip-before-shuffle would break the reproducibility of the shuffle, and repeat-before-shuffle would cause the shuffle to hang as it iterates through an infinite sampler.

class determined.pytorch.samplers.SkipBatchSampler(batch_sampler: torch.utils.data.sampler.BatchSampler, skip: int)#

SkipBatchSampler skips some batches from an underlying BatchSampler, and yield the rest.

Always skip before you repeat when you are continuing training, or you will apply the skip on every epoch.

Because the SkipBatchSampler is only meant to be used on a training dataset (we never checkpoint during evaluation), and because the training dataset should always be repeated before applying the skip (so you only skip once rather than many times), the length reported is always the length of the underlying sampler, regardless of the size of the skip.

class determined.pytorch.samplers.SkipSampler(sampler: torch.utils.data.sampler.BatchSampler, skip: int)#

SkipSampler skips some records from an underlying Sampler, and yields the rest.

Always skip before you repeat when you are continuing training, or you will apply the skip on every epoch.

Warning

When trying to achieve reproducibility after pausing and restarting, you should never prefer this SkipSampler over the SkipBatchSampler, unless you are sure that your dataset will always yield identically sized batches. This is due to how Determined counts batches trained but does not count records trained. Reproducibility when skipping records is only possible if the records to skip can be reliably calculated based on batch size and batches trained.

Because the SkipSampler is only meant to be used on a training dataset (we never checkpoint during evaluation), and because the training dataset should always be repeated before applying the skip (so you only skip once rather than many times), the length reported is always the length of the underlying sampler, regardless of the size of the skip.