DP Data Loader¶
- class opacus.data_loader.CollateFnWithEmpty(collator_fn, batch_first=True, rand_on_empty=False, sample_empty_shapes=None, dtypes=None)[source]¶
Collate function wrapper that handles empty batches by preserving batch structure.
This wrapper is stateful and learns the expected batch structure from the first non-empty batch it processes. When an empty batch is encountered, it generates an empty batch with the same structure (tensors, dicts, lists, or nested combinations) but with zero-length batch dimensions.
This is particularly useful for Poisson sampling in differential privacy, where batch sizes can vary and occasionally result in empty batches.
- Parameters:
collator_fn (
Optional[Callable[[list[TypeVar(_T)]],Any]]) – The original collate function to wrap. If None, returns batch as-is.batch_first (
bool) – If True, batch dimension is the first dimension (index 0). If False, batch dimension is the second dimension (index 1). Default: Truerand_on_empty (
bool) – If True, returns tensors filled with random values (0 or 1) with batch dimension set to 1 when encountering empty batches. If False, returns tensors with batch dimension set to 0. Default: False
Example
>>> collate_fn = CollateFnWithEmpty(default_collate) >>> # First batch: [{"x": tensor([1, 2]), "y": tensor([3, 4])}] >>> # Empty batch: [] -> {"x": tensor([]), "y": tensor([])}
Note
The first batch processed must be non-empty, as it defines the structure for all subsequent empty batches.
Only torch.Tensor, dict (Mapping), list, and tuple types are supported. If your collate function returns other types, a TypeError will be raised to preserve DP guarantees (returning non-empty data for empty batches would violate the privacy guarantee).
- class opacus.data_loader.DPDataLoader(dataset, *, sample_rate, collate_fn=None, drop_last=False, generator=None, distributed=False, batch_first=True, rand_on_empty=False, **kwargs)[source]¶
DataLoader subclass that always does Poisson sampling and supports empty batches by default.
Typically instantiated via
DPDataLoader.from_data_loader()method based on another DataLoader. DPDataLoader would preserve the behaviour of the original data loader, except for the two aspects.First, it switches
batch_samplertoUniformWithReplacementSampler, thus enabling Poisson sampling (i.e. each element in the dataset is selected to be in the next batch with a certain probability defined bysample_rateparameter). NB: this typically leads to a batches of variable size. NB2: By default,sample_rateis calculated based on thebatch_sizeof the original data loader, so that the average batch size stays the sameSecond, it wraps collate function with support for empty batches. Most PyTorch modules will happily process tensors of shape
(0, N, ...), but many collate functions will fail to produce such a batch. As with the Poisson sampling empty batches become a possibility, we need a DataLoader that can handle them.- Parameters:
dataset (
Dataset) – Seetorch.utils.data.DataLoadersample_rate (
float) – probability with which each element of the dataset is included in the next batch.num_workers – See
torch.utils.data.DataLoadercollate_fn (
Optional[Callable[[list[TypeVar(_T)]],Any]]) – Seetorch.utils.data.DataLoaderpin_memory – See
torch.utils.data.DataLoaderdrop_last (
bool) – Seetorch.utils.data.DataLoadertimeout – See
torch.utils.data.DataLoaderworker_init_fn – See
torch.utils.data.DataLoadermultiprocessing_context – See
torch.utils.data.DataLoadergenerator – Random number generator used to sample elements
prefetch_factor – See
torch.utils.data.DataLoaderpersistent_workers – See
torch.utils.data.DataLoaderdistributed (
bool) – setTrueif you’ll be using DPDataLoader in a DDP environment Selects betweenDistributedUniformWithReplacementSamplerandUniformWithReplacementSamplersampler implementationsrand_on_empty (
bool) – setTrueto return a batch containing random numbers when encountering empty batches rather than tensors with zero-length batch dimensions
- classmethod from_data_loader(data_loader, *, distributed=False, generator=None, batch_first=True, rand_on_empty=False)[source]¶
Creates new
DPDataLoaderbased on passeddata_loaderargument.- Parameters:
data_loader (
DataLoader) – Any DataLoader instance. Must not be over anIterableDatasetdistributed (
bool) – setTrueif you’ll be using DPDataLoader in a DDP environmentgenerator – Random number generator used to sample elements. Defaults to generator from the original data loader.
batch_first (
bool) – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be[batch_size, ...], otherwise[K, batch_size, ...]rand_on_empty (
bool) – setTrueto return a batch containing random numbers when encountering empty batches rather than tensors with zero-length batch dimensions
- Returns:
New DPDataLoader instance, with all attributes and parameters inherited from the original data loader, except for sampling mechanism.
Examples
>>> x, y = torch.randn(64, 5), torch.randint(0, 2, (64,)) >>> dataset = TensorDataset(x,y) >>> data_loader = DataLoader(dataset, batch_size=4) >>> dp_data_loader = DPDataLoader.from_data_loader(data_loader)
- opacus.data_loader.switch_generator(*, data_loader, generator)[source]¶
Creates new instance of a
DataLoader, with the exact same behaviour of the provided data loader, except for the source of randomness.Typically used to enhance a user-provided data loader object with cryptographically secure random number generator
- Parameters:
data_loader (
DataLoader) – AnyDataLoaderobjectgenerator – Random number generator object
- Returns:
New
DataLoaderobject with the exact same behaviour as the input data loader, except for the source of randomness.
- opacus.data_loader.wrap_collate_with_empty(*, collate_fn, batch_first=True, rand_on_empty=False, sample_empty_shapes=None, dtypes=None)[source]¶
Wraps given collate function to handle empty batches.
This function returns a stateful
CollateFnWithEmptyinstance that learns the batch structure from the first non-empty batch and uses this structure to generate properly shaped empty batches when needed.- Parameters:
collate_fn (
Optional[Callable[[list[TypeVar(_T)]],Any]]) – collate function to wrap. If None, returns batches as-is.batch_first (
bool) – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be[batch_size, ...], otherwise[K, batch_size, ...]rand_on_empty (
bool) – setTrueto return a batch containing random numbers when encountering empty batches rather than tensors with zero-length batch dimensions
- Returns:
- A callable that is equivalent to input
collate_fnfor non-empty batches and outputs empty tensors with the same structure when the input batch is empty. The structure is learned from the first non-empty batch.
- A callable that is equivalent to input
- Return type:
Example
>>> from torch.utils.data._utils.collate import default_collate >>> collate = wrap_collate_with_empty(collate_fn=default_collate) >>> # First batch defines structure >>> result = collate([{"x": torch.tensor([1, 2])}]) >>> # Empty batch uses learned structure >>> empty = collate([]) # Returns {"x": torch.tensor([])}