Module Inspection

This module includes utils for inspecting model layers using specified predicates to check for conditions, getting layer type etc.

class opacus.utils.module_inspection.ModelInspector(name, predicate, check_leaf_nodes_only=True, message=None)[source]

An inspector of models given a specific predicate. If a module has children the predicate is checked on all children recursively.

Example

>>>  inspector = ModelInspector('simple', lambda x: isinstance(x, Conv2d))
>>>  print(inspector.validate(nn.Conv2d(1, 1, 1)))
True
Parameters
  • name (str) – String to represent the predicate.

  • predicate (Callable[[Module], bool]) – Callable boolean function which tests a hypothesis on a module.

  • check_leaf_nodes_only (bool) – Flag to check only leaf nodes of a module. Here leaf nodes are the ones that have parameters of their own.

  • message (Optional[str]) – Optional value to hold a message about violating this predicate.

Notes

The predicates will not be applied on non-leaf modules unless check_leaf_nodes_only is set to False. E.g. A predicate like:

lambda model: isinstance(model, nn.Sequential)

will always return True unless check_leaf_nodes_only is set.

validate(model)[source]

Checks if the provided module satisfies the predicate specified upon creation of the ModelInspector.

Parameters

model (Module) – PyTorch module on which the predicate must be evaluated and satisfied.

Return type

bool

Returns

Flag indicate if predicate is satisfied.

opacus.utils.module_inspection.get_layer_type(layer)[source]

Returns the name of the type of the given layer.

Parameters

layer (Module) – The module corresponding to the layer whose type is being queried.

Return type

str

Returns

Name of the class of the layer

opacus.utils.module_inspection.has_no_param(module)[source]

Checks if a module does not have any parameters.

Parameters

module (Module) – The module on which this function is being evaluated.

Return type

bool

Returns

Flag indicating if the provided module does not have any parameters.

opacus.utils.module_inspection.requires_grad(module, recurse=False)[source]

Checks if any parameters in a specified module require gradients.

Parameters
  • module (Module) – PyTorch module whose parameters are examined

  • recurse (bool) – Flag specifying if the gradient requirement check should be applied recursively to sub-modules of the specified module

Return type

bool

Returns

Flag indicate if any parameters require gradients