Opacus
  • Introduction
  • FAQ
  • Tutorials
  • API Reference
  • GitHub

›

Tutorials

  • Overview

Using Opacus

  • Building text classifier with Fast Gradient Clipping DP-SGD
  • Training with Non-Wrapping Mode
  • Building image classifier with Differential Privacy
  • Training a differentially private LSTM model for name classification
  • Deep dive into advanced features of Opacus
  • Guide to Module Validator and Fixer
  • Guide to grad samplers
  • Training on multiple GPUs with DistributedDataParallel

Training with Non-Wrapping Mode¶

This tutorial demonstrates how to use Opacus's non-wrapping mode (wrap_model=False), which provides compatibility with transformer models and other complex architectures by avoiding model wrapping.

What is Non-Wrapping Mode?¶

By default, Opacus wraps the model in a GradSampleModule to compute per-sample gradients. This wrapper can cause issues:

  • Type checking: isinstance(model, MyModel) returns False after wrapping.
  • State dict: Keys get a _module. prefix, which can complicate checkpoint loading.
  • Attribute access: Models with custom __getattr__ (e.g., HuggingFace Transformers) may not work as expected.

Non-wrapping mode attaches hooks directly to the model without wrapping it, maintaining the original model structure.

Setup¶

First, we import the necessary libraries and create a synthetic dataset:

In [1]:
import warnings

warnings.simplefilter("ignore")

import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader

# Create synthetic dataset
n_samples = 1000
n_features = 20
n_classes = 10

X = torch.randn(n_samples, n_features)
y = torch.randint(0, n_classes, (n_samples,))

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

Define a Model¶

We define a simple classifier for this tutorial:

In [2]:
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = SimpleClassifier(n_features, 64, n_classes)
print(f"Model type: {type(model).__name__}")
print(f"isinstance check: {isinstance(model, SimpleClassifier)}")
Model type: SimpleClassifier
isinstance check: True

Comparison: Wrapped vs Non-Wrapped¶

We compare the default wrapped mode with non-wrapping mode:

In [3]:
from opacus import PrivacyEngine

# === Default wrapped mode ===
model_wrapped = SimpleClassifier(n_features, 64, n_classes)
optimizer_wrapped = optim.Adam(model_wrapped.parameters(), lr=0.001)

privacy_engine = PrivacyEngine()
model_wrapped, optimizer_wrapped, dataloader_wrapped = privacy_engine.make_private(
    module=model_wrapped,
    optimizer=optimizer_wrapped,
    data_loader=dataloader,
    noise_multiplier=1.0,
    max_grad_norm=1.0,
    # wrap_model=True is the default
)

print("=== Wrapped Mode (default) ===")
print(f"Model type: {type(model_wrapped).__name__}")
print(f"isinstance check: {isinstance(model_wrapped, SimpleClassifier)}")
print(f"State dict keys (first 3): {list(model_wrapped.state_dict().keys())[:3]}")
print()
=== Wrapped Mode (default) ===
Model type: GradSampleModule
isinstance check: False
State dict keys (first 3): ['_module.fc1.weight', '_module.fc1.bias', '_module.fc2.weight']

In [4]:
# === Non-wrapping mode ===
model_nowrap = SimpleClassifier(n_features, 64, n_classes)
optimizer_nowrap = optim.Adam(model_nowrap.parameters(), lr=0.001)

privacy_engine2 = PrivacyEngine()
hooks, optimizer_nowrap, dataloader_nowrap = privacy_engine2.make_private(
    module=model_nowrap,
    optimizer=optimizer_nowrap,
    data_loader=dataloader,
    noise_multiplier=1.0,
    max_grad_norm=1.0,
    wrap_model=False,  # Enable non-wrapping mode
)

print("=== Non-Wrapping Mode ===")
print(f"Hooks type: {type(hooks).__name__}")
print(f"Model type: {type(model_nowrap).__name__}")
print(f"isinstance check: {isinstance(model_nowrap, SimpleClassifier)}")
print(f"State dict keys (first 3): {list(model_nowrap.state_dict().keys())[:3]}")
print("The model instance remains unchanged and can be used directly.")
=== Non-Wrapping Mode ===
Hooks type: GradSampleHooks
Model type: SimpleClassifier
isinstance check: True
State dict keys (first 3): ['fc1.weight', 'fc1.bias', 'fc2.weight']
The model instance remains unchanged and can be used directly.

Key differences:

  • Wrapped: Model becomes GradSampleModule, isinstance() returns False, keys have _module. prefix.
  • Non-wrapped: Model remains SimpleClassifier, isinstance() returns True, state dict keys are clean.

Model Usage in Non-Wrapping Mode¶

In non-wrapping mode, make_private returns a hooks object for cleanup. The original model remains unchanged and should be used for all operations.

# Your model
model = SimpleClassifier(...)

# Make private returns hooks
hooks, optimizer, dataloader = privacy_engine.make_private(
    module=model,
    wrap_model=False,
    # ...
)

# Recommended: Use the original model instance
output = model(input)              # The original model is used directly
state_dict = model.state_dict()   # Use it normally
model.train()                      # Standard state transitions work

# Avoid using the hooks object for model operations
# hooks.state_dict()               # These operations are not supported by the hooks object
# hooks(input)

The hooks object is primarily used for cleanup: hooks.cleanup().

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_nowrap = model_nowrap.to(device)
criterion = nn.CrossEntropyLoss()

EPOCHS = 3
DELTA = 1e-5

for epoch in range(EPOCHS):
    model_nowrap.train()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader_nowrap):
        data, target = data.to(device), target.to(device)

        optimizer_nowrap.zero_grad()
        output = model_nowrap(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer_nowrap.step()

        total_loss += loss.item()

    epsilon = privacy_engine2.get_epsilon(DELTA)
    avg_loss = total_loss / len(dataloader_nowrap)
    print(
        f"Epoch {epoch + 1}/{EPOCHS} | Loss: {avg_loss:.4f} | ε: {epsilon:.2f} (δ={DELTA})"
    )
Epoch 3/3 | Loss: 2.3059 | ε: 2.18 (δ=1e-05)
In [6]:
# Clean up hooks
hooks.cleanup()
print("Hooks cleaned up successfully")

# The model is restored to its original state
print("Model can now be used normally")
Hooks cleaned up successfully
Model can now be used normally

What Does Cleanup Do?¶

The cleanup() method:

  1. Removes all hooks attached during make_private().
  2. Deletes attributes added to parameters (e.g., grad_sample, _forward_counter).
  3. Restores the model to its original state.

Without cleanup, these hooks and attributes remain, which can:

  • Increase memory usage.
  • Interfere with subsequent training.
  • Lead to unexpected behavior in non-DP contexts.

Saving and Loading Checkpoints¶

In non-wrapping mode, checkpoints can be saved and loaded using the original model instance. The state dict keys remain clean, without the _module. prefix.

In [7]:
# Save checkpoint (train a fresh model first)
model_save = SimpleClassifier(n_features, 64, n_classes)
optimizer_save = optim.Adam(model_save.parameters(), lr=0.001)

privacy_engine3 = PrivacyEngine()
hooks_save, optimizer_save, dataloader_save = privacy_engine3.make_private(
    module=model_save,
    optimizer=optimizer_save,
    data_loader=dataloader,
    noise_multiplier=1.0,
    max_grad_norm=1.0,
    wrap_model=False,
)

# Use model_save directly - it's unchanged
# Save state dict - keys do not have the _module. prefix
torch.save(
    {
        "model_state_dict": model_save.state_dict(),
        "optimizer_state_dict": optimizer_save.state_dict(),
    },
    "checkpoint.pt",
)
print("Checkpoint saved with clean state dict keys")

# Clean up
hooks_save.cleanup()
Checkpoint saved with clean state dict keys
In [8]:
# Load checkpoint into a new model
model_load = SimpleClassifier(n_features, 64, n_classes)
optimizer_load = optim.Adam(model_load.parameters(), lr=0.001)

checkpoint = torch.load("checkpoint.pt")
model_load.load_state_dict(checkpoint["model_state_dict"])  # No prefix issues
optimizer_load.load_state_dict(checkpoint["optimizer_state_dict"])
print("Checkpoint loaded successfully")

# Continue training with DP if needed
privacy_engine4 = PrivacyEngine()
hooks_load, optimizer_load, dataloader_load = privacy_engine4.make_private(
    module=model_load,
    optimizer=optimizer_load,
    data_loader=dataloader,
    noise_multiplier=1.0,
    max_grad_norm=1.0,
    wrap_model=False,
)
print("Ready to continue training")
Checkpoint loaded successfully
Ready to continue training

Using make_private_with_epsilon¶

Non-wrapping mode is also compatible with make_private_with_epsilon, which automatically calculates the noise multiplier based on your target privacy budget.

In [9]:
model_eps = SimpleClassifier(n_features, 64, n_classes)
optimizer_eps = optim.Adam(model_eps.parameters(), lr=0.001)
dataloader_eps = DataLoader(dataset, batch_size=32, shuffle=True)

privacy_engine5 = PrivacyEngine()
hooks_eps, optimizer_eps, dataloader_eps = privacy_engine5.make_private_with_epsilon(
    module=model_eps,
    optimizer=optimizer_eps,
    data_loader=dataloader_eps,
    target_epsilon=3.0,
    target_delta=1e-5,
    epochs=EPOCHS,
    max_grad_norm=1.0,
    wrap_model=False,  # Works with non-wrapping mode
)

print(f"Target epsilon: 3.0")
print(f"Computed noise multiplier: {optimizer_eps.noise_multiplier:.3f}")

# Cleanup when finished
hooks_eps.cleanup()
Target epsilon: 3.0
Computed noise multiplier: 0.875

When to Use Non-Wrapping Mode?¶

Use wrap_model=False when:

  • Working with HuggingFace Transformers or models with custom __getattr__.
  • isinstance() checks are required.
  • Clean state dicts without _module. prefixes are preferred.
  • The pipeline relies on model type introspection.

Use default wrap_model=True when:

  • Working with standard models without complex introspection needs.
  • Automatic cleanup is preferred (wrapper discarded when out of scope).

Note: ExpandedWeights mode (grad_sample_mode="ew") is not supported with non-wrapping mode, as it requires overriding the model's .forward() method.

Summary¶

Feature Wrapped Mode Non-Wrapping Mode
API parameter wrap_model=True (default) wrap_model=False
Model type preserved No Yes
isinstance() works No Yes
State dict keys _module. prefix Clean
Cleanup required Automatic Manual (.cleanup())
HuggingFace compatibility May have issues Better
ExpandedWeights support Yes No

Key takeaway: Non-wrapping mode provides a compatible alternative for architectures where model wrapping is problematic. Manual cleanup using the hooks object is required.

Download Tutorial Jupyter Notebook
Opacus
Docs
IntroductionFAQTutorialsAPI Reference
Github
opacus
Legal
PrivacyTerms
Meta Open Source
Copyright © 2026 Meta Platforms, Inc.