Source code for opacus.scheduler

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict

from .optimizers import DPOptimizer


class _NoiseScheduler(object):
    def __init__(self, optimizer: DPOptimizer, *, last_epoch=-1):
        self.optimizer = optimizer
        self.last_epoch = last_epoch

        self.step()

    def state_dict(self) -> Dict:
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.

        """
        return {
            key: value for key, value in self.__dict__.items() if key != "optimizer"
        }

    def load_state_dict(self, state_dict: Dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_noise_multiplier(self):
        # Compute learning rate using chainable form of the scheduler
        raise NotImplementedError

    def step(self):
        self.last_epoch += 1
        noise_multiplier = self.get_noise_multiplier()
        self.optimizer.noise_multiplier = noise_multiplier


[docs]class ExponentialNoise(_NoiseScheduler): """ Decays the noise_multiplier by gamma every epoch. When last_epoch=-1, sets initial noise_multiplier as noise_multiplier. """ def __init__(self, optimizer: DPOptimizer, *, gamma: float, last_epoch: int = -1): """ Args: optimizer: Wrapped optimizer gamma: Multiplicative factor of learning rate decay. last_epoch: The index of last epoch """ self.gamma = gamma super().__init__(optimizer, last_epoch=last_epoch) def get_noise_multiplier(self): if self.last_epoch == 0: return self.optimizer.noise_multiplier else: return self.optimizer.noise_multiplier * self.gamma
[docs]class LambdaNoise(_NoiseScheduler): """ Sets the noise_multiplier to the initial noise_multiplier times a given function. When last_epoch=-1, sets initial noise_multiplier as noise_multiplier. """ def __init__( self, optimizer: DPOptimizer, *, noise_lambda: Callable[[int], float], last_epoch: int = -1, ): """ Args: optimizer: Wrapped optimizer. noise_lambda: A function which computes a multiplicative factor given an integer epoch last_epoch: The index of last epoch. Default: -1. """ self.noise_lambda = noise_lambda self.base_noise_multiplier = optimizer.noise_multiplier super().__init__(optimizer, last_epoch=last_epoch) def get_noise_multiplier(self): return self.base_noise_multiplier * self.noise_lambda(self.last_epoch)
[docs]class StepNoise(_NoiseScheduler): """ Decays the noise_multiplier by gamma every step_size epochs. When last_epoch=-1, sets initial noise_multiplier as noise_multiplier. """ def __init__( self, optimizer: DPOptimizer, *, step_size: int, gamma: float, last_epoch: int = -1, ): """ Args: optimizer: Wrapped optimizer. step_size: Period of learning rate decay. gamma: Multiplicative factor of learning rate decay. last_epoch: The index of last epoch """ self.step_size = step_size self.gamma = gamma super().__init__(optimizer, last_epoch=last_epoch) def get_noise_multiplier(self): # Only change noise_multiplier when at a 'step' if self.last_epoch == 0 or self.last_epoch % self.step_size != 0: return self.optimizer.noise_multiplier else: return self.gamma * self.optimizer.noise_multiplier