This tutorial demonstrates how to use Opacus's non-wrapping mode (wrap_model=False), which provides compatibility with transformer models and other complex architectures by avoiding model wrapping.
By default, Opacus wraps the model in a GradSampleModule to compute per-sample gradients. This wrapper can cause issues:
isinstance(model, MyModel) returns False after wrapping._module. prefix, which can complicate checkpoint loading.__getattr__ (e.g., HuggingFace Transformers) may not work as expected.Non-wrapping mode attaches hooks directly to the model without wrapping it, maintaining the original model structure.
First, we import the necessary libraries and create a synthetic dataset:
import warnings
warnings.simplefilter("ignore")
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
# Create synthetic dataset
n_samples = 1000
n_features = 20
n_classes = 10
X = torch.randn(n_samples, n_features)
y = torch.randint(0, n_classes, (n_samples,))
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
We define a simple classifier for this tutorial:
class SimpleClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
model = SimpleClassifier(n_features, 64, n_classes)
print(f"Model type: {type(model).__name__}")
print(f"isinstance check: {isinstance(model, SimpleClassifier)}")
Model type: SimpleClassifier isinstance check: True
We compare the default wrapped mode with non-wrapping mode:
from opacus import PrivacyEngine
# === Default wrapped mode ===
model_wrapped = SimpleClassifier(n_features, 64, n_classes)
optimizer_wrapped = optim.Adam(model_wrapped.parameters(), lr=0.001)
privacy_engine = PrivacyEngine()
model_wrapped, optimizer_wrapped, dataloader_wrapped = privacy_engine.make_private(
module=model_wrapped,
optimizer=optimizer_wrapped,
data_loader=dataloader,
noise_multiplier=1.0,
max_grad_norm=1.0,
# wrap_model=True is the default
)
print("=== Wrapped Mode (default) ===")
print(f"Model type: {type(model_wrapped).__name__}")
print(f"isinstance check: {isinstance(model_wrapped, SimpleClassifier)}")
print(f"State dict keys (first 3): {list(model_wrapped.state_dict().keys())[:3]}")
print()
=== Wrapped Mode (default) === Model type: GradSampleModule isinstance check: False State dict keys (first 3): ['_module.fc1.weight', '_module.fc1.bias', '_module.fc2.weight']
# === Non-wrapping mode ===
model_nowrap = SimpleClassifier(n_features, 64, n_classes)
optimizer_nowrap = optim.Adam(model_nowrap.parameters(), lr=0.001)
privacy_engine2 = PrivacyEngine()
hooks, optimizer_nowrap, dataloader_nowrap = privacy_engine2.make_private(
module=model_nowrap,
optimizer=optimizer_nowrap,
data_loader=dataloader,
noise_multiplier=1.0,
max_grad_norm=1.0,
wrap_model=False, # Enable non-wrapping mode
)
print("=== Non-Wrapping Mode ===")
print(f"Hooks type: {type(hooks).__name__}")
print(f"Model type: {type(model_nowrap).__name__}")
print(f"isinstance check: {isinstance(model_nowrap, SimpleClassifier)}")
print(f"State dict keys (first 3): {list(model_nowrap.state_dict().keys())[:3]}")
print("The model instance remains unchanged and can be used directly.")
=== Non-Wrapping Mode === Hooks type: GradSampleHooks Model type: SimpleClassifier isinstance check: True State dict keys (first 3): ['fc1.weight', 'fc1.bias', 'fc2.weight'] The model instance remains unchanged and can be used directly.
Key differences:
GradSampleModule, isinstance() returns False, keys have _module. prefix.SimpleClassifier, isinstance() returns True, state dict keys are clean.In non-wrapping mode, make_private returns a hooks object for cleanup. The original model remains unchanged and should be used for all operations.
# Your model
model = SimpleClassifier(...)
# Make private returns hooks
hooks, optimizer, dataloader = privacy_engine.make_private(
module=model,
wrap_model=False,
# ...
)
# Recommended: Use the original model instance
output = model(input) # The original model is used directly
state_dict = model.state_dict() # Use it normally
model.train() # Standard state transitions work
# Avoid using the hooks object for model operations
# hooks.state_dict() # These operations are not supported by the hooks object
# hooks(input)
The hooks object is primarily used for cleanup: hooks.cleanup().
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_nowrap = model_nowrap.to(device)
criterion = nn.CrossEntropyLoss()
EPOCHS = 3
DELTA = 1e-5
for epoch in range(EPOCHS):
model_nowrap.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader_nowrap):
data, target = data.to(device), target.to(device)
optimizer_nowrap.zero_grad()
output = model_nowrap(data)
loss = criterion(output, target)
loss.backward()
optimizer_nowrap.step()
total_loss += loss.item()
epsilon = privacy_engine2.get_epsilon(DELTA)
avg_loss = total_loss / len(dataloader_nowrap)
print(
f"Epoch {epoch + 1}/{EPOCHS} | Loss: {avg_loss:.4f} | ε: {epsilon:.2f} (δ={DELTA})"
)
Epoch 3/3 | Loss: 2.3059 | ε: 2.18 (δ=1e-05)
# Clean up hooks
hooks.cleanup()
print("Hooks cleaned up successfully")
# The model is restored to its original state
print("Model can now be used normally")
Hooks cleaned up successfully Model can now be used normally
The cleanup() method:
make_private().grad_sample, _forward_counter).Without cleanup, these hooks and attributes remain, which can:
In non-wrapping mode, checkpoints can be saved and loaded using the original model instance. The state dict keys remain clean, without the _module. prefix.
# Save checkpoint (train a fresh model first)
model_save = SimpleClassifier(n_features, 64, n_classes)
optimizer_save = optim.Adam(model_save.parameters(), lr=0.001)
privacy_engine3 = PrivacyEngine()
hooks_save, optimizer_save, dataloader_save = privacy_engine3.make_private(
module=model_save,
optimizer=optimizer_save,
data_loader=dataloader,
noise_multiplier=1.0,
max_grad_norm=1.0,
wrap_model=False,
)
# Use model_save directly - it's unchanged
# Save state dict - keys do not have the _module. prefix
torch.save(
{
"model_state_dict": model_save.state_dict(),
"optimizer_state_dict": optimizer_save.state_dict(),
},
"checkpoint.pt",
)
print("Checkpoint saved with clean state dict keys")
# Clean up
hooks_save.cleanup()
Checkpoint saved with clean state dict keys
# Load checkpoint into a new model
model_load = SimpleClassifier(n_features, 64, n_classes)
optimizer_load = optim.Adam(model_load.parameters(), lr=0.001)
checkpoint = torch.load("checkpoint.pt")
model_load.load_state_dict(checkpoint["model_state_dict"]) # No prefix issues
optimizer_load.load_state_dict(checkpoint["optimizer_state_dict"])
print("Checkpoint loaded successfully")
# Continue training with DP if needed
privacy_engine4 = PrivacyEngine()
hooks_load, optimizer_load, dataloader_load = privacy_engine4.make_private(
module=model_load,
optimizer=optimizer_load,
data_loader=dataloader,
noise_multiplier=1.0,
max_grad_norm=1.0,
wrap_model=False,
)
print("Ready to continue training")
Checkpoint loaded successfully Ready to continue training
make_private_with_epsilon¶Non-wrapping mode is also compatible with make_private_with_epsilon, which automatically calculates the noise multiplier based on your target privacy budget.
model_eps = SimpleClassifier(n_features, 64, n_classes)
optimizer_eps = optim.Adam(model_eps.parameters(), lr=0.001)
dataloader_eps = DataLoader(dataset, batch_size=32, shuffle=True)
privacy_engine5 = PrivacyEngine()
hooks_eps, optimizer_eps, dataloader_eps = privacy_engine5.make_private_with_epsilon(
module=model_eps,
optimizer=optimizer_eps,
data_loader=dataloader_eps,
target_epsilon=3.0,
target_delta=1e-5,
epochs=EPOCHS,
max_grad_norm=1.0,
wrap_model=False, # Works with non-wrapping mode
)
print(f"Target epsilon: 3.0")
print(f"Computed noise multiplier: {optimizer_eps.noise_multiplier:.3f}")
# Cleanup when finished
hooks_eps.cleanup()
Target epsilon: 3.0 Computed noise multiplier: 0.875
Use wrap_model=False when:
__getattr__.isinstance() checks are required._module. prefixes are preferred.Use default wrap_model=True when:
Note: ExpandedWeights mode (grad_sample_mode="ew") is not supported with non-wrapping mode, as it requires overriding the model's .forward() method.
| Feature | Wrapped Mode | Non-Wrapping Mode |
|---|---|---|
| API parameter | wrap_model=True (default) |
wrap_model=False |
| Model type preserved | No | Yes |
isinstance() works |
No | Yes |
| State dict keys | _module. prefix |
Clean |
| Cleanup required | Automatic | Manual (.cleanup()) |
| HuggingFace compatibility | May have issues | Better |
| ExpandedWeights support | Yes | No |
Key takeaway: Non-wrapping mode provides a compatible alternative for architectures where model wrapping is problematic. Manual cleanup using the hooks object is required.