DP Data Loader

class opacus.data_loader.CollateFnWithEmpty(collator_fn, batch_first=True, rand_on_empty=False, sample_empty_shapes=None, dtypes=None)[source]

Collate function wrapper that handles empty batches by preserving batch structure.

This wrapper is stateful and learns the expected batch structure from the first non-empty batch it processes. When an empty batch is encountered, it generates an empty batch with the same structure (tensors, dicts, lists, or nested combinations) but with zero-length batch dimensions.

This is particularly useful for Poisson sampling in differential privacy, where batch sizes can vary and occasionally result in empty batches.

Parameters:
  • collator_fn (Optional[Callable[[list[TypeVar(_T)]], Any]]) – The original collate function to wrap. If None, returns batch as-is.

  • batch_first (bool) – If True, batch dimension is the first dimension (index 0). If False, batch dimension is the second dimension (index 1). Default: True

  • rand_on_empty (bool) – If True, returns tensors filled with random values (0 or 1) with batch dimension set to 1 when encountering empty batches. If False, returns tensors with batch dimension set to 0. Default: False

Example

>>> collate_fn = CollateFnWithEmpty(default_collate)
>>> # First batch: [{"x": tensor([1, 2]), "y": tensor([3, 4])}]
>>> # Empty batch: [] -> {"x": tensor([]), "y": tensor([])}

Note

The first batch processed must be non-empty, as it defines the structure for all subsequent empty batches.

Only torch.Tensor, dict (Mapping), list, and tuple types are supported. If your collate function returns other types, a TypeError will be raised to preserve DP guarantees (returning non-empty data for empty batches would violate the privacy guarantee).

class opacus.data_loader.DPDataLoader(dataset, *, sample_rate, collate_fn=None, drop_last=False, generator=None, distributed=False, batch_first=True, rand_on_empty=False, **kwargs)[source]

DataLoader subclass that always does Poisson sampling and supports empty batches by default.

Typically instantiated via DPDataLoader.from_data_loader() method based on another DataLoader. DPDataLoader would preserve the behaviour of the original data loader, except for the two aspects.

First, it switches batch_sampler to UniformWithReplacementSampler, thus enabling Poisson sampling (i.e. each element in the dataset is selected to be in the next batch with a certain probability defined by sample_rate parameter). NB: this typically leads to a batches of variable size. NB2: By default, sample_rate is calculated based on the batch_size of the original data loader, so that the average batch size stays the same

Second, it wraps collate function with support for empty batches. Most PyTorch modules will happily process tensors of shape (0, N, ...), but many collate functions will fail to produce such a batch. As with the Poisson sampling empty batches become a possibility, we need a DataLoader that can handle them.

Parameters:
classmethod from_data_loader(data_loader, *, distributed=False, generator=None, batch_first=True, rand_on_empty=False)[source]

Creates new DPDataLoader based on passed data_loader argument.

Parameters:
  • data_loader (DataLoader) – Any DataLoader instance. Must not be over an IterableDataset

  • distributed (bool) – set True if you’ll be using DPDataLoader in a DDP environment

  • generator – Random number generator used to sample elements. Defaults to generator from the original data loader.

  • batch_first (bool) – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be [batch_size, ...], otherwise [K, batch_size, ...]

  • rand_on_empty (bool) – set True to return a batch containing random numbers when encountering empty batches rather than tensors with zero-length batch dimensions

Returns:

New DPDataLoader instance, with all attributes and parameters inherited from the original data loader, except for sampling mechanism.

Examples

>>> x, y = torch.randn(64, 5), torch.randint(0, 2, (64,))
>>> dataset = TensorDataset(x,y)
>>> data_loader = DataLoader(dataset, batch_size=4)
>>> dp_data_loader = DPDataLoader.from_data_loader(data_loader)
opacus.data_loader.dtype_safe(x)[source]

Exception-safe getter for dtype attribute.

Return type:

Union[dtype, Type]

opacus.data_loader.shape_safe(x)[source]

Exception-safe getter for shape attribute.

Return type:

Tuple

opacus.data_loader.switch_generator(*, data_loader, generator)[source]

Creates new instance of a DataLoader, with the exact same behaviour of the provided data loader, except for the source of randomness.

Typically used to enhance a user-provided data loader object with cryptographically secure random number generator

Parameters:
  • data_loader (DataLoader) – Any DataLoader object

  • generator – Random number generator object

Returns:

New DataLoader object with the exact same behaviour as the input data loader, except for the source of randomness.

opacus.data_loader.wrap_collate_with_empty(*, collate_fn, batch_first=True, rand_on_empty=False, sample_empty_shapes=None, dtypes=None)[source]

Wraps given collate function to handle empty batches.

This function returns a stateful CollateFnWithEmpty instance that learns the batch structure from the first non-empty batch and uses this structure to generate properly shaped empty batches when needed.

Parameters:
  • collate_fn (Optional[Callable[[list[TypeVar(_T)]], Any]]) – collate function to wrap. If None, returns batches as-is.

  • batch_first (bool) – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be [batch_size, ...], otherwise [K, batch_size, ...]

  • rand_on_empty (bool) – set True to return a batch containing random numbers when encountering empty batches rather than tensors with zero-length batch dimensions

Returns:

A callable that is equivalent to input collate_fn for non-empty

batches and outputs empty tensors with the same structure when the input batch is empty. The structure is learned from the first non-empty batch.

Return type:

CollateFnWithEmpty

Example

>>> from torch.utils.data._utils.collate import default_collate
>>> collate = wrap_collate_with_empty(collate_fn=default_collate)
>>> # First batch defines structure
>>> result = collate([{"x": torch.tensor([1, 2])}])
>>> # Empty batch uses learned structure
>>> empty = collate([])  # Returns {"x": torch.tensor([])}