DPOptimizer¶
- class opacus.optimizers.optimizer.DPOptimizer(optimizer, *, noise_multiplier, max_grad_norm, expected_batch_size, loss_reduction='mean', generator=None, secure_mode=False)[source]¶
torch.optim.Optimizer
wrapper that adds additional functionality to clip per sample gradients and add Gaussian noise.Can be used with any
torch.optim.Optimizer
subclass as an underlying optimizer.DPOptimzer
assumes that parameters over which it performs optimization belong to GradSampleModule and therefore have thegrad_sample
attribute.On a high level
DPOptimizer
’s step looks like this: 1) Aggregatep.grad_sample
over all parameters to calculate per sample norms 2) Clipp.grad_sample
so that per sample norm is not above threshold 3) Aggregate clipped per sample gradients intop.grad
4) Add Gaussian noise top.grad
calibrated to a given noise multiplier and max grad norm limit (std = noise_multiplier * max_grad_norm
). 5) Call underlying optimizer to perform optimization stepExamples
>>> module = MyCustomModel() >>> optimizer = torch.optim.SGD(module.parameters(), lr=0.1) >>> dp_optimizer = DPOptimizer( ... optimizer=optimizer, ... noise_multiplier=1.0, ... max_grad_norm=1.0, ... expected_batch_size=4, ... )
- Parameters:
optimizer (
Optimizer
) – wrapped optimizer.noise_multiplier (
float
) – noise multipliermax_grad_norm (
float
) – max grad norm used for gradient clippingexpected_batch_size (
Optional
[int
]) – batch_size used for averaging gradients. When using Poisson sampling averaging denominator can’t be inferred from the actual batch size. Required isloss_reduction="mean"
, ignored ifloss_reduction="sum"
loss_reduction (
str
) – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values “sum” or “mean”generator – torch.Generator() object used as a source of randomness for the noise
secure_mode (
bool
) – ifTrue
uses noise generation approach robust to floating point arithmetic attacks. See_generate_noise()
for details
- property accumulated_iterations: int¶
Returns number of batches currently accumulated and not yet processed.
In other words
accumulated_iterations
tracks the number of forward/backward passed done in between two optimizer steps. The value would typically be 1, but there are possible exceptions.Used by privacy accountants to calculate real sampling rate.
- attach_step_hook(fn)[source]¶
Attaches a hook to be executed after gradient clipping/noising, but before the actual optimization step.
Most commonly used for privacy accounting.
- Parameters:
fn (
Callable
[[DPOptimizer
],None
]) – hook function. Expected signature:foo(optim: DPOptimizer)
- clip_and_accumulate()[source]¶
Performs gradient clipping. Stores clipped and aggregated gradients into p.summed_grad``
- property grad_samples: List[Tensor]¶
Returns a flat list of per sample gradient tensors (one per parameter)
- load_state_dict(state_dict)[source]¶
Loads the optimizer state.
- Parameters:
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict()
.- Return type:
- property param_groups: List[dict]¶
Returns a list containing a dictionary of all parameters managed by the optimizer.
- pre_step(closure=None)[source]¶
Perform actions specific to
DPOptimizer
before calling underlyingoptimizer.step()
- scale_grad()[source]¶
Applies given
loss_reduction
top.grad
.Does nothing if
loss_reduction="sum"
. Divides gradients byself.expected_batch_size
ifloss_reduction="mean"
- signal_skip_step(do_skip=True)[source]¶
Signals the optimizer to skip an optimization step and only perform clipping and per sample gradient accumulation.
On every call of
.step()
optimizer will check the queue of skipped step signals. If non-empty and the latest flag isTrue
, optimizer will callself.clip_and_accumulate
, but won’t proceed to adding noise and performing the actual optimization step. It also affects the behaviour ofzero_grad()
. If the last step was skipped, optimizer will clear per sample gradients accumulated byself.clip_and_accumulate
(p.grad_sample
), but won’t touch aggregated clipped gradients (p.summed_grad
)Used by
BatchMemoryManager
to simulate large virtual batches with limited memory footprint.- Parameters:
do_skip – flag if next step should be skipped
- property state: defaultdict¶
Returns a dictionary holding current optimization state.
- state_dict()[source]¶
Returns the state of the optimizer as a
dict
.It contains two entries:
state
: a Dict holding current optimization state. Its contentdiffers between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved.
state
is a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.
param_groups
: a List containing all parameter groups where eachparameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group.
NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group
params
(int IDs) and the optimizerparam_groups
(actualnn.Parameter
s) in order to match state WITHOUT additional verification.A returned state dict might look something like:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] } ] }
- step(closure=None)[source]¶
Performs a single optimization step (parameter update).
- Parameters:
closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
- Return type:
Note
Unless otherwise specified, this function should not modify the
.grad
field of the parameters.
- zero_grad(set_to_none=False)[source]¶
Clear gradients.
Clears
p.grad
,p.grad_sample
andp.summed_grad
for all of it’s parametersNotes
set_to_none
argument only affectsp.grad
.p.grad_sample
andp.summed_grad
is never zeroed out and always set to None. Normal grads can do this, because their shape is always the same. Grad samples do not behave like this, as we accumulate gradients from different batches in a list- Parameters:
set_to_none (
bool
) – instead of setting to zero, set the grads to None. (onlyNone) (affects regular gradients. Per sample gradients are always set to)