"""
check_accuracy.py
This module defines functions for calculating the accuracy of a given 
dataset on a model.
"""
import torch
[docs]
def check_accuracy(model, dataloader, device):
    """
    Calculate the accuracy of a dataset using a pre-trained machine learning model.
    Parameters:
        model: The pre-trained machine learning model.
        dataloader: The dataloader for the dataset to be tested.
        device (str): The device to use, either 'cpu' or 'gpu'.
    Returns:
        float: The accuracy of the dataset when tested on the provided model.
    """
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            x = x.to(device)
            outputs = model(x)
            _, predictions = torch.max(outputs, 1)
            predictions = predictions.to("cpu")
            total += y.size(0)
            correct += (predictions == y).sum().item()
    accuracy = correct / total
    return accuracy 
[docs]
def check_accuracy_with_flags(model, dataloader, device):
    """
    Calculate the accuracy of a dataset using a machine learning model.
    Parameters:
        model: The machine learning model for evaluation.
        dataloader: The dataloader for the dataset to be tested.
        device (str): The device to use, either 'cpu' or 'gpu'.
    Returns:
        float: The accuracy of the dataset when tested on the provided model.
        list: A list showing which images were correctly recognized (True) or not (False).
    """
    correct = 0
    total = 0
    model.eval()
    correct_image_bool = []
    with torch.no_grad():
        for batch in dataloader:
            # obtain image and label
            x, y = batch
            # pass image to device
            x = x.to(device)
            # obtain model output
            outputs = model(x)
            # obtain model predicted label
            confidence, predictions = torch.max(outputs, 1)
            predictions = predictions.to("cpu")
            # generates a tensor giving True for correct classification
            # and False otherwise
            correct_bool = predictions == y
            total += y.size(0)
            correct += correct_bool.sum().item()
            correct_image_bool.append(correct_bool.tolist())
    accuracy = correct / total
    return accuracy, correct_image_bool