# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List
import numpy as np
from opacus.optimizers import DPOptimizer
from opacus.utils.uniform_sampler import (
DistributedUniformWithReplacementSampler,
UniformWithReplacementSampler,
)
from torch.utils.data import BatchSampler, DataLoader, Sampler
[docs]
class BatchSplittingSampler(Sampler[List[int]]):
"""
Samples according to the underlying instance of ``Sampler``, but splits
the index sequences into smaller chunks.
Used to split large logical batches into physical batches of a smaller size,
while coordinating with DPOptimizer when the logical batch has ended.
"""
def __init__(
self,
*,
sampler: Sampler[List[int]],
max_batch_size: int,
optimizer: DPOptimizer,
):
"""
Args:
sampler: Wrapped Sampler instance
max_batch_size: Max size of emitted chunk of indices
optimizer: optimizer instance to notify when the logical batch is over
"""
self.sampler = sampler
self.max_batch_size = max_batch_size
self.optimizer = optimizer
def __iter__(self):
for batch_idxs in self.sampler:
if len(batch_idxs) == 0:
self.optimizer.signal_skip_step(do_skip=False)
yield []
continue
split_idxs = np.array_split(
batch_idxs, math.ceil(len(batch_idxs) / self.max_batch_size)
)
split_idxs = [s.tolist() for s in split_idxs]
for x in split_idxs[:-1]:
self.optimizer.signal_skip_step(do_skip=True)
yield x
self.optimizer.signal_skip_step(do_skip=False)
yield split_idxs[-1]
def __len__(self):
if isinstance(self.sampler, BatchSampler):
return math.ceil(
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
)
elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
self.sampler, DistributedUniformWithReplacementSampler
):
expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return math.ceil(
len(self.sampler) * (expected_batch_size / self.max_batch_size)
)
return len(self.sampler)
[docs]
def wrap_data_loader(
*, data_loader: DataLoader, max_batch_size: int, optimizer: DPOptimizer
):
"""
Replaces batch_sampler in the input data loader with ``BatchSplittingSampler``
Args:
data_loader: Wrapper DataLoader
max_batch_size: max physical batch size we want to emit
optimizer: DPOptimizer instance used for training
Returns:
New DataLoader instance with batch_sampler wrapped in ``BatchSplittingSampler``
"""
return DataLoader(
dataset=data_loader.dataset,
batch_sampler=BatchSplittingSampler(
sampler=data_loader.batch_sampler,
max_batch_size=max_batch_size,
optimizer=optimizer,
),
num_workers=data_loader.num_workers,
collate_fn=data_loader.collate_fn,
pin_memory=data_loader.pin_memory,
timeout=data_loader.timeout,
worker_init_fn=data_loader.worker_init_fn,
multiprocessing_context=data_loader.multiprocessing_context,
generator=data_loader.generator,
prefetch_factor=data_loader.prefetch_factor,
persistent_workers=data_loader.persistent_workers,
)
[docs]
class BatchMemoryManager:
"""
Context manager to manage memory consumption during training.
Allows setting hard limit on the physical batch size as a just one line code change.
Can be used both for simulating large logical batches with limited memory and for
safeguarding against occasional large batches produced by
:class:`~opacus.utils.uniform_sampler.UniformWithReplacementSampler`.
Note that it doesn't modify the input DataLoader, you'd need to use new DataLoader
returned by the context manager.
BatchSplittingSampler will split large logical batches into smaller sub-batches with
certain maximum size.
On every step optimizer will check if the batch was the last physical batch comprising
a logical one, and will change behaviour accordingly.
If it was not the last, ``optimizer.step()`` will only clip per sample gradients and
sum them into ``p.summed_grad`.` ``optimizer.zero_grad()`` will clear ``p.grad_sample``,
but will leave ``p.grad`` and ``p.summed_grad``
If the batch was the last one of the current logical batch, then
``optimizer.step()`` and ``optimizer.zero_grad()`` will behave normally.
Example:
>>> # Assuming you've initialized your objects and passed them to PrivacyEngine.
>>> # For this example we assume data_loader is initialized with batch_size=4
>>> model, optimizer, data_loader = _init_private_training()
>>> criterion = nn.CrossEntropyLoss()
>>> with BatchMemoryManager(
... data_loader=data_loader, max_physical_batch_size=2, optimizer=optimizer
... ) as new_data_loader:
... for data, label in new_data_loader:
... assert len(data) <= 2 # physical batch is no more than 2
... output = model(data)
... loss = criterion(output, label)
... loss.backward()
... # optimizer won't actually make a step unless logical batch is over
... optimizer.step()
... # optimizer won't actually clear gradients unless logical batch is over
... optimizer.zero_grad()
"""
def __init__(
self,
*,
data_loader: DataLoader,
max_physical_batch_size: int,
optimizer: DPOptimizer,
):
self.data_loader = data_loader
self.optimizer = optimizer
self.max_physical_batch_size = max_physical_batch_size
def __enter__(self):
return wrap_data_loader(
data_loader=self.data_loader,
max_batch_size=self.max_physical_batch_size,
optimizer=self.optimizer,
)
def __exit__(self, type, value, traceback):
pass