Source code for opacus.optimizers.ddpoptimizer
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Callable, Optional
import torch
from torch.optim import Optimizer
from .optimizer import DPOptimizer
[docs]
class DistributedDPOptimizer(DPOptimizer):
"""
:class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with
distributed data processing
"""
def __init__(
self,
optimizer: Optimizer,
*,
noise_multiplier: float,
max_grad_norm: float,
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
):
super().__init__(
optimizer,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=expected_batch_size,
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
)
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
[docs]
def add_noise(self):
# Noise only gets added to the first worker
if self.rank == 0:
super().add_noise()
else:
for p in self.params:
p.grad = p.summed_grad.view_as(p)
def reduce_gradients(self):
for p in self.params:
if not p.requires_grad:
continue
torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.SUM)
if self.loss_reduction == "mean":
p.grad /= self.world_size
[docs]
def step(
self, closure: Optional[Callable[[], float]] = None
) -> Optional[torch.Tensor]:
if closure is not None:
with torch.enable_grad():
closure()
if self.pre_step():
self.reduce_gradients()
return self.original_optimizer.step()
else:
return None