"""
normalize_datasets.py
This module contains functions for normalizing and denormalizing a dataset.
"""
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
normalize_values = {}
[docs]
def get_mean_std(dataset):
"""
Get the mean and standard deviation of the dataset's image channels.
Parameters:
dataset (dataset): The dataset containing images.
Returns:
None.
"""
imgs = [item[0] for item in dataset]
imgs = torch.stack(imgs, dim=0).numpy()
sample_image = imgs[0]
num_channels, height, width = sample_image.shape
num_images = len(dataset)
num_classes = len(set([y for _, y in dataset]))
if num_channels == 3:
# CIFAR10
if num_classes == 10 and height == 32 and width == 32 and num_images in [50000, 10000]:
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
# SVHN
elif num_classes == 10 and height == 32 and width == 32 and num_images in [73257, 26032]:
mean = [0.4377, 0.44378, 0.4728]
std = [0.1980, 0.2010, 0.19704]
# GTSRB
elif num_classes == 43 and num_images in [39209, 26640, 12630]:
mean = [0.3417, 0.3126, 0.3217]
std = [0.2768, 0.2646, 0.2706]
# CIFAR100
elif num_classes == 100 and height == 32 and width == 32 and num_images in [50000, 10000]:
mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]
# Tiny ImageNet
elif num_classes == 200 and height == 64 and width == 64 and num_images in [100000, 10000]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
else:
mean_r = imgs[:, 0, :, :].mean()
mean_g = imgs[:, 1, :, :].mean()
mean_b = imgs[:, 2, :, :].mean()
mean = [mean_r, mean_g, mean_b]
std_r = imgs[:, 0, :, :].std()
std_g = imgs[:, 1, :, :].std()
std_b = imgs[:, 2, :, :].std()
std = [std_r, std_g, std_b]
else:
# MNIST
if num_classes == 10 and height == 28 and width == 28 and num_images in [60000, 10000]:
mean = [0.1307,]
std = [0.3081,]
else:
mean = [imgs[:, 0, :, :].mean()]
std = [imgs[:, 0, :, :].std()]
normalize_values["mean"] = mean
normalize_values["std"] = std
[docs]
def check_normalize(dataloader):
"""
Check if the data in a dataloader is normalized.
Parameters:
dataloader (torch.utils.data.DataLoader): A PyTorch DataLoader containing the dataset.
Returns:
bool: True if the data is normalized (mean close to 0, std close to 1), False otherwise.
"""
data = next(iter(dataloader))
mean = data[0].mean()
std = data[0].std()
if mean > 0.1 or mean < -0.1 or std > 1.1 or std < 0.9:
return False
return True
[docs]
def normalize_datasets(dataset_test, dataset_train=None):
"""
Normalize the training and testing datasets.
Parameters:
dataset_test (dataset): The testing dataset.
dataset_train (dataset, optional): The training dataset (Default is None).
Returns:
tuple: A tuple containing the normalized testing dataset and, if provided, the normalized training dataset.
"""
if dataset_train:
dataset_train.transform = get_transforms()
dataset_test.transform = get_transforms()
return dataset_test, dataset_train
[docs]
def normalize_and_check_datasets(num_workers, batch_size_test, batch_size_train, dataset_test, dataset_train):
"""
Normalize and check the given test and optionally, training datasets for normalization.
Parameters:
num_workers (int): Number of workers for data loading.
batch_size_test (int): Batch size for the test dataset.
batch_size_train (int): Batch size for the training dataset (if provided).
test_dataset: The test dataset.
train_dataset (optional): The training dataset (Default is None).
Returns:
Tuple: If normalization is required, returns a tuple containing the normalized test
and training datasets along with their data loaders. If no normalization is needed,
returns the test dataset as-is.
"""
dataloader_test = None
dataloader_train = None
if dataset_train:
get_mean_std(dataset_train)
dataloader_train = DataLoader(
dataset_train,
batch_size=batch_size_train,
shuffle=True,
num_workers=num_workers,
)
dataloader_test = DataLoader(
dataset_test,
batch_size=batch_size_test,
shuffle=False,
num_workers=num_workers,
)
if not check_normalize(dataloader_test) or not check_normalize(dataloader_train):
dataset_test_norm, dataset_train_norm = normalize_datasets(
dataset_test, dataset_train)
dataloader_train = DataLoader(
dataset_test_norm,
batch_size=batch_size_train,
shuffle=True,
num_workers=num_workers,
)
dataloader_test = DataLoader(
dataset_train_norm,
batch_size=batch_size_test,
shuffle=False,
num_workers=num_workers,
)
else:
get_mean_std(dataset_test)
dataloader_test = DataLoader(
dataset_test,
batch_size=batch_size_test,
shuffle=False,
num_workers=num_workers,
)
if not check_normalize(dataloader_test):
dataset_test_norm, _ = normalize_datasets(dataset_test)
dataloader_test = DataLoader(
dataset_test_norm,
batch_size=batch_size_test,
shuffle=False,
num_workers=num_workers,
)
return dataset_test, dataset_train, dataloader_test, dataloader_train
[docs]
def denormalize(batch):
"""
Denormalize a batch of normalized data using mean and standard deviation values.
Parameters:
batch (torch.Tensor): A batch of normalized data.
Returns:
torch.Tensor: The denormalized batch of data with the same shape as the input.
"""
mean = torch.Tensor(normalize_values["mean"])
std = torch.Tensor(normalize_values["std"])
return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)