import torch

from timm.utils import AverageMeter

from metrics import accuracy


__all__ = [
    "val_epoch"
]


@torch.no_grad()
def val_epoch(
    model, 
    loader, 
    loss_fn,
    device: str = 'cpu', 
    amp: bool = False
) -> dict:
    model.eval()
    # create meters
    loss_m = AverageMeter()
    acc1_m = AverageMeter()
    # get device type
    device_type = device.split(':')[0]
    # train epoch
    for _, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.autocast(device_type=device_type, enabled=amp):
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
        # get accuracies
        acc1 = accuracy(outputs.float(), targets.float())
        # update stats
        acc1_m.update(acc1, len(inputs))
        loss_m.update(loss.item(), len(inputs))

    return {'val/loss': loss_m.avg, 'val/acc1': acc1_m.avg}
    