Source code for opacus.data_loader

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from opacus.utils.uniform_sampler import (
    DistributedUniformWithReplacementSampler,
    UniformWithReplacementSampler,
)
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.dataloader import _collate_fn_t


logger = logging.getLogger(__name__)


[docs] class CollateFnWithEmpty: """ Collate function wrapper that handles empty batches by preserving batch structure. This wrapper is stateful and learns the expected batch structure from the first non-empty batch it processes. When an empty batch is encountered, it generates an empty batch with the same structure (tensors, dicts, lists, or nested combinations) but with zero-length batch dimensions. This is particularly useful for Poisson sampling in differential privacy, where batch sizes can vary and occasionally result in empty batches. Args: collator_fn: The original collate function to wrap. If None, returns batch as-is. batch_first: If True, batch dimension is the first dimension (index 0). If False, batch dimension is the second dimension (index 1). Default: True rand_on_empty: If True, returns tensors filled with random values (0 or 1) with batch dimension set to 1 when encountering empty batches. If False, returns tensors with batch dimension set to 0. Default: False Example: >>> collate_fn = CollateFnWithEmpty(default_collate) >>> # First batch: [{"x": tensor([1, 2]), "y": tensor([3, 4])}] >>> # Empty batch: [] -> {"x": tensor([]), "y": tensor([])} Note: The first batch processed must be non-empty, as it defines the structure for all subsequent empty batches. Only torch.Tensor, dict (Mapping), list, and tuple types are supported. If your collate function returns other types, a TypeError will be raised to preserve DP guarantees (returning non-empty data for empty batches would violate the privacy guarantee). """ def __init__( self, collator_fn: Optional[_collate_fn_t], batch_first: bool = True, rand_on_empty: bool = False, sample_empty_shapes: Optional[Sequence[Tuple]] = None, dtypes: Optional[Sequence[Union[torch.dtype, Type]]] = None, ) -> None: self.wrapped_collator_fn = collator_fn self.batch_first = batch_first self.rand_on_empty = rand_on_empty self.sample_empty_shapes = sample_empty_shapes self.dtypes = dtypes self.first_batch = None def __call__(self, batch: List[Any]) -> Union[torch.Tensor, List, Mapping]: if len(batch) > 0: if not self.wrapped_collator_fn: output = batch else: output = self.wrapped_collator_fn(batch) if self.first_batch is None: self.first_batch = copy.deepcopy(output) else: if self.first_batch is None: if self.sample_empty_shapes is not None and self.dtypes is not None: logger.warning( "First batch is empty. We are using a list of zero-valued " "tensors as a batch. This may cause issues if the model " "expects a different batch format. To fix, use more data, " "increase epsilon, or increase sampling rate." ) return [ torch.zeros(shape, dtype=dtype) for shape, dtype in zip(self.sample_empty_shapes, self.dtypes) ] else: logger.warning( "First batch is empty. We are using an empty list as a " "batch. This may cause issues if the model expects a " "different batch format. To fix, use more data, increase " "epsilon, or increase sampling rate." ) return [] # materialize into empty with the same structure as list/dict output = self._make_empty_batch(self.first_batch) return output def _make_empty_batch( self, sample: Union[torch.Tensor, Mapping, List, Any] ) -> Union[torch.Tensor, Mapping, List, Any]: if torch.is_tensor(sample): shape = list(sample.shape) # If it's at least 1D, set batch dim to 1; otherwise make a 0-length 1D tensor batch_dim = 0 if self.batch_first else 1 shape[batch_dim] = 1 if self.rand_on_empty else 0 if self.rand_on_empty: return torch.randint( 0, 2, shape, dtype=sample.dtype, device=sample.device ) else: return torch.empty(shape, dtype=sample.dtype, device=sample.device) if isinstance(sample, Mapping): return {k: self._make_empty_batch(v) for k, v in sample.items()} if isinstance(sample, (list, tuple)): converted = [self._make_empty_batch(v) for v in sample] return type(sample)(converted) # Unsupported type - raise error to preserve DP guarantees raise TypeError( f"Unsupported batch type: {type(sample).__name__}. " f"CollateFnWithEmpty only supports batches containing torch.Tensor, " f"dict (Mapping), list, or tuple types. " f"If you need support for a different output type, please open an issue at " f"Opacus or submit a PR." )
[docs] def wrap_collate_with_empty( *, collate_fn: Optional[_collate_fn_t], batch_first: bool = True, rand_on_empty: bool = False, sample_empty_shapes: Optional[Sequence[Tuple]] = None, dtypes: Optional[Sequence[Union[torch.dtype, Type]]] = None, ) -> CollateFnWithEmpty: """ Wraps given collate function to handle empty batches. This function returns a stateful ``CollateFnWithEmpty`` instance that learns the batch structure from the first non-empty batch and uses this structure to generate properly shaped empty batches when needed. Args: collate_fn: collate function to wrap. If None, returns batches as-is. batch_first: Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be ``[batch_size, ...]``, otherwise ``[K, batch_size, ...]`` rand_on_empty: set ``True`` to return a batch containing random numbers when encountering empty batches rather than tensors with zero-length batch dimensions Returns: CollateFnWithEmpty: A callable that is equivalent to input ``collate_fn`` for non-empty batches and outputs empty tensors with the same structure when the input batch is empty. The structure is learned from the first non-empty batch. Example: >>> from torch.utils.data._utils.collate import default_collate >>> collate = wrap_collate_with_empty(collate_fn=default_collate) >>> # First batch defines structure >>> result = collate([{"x": torch.tensor([1, 2])}]) >>> # Empty batch uses learned structure >>> empty = collate([]) # Returns {"x": torch.tensor([])} """ return CollateFnWithEmpty( collate_fn, batch_first=batch_first, rand_on_empty=rand_on_empty, sample_empty_shapes=sample_empty_shapes, dtypes=dtypes, )
[docs] def shape_safe(x: Any) -> Tuple: """Exception-safe getter for ``shape`` attribute.""" return getattr(x, "shape", ())
[docs] def dtype_safe(x: Any) -> Union[torch.dtype, Type]: """Exception-safe getter for ``dtype`` attribute.""" return getattr(x, "dtype", type(x))
[docs] class DPDataLoader(DataLoader): """ DataLoader subclass that always does Poisson sampling and supports empty batches by default. Typically instantiated via ``DPDataLoader.from_data_loader()`` method based on another DataLoader. DPDataLoader would preserve the behaviour of the original data loader, except for the two aspects. First, it switches ``batch_sampler`` to ``UniformWithReplacementSampler``, thus enabling Poisson sampling (i.e. each element in the dataset is selected to be in the next batch with a certain probability defined by ``sample_rate`` parameter). NB: this typically leads to a batches of variable size. NB2: By default, ``sample_rate`` is calculated based on the ``batch_size`` of the original data loader, so that the average batch size stays the same Second, it wraps collate function with support for empty batches. Most PyTorch modules will happily process tensors of shape ``(0, N, ...)``, but many collate functions will fail to produce such a batch. As with the Poisson sampling empty batches become a possibility, we need a DataLoader that can handle them. """ def __init__( self, dataset: Dataset, *, sample_rate: float, collate_fn: Optional[_collate_fn_t] = None, drop_last: bool = False, generator=None, distributed: bool = False, batch_first: bool = True, rand_on_empty: bool = False, **kwargs, ): """ Args: dataset: See :class:`torch.utils.data.DataLoader` sample_rate: probability with which each element of the dataset is included in the next batch. num_workers: See :class:`torch.utils.data.DataLoader` collate_fn: See :class:`torch.utils.data.DataLoader` pin_memory: See :class:`torch.utils.data.DataLoader` drop_last: See :class:`torch.utils.data.DataLoader` timeout: See :class:`torch.utils.data.DataLoader` worker_init_fn: See :class:`torch.utils.data.DataLoader` multiprocessing_context: See :class:`torch.utils.data.DataLoader` generator: Random number generator used to sample elements prefetch_factor: See :class:`torch.utils.data.DataLoader` persistent_workers: See :class:`torch.utils.data.DataLoader` distributed: set ``True`` if you'll be using DPDataLoader in a DDP environment Selects between ``DistributedUniformWithReplacementSampler`` and ``UniformWithReplacementSampler`` sampler implementations rand_on_empty: set ``True`` to return a batch containing random numbers when encountering empty batches rather than tensors with zero-length batch dimensions """ self.sample_rate = sample_rate self.distributed = distributed if distributed: batch_sampler = DistributedUniformWithReplacementSampler( total_size=len(dataset), # type: ignore[assignment, arg-type] sample_rate=sample_rate, generator=generator, ) else: batch_sampler = UniformWithReplacementSampler( num_samples=len(dataset), # type: ignore[assignment, arg-type] sample_rate=sample_rate, generator=generator, ) sample_empty_shapes = [(0, *shape_safe(x)) for x in dataset[0]] dtypes = [dtype_safe(x) for x in dataset[0]] if collate_fn is None: collate_fn = default_collate if drop_last: logger.warning( "Ignoring drop_last as it is not compatible with DPDataLoader." ) super().__init__( dataset=dataset, batch_sampler=batch_sampler, collate_fn=wrap_collate_with_empty( collate_fn=collate_fn, batch_first=batch_first, rand_on_empty=rand_on_empty, sample_empty_shapes=sample_empty_shapes, dtypes=dtypes, ), generator=generator, **kwargs, )
[docs] @classmethod def from_data_loader( cls, data_loader: DataLoader, *, distributed: bool = False, generator=None, batch_first: bool = True, rand_on_empty: bool = False, ): """ Creates new ``DPDataLoader`` based on passed ``data_loader`` argument. Args: data_loader: Any DataLoader instance. Must not be over an ``IterableDataset`` distributed: set ``True`` if you'll be using DPDataLoader in a DDP environment generator: Random number generator used to sample elements. Defaults to generator from the original data loader. batch_first: Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be ``[batch_size, ...]``, otherwise ``[K, batch_size, ...]`` rand_on_empty: set ``True`` to return a batch containing random numbers when encountering empty batches rather than tensors with zero-length batch dimensions Returns: New DPDataLoader instance, with all attributes and parameters inherited from the original data loader, except for sampling mechanism. Examples: >>> x, y = torch.randn(64, 5), torch.randint(0, 2, (64,)) >>> dataset = TensorDataset(x,y) >>> data_loader = DataLoader(dataset, batch_size=4) >>> dp_data_loader = DPDataLoader.from_data_loader(data_loader) """ if isinstance(data_loader.dataset, IterableDataset): raise ValueError("Uniform sampling is not supported for IterableDataset") return cls( dataset=data_loader.dataset, sample_rate=1 / len(data_loader), num_workers=data_loader.num_workers, collate_fn=data_loader.collate_fn, pin_memory=data_loader.pin_memory, drop_last=data_loader.drop_last, timeout=data_loader.timeout, worker_init_fn=data_loader.worker_init_fn, multiprocessing_context=data_loader.multiprocessing_context, generator=generator if generator else data_loader.generator, prefetch_factor=data_loader.prefetch_factor, persistent_workers=data_loader.persistent_workers, distributed=distributed, batch_first=batch_first, rand_on_empty=rand_on_empty, )
def _is_supported_batch_sampler(sampler: Sampler): return ( isinstance(sampler, BatchSampler) or isinstance(sampler, UniformWithReplacementSampler) or isinstance(sampler, DistributedUniformWithReplacementSampler) )
[docs] def switch_generator(*, data_loader: DataLoader, generator): """ Creates new instance of a ``DataLoader``, with the exact same behaviour of the provided data loader, except for the source of randomness. Typically used to enhance a user-provided data loader object with cryptographically secure random number generator Args: data_loader: Any ``DataLoader`` object generator: Random number generator object Returns: New ``DataLoader`` object with the exact same behaviour as the input data loader, except for the source of randomness. """ batch_sampler = data_loader.batch_sampler if batch_sampler is None or not _is_supported_batch_sampler(batch_sampler): raise ValueError( "Non-batch processing is not supported: Opacus always assumes one of the input dimensions to be batch dimension." ) if isinstance(batch_sampler, BatchSampler): if not hasattr(batch_sampler.sampler, "generator"): raise ValueError( "Target sampler doesn't have generator attribute: nothing to switch" ) batch_sampler.sampler.generator = generator else: batch_sampler.generator = generator return DataLoader( dataset=data_loader.dataset, batch_sampler=batch_sampler, num_workers=data_loader.num_workers, collate_fn=data_loader.collate_fn, pin_memory=data_loader.pin_memory, drop_last=data_loader.drop_last, timeout=data_loader.timeout, worker_init_fn=data_loader.worker_init_fn, multiprocessing_context=data_loader.multiprocessing_context, generator=generator, prefetch_factor=data_loader.prefetch_factor, persistent_workers=data_loader.persistent_workers, )