Uniform Sampler

class opacus.utils.uniform_sampler.DistributedUniformWithReplacementSampler(*, total_size, sample_rate, shuffle=True, shuffle_seed=0, steps=None, generator=None)[source]

Distributed batch sampler.

Each batch is sampled as follows:
  1. Shuffle the dataset (enabled by default)

  2. Split the dataset among the replicas into chunks of equal size (plus or minus one sample)

  3. Each replica selects each sample of its chunk independently with probability sample_rate

  4. Each replica outputs the selected samples, which form a local batch

The sum of the lengths of the local batches follows a Poisson distribution. In particular, the expected length of each local batch is: sample_rate * total_size / num_replicas

Parameters:
  • total_size (int) – total number of samples to sample from

  • sample_rate (float) – number of samples to draw.

  • shuffle (bool) – Flag indicating whether apply shuffle when dividing elements between workers

  • shuffle_seed (int) – Random seed used to shuffle when dividing elements across workers

  • generator – torch.Generator() object used as a source of randomness when selecting items for the next round on a given worker

set_epoch(epoch)[source]

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

epoch (int) – Epoch number.

Return type:

None

class opacus.utils.uniform_sampler.UniformWithReplacementSampler(*, num_samples, sample_rate, generator=None, steps=None)[source]

This sampler samples elements according to the Sampled Gaussian Mechanism. Each sample is selected with a probability equal to sample_rate. The sampler generates steps number of batches, that defaults to 1/sample_rate.

Parameters:
  • num_samples (int) – number of samples to draw.

  • sample_rate (float) – probability used in sampling.

  • generator – Generator used in sampling.

  • steps – Number of steps (iterations of the Sampler)