"""
load_model.py
This script is responsible for loading the model.
"""
#import detectors
from robustbench.utils import load_model
[docs]
def load_model(model, device):
    """
    Load a machine learning model.
    Parameters:
        model (model or string): If a string is provided, it will search for 
            the target model by detectors.
        device (string): The device to use, either 'cpu' or 'gpu'.
    Returns:
        model: The loaded machine learning model.
    """
    
    if type(model) == type("a"):
        try:
            #model = detectors.create_model(model, pretrained=True)
            model = model.to(device)
        except:
            try:
                model= load_model(model)
                model = model.to(device)
            except:
                print("We can't find your model.")
                return None
    return model