Batch Memory Manager

class opacus.utils.batch_memory_manager.BatchMemoryManager(*, data_loader, max_physical_batch_size, optimizer)[source]

Context manager to manage memory consumption during training.

Allows setting hard limit on the physical batch size as a just one line code change. Can be used both for simulating large logical batches with limited memory and for safeguarding against occasional large batches produced by UniformWithReplacementSampler.

Note that it doesn’t modify the input DataLoader, you’d need to use new DataLoader returned by the context manager.

BatchSplittingSampler will split large logical batches into smaller sub-batches with certain maximum size. On every step optimizer will check if the batch was the last physical batch comprising a logical one, and will change behaviour accordingly.

If it was not the last, optimizer.step() will only clip per sample gradients and sum them into p.summed_grad`.` ``optimizer.zero_grad() will clear p.grad_sample, but will leave p.grad and p.summed_grad

If the batch was the last one of the current logical batch, then optimizer.step() and optimizer.zero_grad() will behave normally.


>>> # Assuming you've initialized your objects and passed them to PrivacyEngine.
>>> # For this example we assume data_loader is initialized with batch_size=4
>>> model, optimizer, data_loader = _init_private_training()
>>> criterion = nn.CrossEntropyLoss()
>>> with BatchMemoryManager(
...     data_loader=data_loader, max_physical_batch_size=2, optimizer=optimizer
... ) as new_data_loader:
...     for data, label in new_data_loader:
...         assert len(data) <= 2 # physical batch is no more than 2
...         output = model(data)
...         loss = criterion(output, label)
...         loss.backward()
...         # optimizer won't actually make a step unless logical batch is over
...         optimizer.step()
...         # optimizer won't actually clear gradients unless logical batch is over
...         optimizer.zero_grad()
class opacus.utils.batch_memory_manager.BatchSplittingSampler(*, sampler, max_batch_size, optimizer)[source]

Samples according to the underlying instance of Sampler, but splits the index sequences into smaller chunks.

Used to split large logical batches into physical batches of a smaller size, while coordinating with DPOptimizer when the logical batch has ended.

  • sampler (Sampler[List[int]]) – Wrapped Sampler instance

  • max_batch_size (int) – Max size of emitted chunk of indices

  • optimizer (DPOptimizer) – optimizer instance to notify when the logical batch is over

opacus.utils.batch_memory_manager.wrap_data_loader(*, data_loader, max_batch_size, optimizer)[source]

Replaces batch_sampler in the input data loader with BatchSplittingSampler

  • data_loader (DataLoader) – Wrapper DataLoader

  • max_batch_size (int) – max physical batch size we want to emit

  • optimizer (DPOptimizer) – DPOptimizer instance used for training


New DataLoader instance with batch_sampler wrapped in BatchSplittingSampler