ModuleValidator

class opacus.validators.module_validator.ModuleValidator[source]

Encapsulates all the validation logic required by Opacus. Also works as a namespace to hold registered validators and fixers.

classmethod fix(module, **kwargs)[source]

Make the module and sub_modules DP compatible by running registered custom fixers.

Parameters:
  • module (Module) – The root module to be made compatible.

  • **kwargs – Arbitrary keyword arguments.

Return type:

Module

Returns:

Fixed module.

classmethod fix_and_validate(module, **kwargs)[source]

Fix the module and sub_modules first, and then run validation.

Parameters:
  • module (Module) – The root module to be fixed and validated

  • **kwargs – Arbitrary keyword arguments.

Return type:

Module

Returns:

Fixed module.

Raises:

UnsupportedModuleError in case of validation failures.

classmethod is_valid(module)[source]

Check if module and sub_modules are valid by running registered custom validators.

Parameters:

module (Module) – The root module to validate.

Return type:

bool

Returns:

bool

classmethod validate(module, *, strict=False)[source]

Validate module and sub_modules by running registered custom validators. Returns or raises exceptions depending on strict flag.

Parameters:
  • module (Module) – The root module to validate.

  • strict (bool) – Boolean to indicate whether to raise errors or return

  • errors. (the list of) –

Raises:

UnsupportedModuleError in case of validation failures.

Return type:

List[UnsupportedModuleError]

opacus.validators.utils.register_module_fixer(target_class_or_classes, validator_class=<class 'opacus.validators.module_validator.ModuleValidator'>)[source]

Registers the decorated function as the fixer of target_class_or_classes, which is the function that will be invoked every time you want to fix an incompatoble module to make it work for training with Opacus. You may supply your own validator_class that holds the registry of FIXERS. The signature of every fixer is always the same:

>>> @register_module_fixer(MyCustomModel)
... def fix(module: nn.Module, **kwargs) -> nn.Module:
...    pass

It may help you to take a look at the existing fixers inside Opacus, under opacus.validators.

opacus.validators.utils.register_module_validator(target_class_or_classes, validator_class=<class 'opacus.validators.module_validator.ModuleValidator'>)[source]

Registers the decorated function as the validator of target_class_or_classes, which is the function that will be invoked every time you want to validate that a module is compatible for training with Opacus. You may supply your own validator_class that holds the registry of VALIDATORS. The signature of every validator is always the same:

>>> @register_module_validator(MyCustomModel)
... def validate(module: nn.Module, **kwargs) -> List[opacus.validators.errors.UnsupportedError]:
...    pass

It may help you to take a look at the existing validator inside Opacus, under opacus.validators.