import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicModel(nn.Module):
    def __init__(self, args):
        super(BasicModel, self).__init__()
        self.name = args.model

    def load_model(self, path, epoch, load_optimizer=True):
        checkpoint = torch.load(
            os.path.join(path, '{}_epoch_{}.pth'.format(self.name, epoch))
        )
        self.load_state_dict(checkpoint["state_dict"])
        if load_optimizer:
            self.optimizer.load_state_dict(checkpoint["optimizer"])

    def save_model(self, path, epoch, symlink_latest=True, remove_old=None):
        checkpoint = {
            "epoch": epoch,
            "state_dict": self.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }
        ckpt_path = os.path.join(path, '{}_epoch_{}.pth'.format(self.name, epoch))
        torch.save(checkpoint, ckpt_path)
        if symlink_latest:
            # Make checkpoint_latest.pth be a symbolic link to the new checkpoint.
            # Since the link will already exist, we make a temporary symbolic
            # link and copy it over to overwrite the destination.
            tmp_link = ckpt_path + ".{}.tmp".format(time.time())
            os.symlink(ckpt_path, tmp_link)
            os.rename(tmp_link, os.path.join(path, "checkpoint_latest.pth"))
        if remove_old is not None and remove_old > 0 and epoch > remove_old:
            ckpt_path_old = os.path.join(
                path,
                '{}_epoch_{}.pth'.format(self.name, epoch - remove_old),
            )
            if os.path.exists(ckpt_path_old):
                os.remove(ckpt_path_old)

    def compute_loss(self, output, target, meta_target):
        pass
