Module Utils

opacus.utils.module_utils.are_state_dict_equal(sd1, sd2)[source]

Compares two state dicts, while logging discrepancies

opacus.utils.module_utils.clone_module(module)[source]

Handy utility to clone an nn.Module. PyTorch doesn’t always support copy.deepcopy(), so it is just easier to serialize the model to a BytesIO and read it from there.

Parameters

module (Module) – The module to clone

Return type

Module

Returns

The clone of module

opacus.utils.module_utils.get_submodule(module, target)[source]

Returns the submodule given by target if it exists, otherwise throws an error.

This is copy-pasta of Pytorch 1.9’s get_submodule() implementation; and is included here to also support Pytorch 1.8. This function can be removed in favour of module.get_submodule() once Opacus abandons support for torch 1.8.

See more details at https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=get_submodule#torch.nn.Module.get_submodule

Parameters
  • module (Module) – module

  • target (str) – submodule string

Return type

Module

Returns

The submodule given by target if it exists

Raises

AttributeError – If submodule doesn’t exist

opacus.utils.module_utils.parametrized_modules(module)[source]

Recursively iterates over all submodules, returning those that have parameters (as opposed to “wrapper modules” that just organize modules).

Return type

Iterable[Module]

opacus.utils.module_utils.requires_grad(module, *, recurse=False)[source]

Checks if any parameters in a specified module require gradients.

Parameters
  • module (Module) – PyTorch module whose parameters are to be examined.

  • recurse (bool) – Flag specifying if the gradient requirement check should be applied recursively to sub-modules of the specified module

Return type

bool

Returns

Flag indicate if any parameters require gradients

opacus.utils.module_utils.trainable_modules(module)[source]

Recursively iterates over all submodules, returning those that have parameters and are trainable (ie they want a grad).

Return type

Iterable[Module]