Source code for aiml.load_data.load_model

"""
load_model.py

This script is responsible for loading the model.
"""


#import detectors
from robustbench.utils import load_model


[docs] def load_model(model, device): """ Load a machine learning model. Parameters: model (model or string): If a string is provided, it will search for the target model by detectors. device (string): The device to use, either 'cpu' or 'gpu'. Returns: model: The loaded machine learning model. """ if type(model) == type("a"): try: #model = detectors.create_model(model, pretrained=True) model = model.to(device) except: try: model= load_model(model) model = model.to(device) except: print("We can't find your model.") return None return model