DP Data Loader¶
- class opacus.data_loader.DPDataLoader(dataset, *, sample_rate, collate_fn=None, drop_last=False, generator=None, distributed=False, **kwargs)[source]¶
DataLoader subclass that always does Poisson sampling and supports empty batches by default.
Typically instantiated via
DPDataLoader.from_data_loader()
method based on another DataLoader. DPDataLoader would preserve the behaviour of the original data loader, except for the two aspects.First, it switches
batch_sampler
toUniformWithReplacementSampler
, thus enabling Poisson sampling (i.e. each element in the dataset is selected to be in the next batch with a certain probability defined bysample_rate
parameter). NB: this typically leads to a batches of variable size. NB2: By default,sample_rate
is calculated based on thebatch_size
of the original data loader, so that the average batch size stays the sameSecond, it wraps collate function with support for empty batches. Most PyTorch modules will happily process tensors of shape
(0, N, ...)
, but many collate functions will fail to produce such a batch. As with the Poisson sampling empty batches become a possibility, we need a DataLoader that can handle them.- Parameters:
dataset (
Dataset
) – Seetorch.utils.data.DataLoader
sample_rate (
float
) – probability with which each element of the dataset is included in the next batch.num_workers – See
torch.utils.data.DataLoader
collate_fn (
Optional
[Callable
[[List
[TypeVar
(_T
)]],Any
]]) – Seetorch.utils.data.DataLoader
pin_memory – See
torch.utils.data.DataLoader
drop_last (
bool
) – Seetorch.utils.data.DataLoader
timeout – See
torch.utils.data.DataLoader
worker_init_fn – See
torch.utils.data.DataLoader
multiprocessing_context – See
torch.utils.data.DataLoader
generator – Random number generator used to sample elements
prefetch_factor – See
torch.utils.data.DataLoader
persistent_workers – See
torch.utils.data.DataLoader
distributed (
bool
) – setTrue
if you’ll be using DPDataLoader in a DDP environment Selects betweenDistributedUniformWithReplacementSampler
andUniformWithReplacementSampler
sampler implementations
- classmethod from_data_loader(data_loader, *, distributed=False, generator=None)[source]¶
Creates new
DPDataLoader
based on passeddata_loader
argument.- Parameters:
data_loader (
DataLoader
) – Any DataLoader instance. Must not be over anIterableDataset
distributed (
bool
) – setTrue
if you’ll be using DPDataLoader in a DDP environmentgenerator – Random number generator used to sample elements. Defaults to generator from the original data loader.
- Returns:
New DPDataLoader instance, with all attributes and parameters inherited from the original data loader, except for sampling mechanism.
Examples
>>> x, y = torch.randn(64, 5), torch.randint(0, 2, (64,)) >>> dataset = TensorDataset(x,y) >>> data_loader = DataLoader(dataset, batch_size=4) >>> dp_data_loader = DPDataLoader.from_data_loader(data_loader)
- opacus.data_loader.collate(batch, *, collate_fn, sample_empty_shapes, dtypes)[source]¶
Wraps collate_fn to handle empty batches.
Default collate_fn implementations typically can’t handle batches of length zero. Since this is a possible case for poisson sampling, we need to wrap the collate method, producing tensors with the correct shape and size (albeit the batch dimension being zero-size)
- opacus.data_loader.switch_generator(*, data_loader, generator)[source]¶
Creates new instance of a
DataLoader
, with the exact same behaviour of the provided data loader, except for the source of randomness.Typically used to enhance a user-provided data loader object with cryptographically secure random number generator
- Parameters:
data_loader (
DataLoader
) – AnyDataLoader
objectgenerator – Random number generator object
- Returns:
New
DataLoader
object with the exact same behaviour as the input data loader, except for the source of randomness.
- opacus.data_loader.wrap_collate_with_empty(*, collate_fn, sample_empty_shapes, dtypes)[source]¶
Wraps given collate function to handle empty batches.
- Parameters:
- Returns:
New collate function, which is equivalent to input
collate_fn
for non-empty batches and outputs empty tensors with shapes fromsample_empty_shapes
if the input batch is of size 0