Module Modification¶
This module includes utils for modifying model layers, replacing layers etc.
-
opacus.utils.module_modification.
convert_batchnorm_modules
(model, converter=<function _batchnorm_to_groupnorm>)[source]¶ Converts all BatchNorm modules to another module (defaults to GroupNorm) that is privacy compliant.
- Parameters
model (
Module
) – Module instance, potentially with sub-modulesconverter (
Callable
[[_BatchNorm
],Module
]) – Function or a lambda that converts an instance of a Batchnorm to another nn.Module.
- Return type
Module
- Returns
Model with all the BatchNorm types replaced by another operation by using the provided converter, defaulting to GroupNorm if one isn’t provided.
Example
>>> from torchvision.models import resnet50 >>> from torch import nn >>> model = resnet50() >>> print(model.layer1[0].bn1) BatchNorm2d module details >>> model = convert_batchnorm_modules(model) >>> print(model.layer1[0].bn1) GroupNorm module details
-
opacus.utils.module_modification.
nullify_batchnorm_modules
(root)[source]¶ Replaces all the BatchNorm submodules (e.g.
torch.nn.BatchNorm1d
,torch.nn.BatchNorm2d
etc.) inroot
withtorch.nn.Identity
.- Parameters
root (
Module
) – Module for which to replace BatchNorm submodules.- Return type
Module
- Returns
Module with all the BatchNorm sub modules replaced with Identity.
root
is modified and is equal to the return value.
Notes
Most of the times replacing a BatchNorm module with Identity will heavily affect convergence of the model.
-
opacus.utils.module_modification.
replace_all_modules
(root, target_class, converter)[source]¶ Converts all the submodules (of root) that have the same type as target_class, given a converter, a module root, and a target class type.
This method is useful for replacing modules that are not supported by the Privacy Engine.
- Parameters
- Return type
Module
- Returns
Module with all the target_class types replaced using the converter. root is modified and is equal to the return value.
Example
>>> from torchvision.models import resnet18 >>> from torch import nn >>> model = resnet18() >>> print(model.layer1[0].bn1) BatchNorm2d(64, eps=1e-05, ... >>> model = replace_all_modules(model, nn.BatchNorm2d, lambda _: nn.Identity()) >>> print(model.layer1[0].bn1) Identity()