"""
utils.py
This module contains various utility functions and configurations for
working with the CIFAR-10 dataset and PyTorch Lightning-based training
for creating and training a surrogate model. This file supports the
"create_surrogate_model.py" file.
"""
import math
from typing import Tuple, Union
import torch
import torchvision as tv
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset, TensorDataset
cifar10_normalize_values = {
"mean": [0.4914, 0.4822, 0.4465],
"std": [0.2470, 0.2435, 0.2616],
}
[docs]
def load_cifar10(train=True, require_normalize=False) -> Dataset:
"""Return CIFAR10 dataset."""
dataset = tv.datasets.CIFAR10(
"./data",
download=True,
train=train,
transform=get_transforms(train, require_normalize),
)
return dataset
[docs]
def inverse_normalize(batch: torch.Tensor, normalize_values: dict) -> torch.Tensor:
"""Convert a tensor to their original scale."""
device = batch.get_device()
mean = torch.Tensor(normalize_values["mean"]).to(device)
std = torch.Tensor(normalize_values["std"]).to(device)
return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)
[docs]
def get_labels(dataloader: DataLoader) -> torch.Tensor:
"""Extract labels from a dataloader."""
n_samples = len(dataloader.dataset)
labels = torch.zeros(n_samples, dtype=torch.long)
start = 0
for batch in dataloader:
_, y = batch
n = len(y)
end = start + n
labels[start:end] = y
start = end
return labels
[docs]
def get_data(dataloader: DataLoader) -> torch.Tensor:
"""Extract data from a dataloader."""
X = []
for batch in dataloader:
x, _ = batch
X.append(x)
return torch.concat(X)
[docs]
def choose_dataset(
dataset: Dataset, n_sample: Union[int, float], num_workers=1
) -> Dataset:
"""Random choose n samples from a dataset without replacement."""
assert (
isinstance(n_sample, int) or n_sample < 1
) and n_sample > 0, "n_sample is invalid."
assert n_sample < len(dataset), "This function does not allow replacement."
if isinstance(n_sample, float):
n_sample = math.floor(len(dataset) * n_sample)
dataloader = DataLoader(
dataset, batch_size=512, shuffle=True, num_workers=num_workers
)
X = []
Y = []
n = 0
for batch in dataloader:
x, y = batch
n += len(x)
X.append(x)
Y.append(y)
if n >= n_sample:
break
X = torch.concat(X)[:n_sample]
Y = torch.concat(Y)[:n_sample]
return TensorDataset(X, Y)
[docs]
def find_clip_range(dataset: Dataset) -> Tuple[float, float]:
"""Return the range of a dataset.
WARNING: Adversarial examples should NOT use a clip range after normalization.
The scale of the perturbation will be wrong.
"""
max_x = -torch.inf
min_x = torch.inf
for x, _ in dataset:
_max = x.max()
_min = x.min()
if max_x < _max:
max_x = _max
if min_x > _min:
min_x = _min
return min_x.item(), max_x.item()