"""
load_set.py
This script is responsible for loading the test dataset for model evaluation.
"""
from aiml.surrogate_model.utils import load_cifar10
from datasets import list_datasets
from datasets import load_dataset
import torchvision as tv
[docs]
def load_test_set(test_set):
    """
    Load a test dataset.
    Parameters:
        dataset (dataset or string): Given a string, it will search for the target dataset.
    Returns:
        dataset: The loaded test dataset.
    """
    if type(test_set) == type("a"):
        if test_set == "cifar10":
            test_set = load_cifar10(train=False, require_normalize=True)
        elif train_set =="mnist":
            train_set=tv.datasets.MNIST('./data', download=True, train=False)
        elif train_set == "cifar100":
            train_set = tv.datasets.CIFAR100('./data', download=True, train=False)
        else:
            try:
                test_set = load_dataset(path=test_set, split='test')
                print("we find the test set on huggingface")
            except:
                print("Currently we cannot find the test dataset you input. you can call huggingface_hub.list_datasets to see the whole valiable set")
                return None
    
    return test_set 
[docs]
def load_train_set(train_set):
    """
    Load a training dataset.
    Parameters:
        dataset (dataset or string): Given a string, it will search for the target dataset.
    Returns:
        dataset: The loaded training dataset.
    """
    if type(train_set) == type("a"):
        if train_set == "cifar10":
            train_set = tv.datasets.CIFAR10('./data', download=True, train=True)
        elif train_set =="mnist":
            train_set=tv.datasets.MNIST('./data', download=True, train=True)
        elif train_set == "cifar100":
            train_set = tv.datasets.CIFAR100('./data', download=True, train=True)
        else:
            try:
                train_set = load_dataset(path=train_set, split='train')
                print("we find the train set on huggingface")
            except:
                print("Currently we cannot find the train dataset you input. you can call huggingface_hub.list_datasets to see the whole valiable set")
                return None
    
    return train_set