"""
create_surrogate_model.py
This module creates surrogate models for black-box attacks.
"""
import pytorch_lightning as pl
import torch
import torch.nn as nn
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from aiml.surrogate_model.models import LogSoftmaxModule, Surrogate, create_substitute_model
[docs]
def get_num_classes(dataloader):
    """
    Get the number of classes from a dataloader.
    Parameters:
        dataloader (torch.utils.data.DataLoader): The dataloader containing the dataset.
    Returns:
        int: The number of classes in the dataset.
    """
    try:
        return len(dataloader.dataset.classes)
    except:
        unique_labels = set()
        for batch in dataloader:
            _, labels = batch
            unique_labels.update(labels.numpy().tolist())
        return len(unique_labels) 
[docs]
def create_substitute(dataloader_train, num_classes):
    """
    Create a substitute model based on the training dataloader.
    Parameters:
        dataloader_train (torch.utils.data.DataLoader): The training dataloader.
        num_classes (int): The number of classes in the dataset.
    Returns:
        nn.Module: The created substitute model.
    """
    num_channels = dataloader_train.dataset[0][0].shape[0]
    surrogate = create_substitute_model(num_classes, num_channels)
    return surrogate 
[docs]
def create_surrogate_model(model, dataloader_train, dataloader_test):
    """
    Create and train a surrogate model using PyTorch Lightning.
    Parameters:
        model (nn.Module): The black-box model to create a surrogate for.
        dataloader_train (torch.utils.data.DataLoader): The training dataloader.
        dataloader_test (torch.utils.data.DataLoader): The testing dataloader.
    Returns:
        pytorch_lightning.LightningModule: The trained surrogate model.
    """
    MAX_EPOCHS = 50
    LEARNING_RATE = 0.0005
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if str(device) == "cuda:0":
        torch.set_float32_matmul_precision("high")
    oracle = LogSoftmaxModule(model)
    num_classes = get_num_classes(dataloader_train)
    substitute = create_substitute(dataloader_train, num_classes)
    loss_fn = nn.KLDivLoss(reduction="batchmean", log_target=True)
    num_training_batches = len(dataloader_train)
    surrogate_module = Surrogate(
        lr=LEARNING_RATE,
        num_training_batches=num_training_batches,
        oracle=oracle,
        substitute=substitute,
        loss_fn=loss_fn,
        num_classes=num_classes,
        softmax=True
    )
    trainer = pl.Trainer(
        max_epochs=MAX_EPOCHS,
        enable_progress_bar=True,
        logger=TensorBoardLogger(
            "logs", name="surrogate", default_hp_metric=False
        ),
        callbacks=[LearningRateMonitor(logging_interval="step")],
        # fast_dev_run=True,
    )
    trainer.fit(
        surrogate_module,
        train_dataloaders=dataloader_train,
        val_dataloaders=dataloader_test,
    )
    return surrogate_module