Opacus strives to enable private training of PyTorch models with minimal code changes on the user side. As you might have learnt by following the README and the introductory tutorials, Opacus does this by consuming your model, dataloader, and optimizer and returning wrapped counterparts that can perform privacy-related functions.
While most of the common models work with Opacus, not all of them do.
nn.ReLU
, nn.Tanh
, etc.) and frozen modules (with parameters whose requires_grad
is set to False
) are compatible.GradSampleModule
's and implementations offered by opacus.layers
have this property.BatchNorm
are not DP friendly as a sample's normalized value depends on other samples, and hence are incompatible with Opacus.InstanceNorm
are DP friendly, except under certain configurations (eg, when track_running_stats
is On).It is unreasonable to expect you to remember all of this and take care of it. This is why Opacus provides a ModuleValidator
to take care of this.
ModuleValidator
internals¶The ModuleValidator
class has two primary class methods validate()
and fix()
.
As the name suggests, validate()
validates a given module's compatibility with Opacus by ensuring it is in training mode and is of type GradSampleModule
(i.e, the module can capture per sample gradients). More importantly, this method also checks the sub-modules and their configurations for compatibility issues (more on this in the next section).
The fix()
method attempts to make the module compatible with Opacus.
In Opacus 0.x, the specific checks for each of the supported modules and the necessary replacements were done centrally in the validator with a series of if
checks. Adding new validation checks and fixes would have necessitated modifying the core Opacus code. In Opacus 1.0, this has been modularised by allowing you to register your own custom validator and fixer.
In the rest of the tutorial, we will consider nn.BatchNorm
as an example and show exactly how to do that.
We know that BatchNorm
module is not privacy friendly and hence the validator should throw an error, say like this
def validate_bathcnorm(module):
return [Exception("BatchNorm is not supported")]
In order to register the above, all you need to do is decorate the above method as follows.
from opacus.validators import register_module_validator
@register_module_validator(
[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]
)
def validate_bathcnorm(module):
return [Exception("BatchNorm is not supported")]
That's it! The above will register validate_bathcnorm()
for all of these modules: [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]
, and this method will be automatically called along with other validators when you do privacy_engine.make_private()
.
The decorator essentially adds your method to ModuleValidator
's register for it to be cycled through during the validation phase.
Just one nit bit: it is recommended that you make your validation exceptions as clear as possible. Opacus's validation for the above looks as follows:
@register_module_validator(
[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]
)
def validate(module) -> None:
return [
ShouldReplaceModuleError(
"BatchNorm cannot support training with differential privacy. "
"The reason for it is that BatchNorm makes each sample's normalized value "
"depend on its peers in a batch, ie the same sample x will get normalized to "
"a different value depending on who else is in its batch. "
"Privacy-wise, this means that we would have to put a privacy mechanism there too. "
"While it can in principle be done, there are now multiple normalization layers that "
"do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm "
"are all privacy-safe since they don't have this property."
"We offer utilities to automatically replace BatchNorms to GroupNorms and we will "
"release pretrained models to help transition, such as GN-ResNet ie a ResNet using "
"GroupNorm, pretrained on ImageNet"
)
]. # quite a mouthful, but is super clear! ;)
Validating is good, but can we fix the issue when possible? The answer, of course, is yes. And the syntax is pretty much the same as that of validator.
BatchNorm
, for example, can be replaced with GroupNorm
without any meaningful loss of performance and still being privacy friendly. In Opacus, we do it as follows:
def _batchnorm_to_groupnorm(module) -> nn.GroupNorm:
"""
Converts a BatchNorm ``module`` to GroupNorm module.
This is a helper function.
Args:
module: BatchNorm module to be replaced
Returns:
GroupNorm module that can replace the BatchNorm module provided
Notes:
A default value of 32 is chosen for the number of groups based on the
paper *Group Normalization* https://arxiv.org/abs/1803.08494
"""
return nn.GroupNorm(
min(32, module.num_features), module.num_features, affine=module.affine
)
from opacus.validators.utils import register_module_fixer
@register_module_fixer(
[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]
)
def fix(module) -> nn.GroupNorm:
logger.info(
"The default batch_norm fixer replaces BatchNorm with GroupNorm."
" The batch_norm validator module also offers implementations to replace"
" it with InstanceNorm or Identity. Please check them out and override the"
" fixer if those are more suitable for your needs."
)
return _batchnorm_to_groupnorm(module)
Opacus does NOT automatically fix the module for you when you call privacy_engine.make_private()
; it expects the module to be compliant before it is passed in. However, this can easily be done as follows:
import torch
from opacus.validators import ModuleValidator
model = torch.nn.Linear(2,1)
if not ModuleValidator.is_valid(model):
model = ModuleValidator.fix(model)
If you want to use a custom fixer in place of the one provided, you can simply decorate your function using this same decorator. Note that the order of registration matters and the last function to be registered will be the one used.
Eg: to only replace BatchNorm2d
with InstanceNorm
(while using the default replacement for BatchNorm1d
and BatchNorm3d
with GroupNorm
), you can do:
import torch.nn as nn
from opacus.validators import register_module_fixer
@register_module_validator([nn.BatchNorm2d])
def fix_batchnorm2d(module):
return nn.InstanceNorm2d(module.num_features)
Hope this tutorial was helpful! We welcome you to peek into the code under opacus/validators/
for details. If you have any questions or comments, please don't hesitate to post them on our forum.