"""
utils.py
This module contains various utility functions and configurations for 
working with the CIFAR-10 dataset and PyTorch Lightning-based training 
for creating and training a surrogate model. This file supports the 
"create_surrogate_model.py" file. 
"""
import math
from typing import Tuple, Union
import torch
import torchvision as tv
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset, TensorDataset
cifar10_normalize_values = {
    "mean": [0.4914, 0.4822, 0.4465],
    "std": [0.2470, 0.2435, 0.2616],
}
[docs]
def load_cifar10(train=True, require_normalize=False) -> Dataset:
    """Return CIFAR10 dataset."""
    dataset = tv.datasets.CIFAR10(
        "./data",
        download=True,
        train=train,
        transform=get_transforms(train, require_normalize),
    )
    return dataset 
[docs]
def inverse_normalize(batch: torch.Tensor, normalize_values: dict) -> torch.Tensor:
    """Convert a tensor to their original scale."""
    device = batch.get_device()
    mean = torch.Tensor(normalize_values["mean"]).to(device)
    std = torch.Tensor(normalize_values["std"]).to(device)
    return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1) 
[docs]
def get_labels(dataloader: DataLoader) -> torch.Tensor:
    """Extract labels from a dataloader."""
    n_samples = len(dataloader.dataset)
    labels = torch.zeros(n_samples, dtype=torch.long)
    start = 0
    for batch in dataloader:
        _, y = batch
        n = len(y)
        end = start + n
        labels[start:end] = y
        start = end
    return labels 
[docs]
def get_data(dataloader: DataLoader) -> torch.Tensor:
    """Extract data from a dataloader."""
    X = []
    for batch in dataloader:
        x, _ = batch
        X.append(x)
    return torch.concat(X) 
[docs]
def choose_dataset(
    dataset: Dataset, n_sample: Union[int, float], num_workers=1
) -> Dataset:
    """Random choose n samples from a dataset without replacement."""
    assert (
        isinstance(n_sample, int) or n_sample < 1
    ) and n_sample > 0, "n_sample is invalid."
    assert n_sample < len(dataset), "This function does not allow replacement."
    if isinstance(n_sample, float):
        n_sample = math.floor(len(dataset) * n_sample)
    dataloader = DataLoader(
        dataset, batch_size=512, shuffle=True, num_workers=num_workers
    )
    X = []
    Y = []
    n = 0
    for batch in dataloader:
        x, y = batch
        n += len(x)
        X.append(x)
        Y.append(y)
        if n >= n_sample:
            break
    X = torch.concat(X)[:n_sample]
    Y = torch.concat(Y)[:n_sample]
    return TensorDataset(X, Y) 
[docs]
def find_clip_range(dataset: Dataset) -> Tuple[float, float]:
    """Return the range of a dataset.
    WARNING: Adversarial examples should NOT use a clip range after normalization.
    The scale of the perturbation will be wrong.
    """
    max_x = -torch.inf
    min_x = torch.inf
    for x, _ in dataset:
        _max = x.max()
        _min = x.min()
        if max_x < _max:
            max_x = _max
        if min_x > _min:
            min_x = _min
    return min_x.item(), max_x.item()