Source code for opacus.optimizers.ddpoptimizer

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Callable, Optional

import torch
from torch.optim import Optimizer

from .optimizer import DPOptimizer


[docs]class DistributedDPOptimizer(DPOptimizer): """ :class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with distributed data processing """ def __init__( self, optimizer: Optimizer, *, noise_multiplier: float, max_grad_norm: float, expected_batch_size: Optional[int], loss_reduction: str = "mean", generator=None, secure_mode: bool = False, ): super().__init__( optimizer, noise_multiplier=noise_multiplier, max_grad_norm=max_grad_norm, expected_batch_size=expected_batch_size, loss_reduction=loss_reduction, generator=generator, secure_mode=secure_mode, ) self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size()
[docs] def add_noise(self): # Noise only gets added to the first worker if self.rank == 0: super().add_noise() else: for p in self.params: p.grad = p.summed_grad.view_as(p)
def reduce_gradients(self): for p in self.params: if not p.requires_grad: continue torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.SUM) if self.loss_reduction == "mean": p.grad /= self.world_size
[docs] def step( self, closure: Optional[Callable[[], float]] = None ) -> Optional[torch.Tensor]: if self.pre_step(): self.reduce_gradients() return self.original_optimizer.step(closure) else: return None