Source code for opacus.validators.module_validator

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List

import torch.nn as nn
from opacus.utils.module_utils import clone_module, get_submodule, trainable_modules
from opacus.validators.errors import (

logger = logging.getLogger(__name__)

[docs]class ModuleValidator: """ Encapsulates all the validation logic required by Opacus. Also works as a namespace to hold registered validators and fixers. """ VALIDATORS = {} FIXERS = {}
[docs] @classmethod def validate( cls, module: nn.Module, *, strict: bool = False ) -> List[UnsupportedModuleError]: """ Validate module and sub_modules by running registered custom validators. Returns or raises exceptions depending on ``strict`` flag. Args: module: The root module to validate. strict: Boolean to indicate whether to raise errors or return the list of errors. Raises: UnsupportedModuleError in case of validation failures. """ errors = [] # 1. validate that module is in training mode if not errors.append( IllegalModuleConfigurationError("Model needs to be in training mode") ) # 2. perform module specific validations for trainable modules. # TODO: use module name here - it's useful part of error message for _, sub_module in trainable_modules(module): if type(sub_module) in ModuleValidator.VALIDATORS: sub_module_validator = ModuleValidator.VALIDATORS[type(sub_module)] errors.extend(sub_module_validator(sub_module)) # raise/return as needed if strict and len(errors) > 0: raise UnsupportedModuleError(errors) else: return errors
[docs] @classmethod def is_valid(cls, module: nn.Module) -> bool: """ Check if module and sub_modules are valid by running registered custom validators. Args: module: The root module to validate. Returns: bool """ return len(cls.validate(module, strict=False)) == 0
[docs] @classmethod def fix(cls, module: nn.Module, **kwargs) -> nn.Module: """ Make the module and sub_modules DP compatible by running registered custom fixers. Args: module: The root module to be made compatible. **kwargs: Arbitrary keyword arguments. Returns: Fixed module. """ module = clone_module(module) # iterate over all sub_modules # We have to get sub_module names in a list first as we will be # changing the modules inside the loop. sub_module_names = [name for name, _ in trainable_modules(module)] for sub_module_name in sub_module_names: # get sub_module sub_module = get_submodule(module, sub_module_name) # if sub_module has a registered fixer if type(sub_module) in ModuleValidator.FIXERS: # get a replacement for sub_module sub_module_fixer = ModuleValidator.FIXERS[type(sub_module)] new_sub_module = sub_module_fixer(sub_module, **kwargs) # move new_sub_module to the same device as that of sub_module # get module after replacement. module = cls._replace_sub_module( root=module, sub_module_name=sub_module_name, new_sub_module=new_sub_module, ) # log it f"Replaced sub_module {sub_module_name} : {sub_module}" f" with {new_sub_module}" ) # return fixed module return module
@classmethod def _replace_sub_module( cls, *, root: nn.Module, sub_module_name: str, new_sub_module: nn.Module, ) -> None: sub_module_path = sub_module_name.split(".") if ( len(sub_module_path) == 1 and sub_module_path[0] == "" ): # root is the only sub_module of root return new_sub_module else: # replace root's descendant sub_module_parent = root for name in sub_module_path[:-1]: # descend down to sub_module sub_module_parent = sub_module_parent._modules[name] sub_module_parent._modules[sub_module_path[-1]] = new_sub_module return root
[docs] @classmethod def fix_and_validate(cls, module: nn.Module, **kwargs) -> nn.Module: """ Fix the module and sub_modules first, and then run validation. Args: module: The root module to be fixed and validated **kwargs: Arbitrary keyword arguments. Returns: Fixed module. Raises: UnsupportedModuleError in case of validation failures. """ # 1. replace any fixable modules fixed_module = cls.fix(module, **kwargs) # 2. perform module specific validations. cls.validate(fixed_module, strict=True) # return fixed module return fixed_module