import os
import math
import torch
import random
import numpy as np
from torch.optim import SGD, Adam
import datetime

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def get_timestamp():
    dateTimeObj = datetime.datetime.now()
    timestampStr = dateTimeObj.strftime("%d_%m_%H_%M_%S")
    return timestampStr

def learning_rate(init, epoch):
    optim_factor = 0
    if(epoch > 160):
        optim_factor = 3
    elif(epoch > 120):
        optim_factor = 2
    elif(epoch > 60):
        optim_factor = 1

    return init*math.pow(0.2, optim_factor)

def make_directories(path):
    if not os.path.exists(path):
        os.makedirs(path)

# def log_scalar(key, value, type):
#     if type == "metric":
#         mlflow.log_metric(key, value)
#     elif type == "param":
#         mlflow.log_param(key, value)
    
def save_param(key, value, root_dir, device):

    # check if save_root path exists or not; make a directory if not
    make_directories(root_dir)

    if value.dim() == 0:
        value = value.unsqueeze(dim=0)

    file_name = os.path.join(root_dir, key+".pt")
    if os.path.isfile(file_name):
        saved_tensor = torch.load(file_name, map_location=device)
        tensor_to_save = torch.cat([saved_tensor, value])
    else:
        tensor_to_save = value

    torch.save(tensor_to_save, file_name)

def fetch_dim(x):
    """
    """
    # how to handle varied shape images? 
    # because in queries after the first one (in a multiquery setup),    
    # different images can have different dimensions within a batch
    shape_tensor = torch.as_tensor(x.shape[1:])
    return torch.prod(shape_tensor)

def get_image_size(args):
    if args.dataset == "cifar10":
        image_size = 32
    elif args.dataset == "tiny_imagenet":
        image_size = 64
    if args.pad_size > 0:
        image_size += (2 * args.pad_size)
    return image_size

def get_mask_shape(args, image_size):
    if args.wm_across_channels == "different":
        mask_shape = (3, image_size, image_size)
    elif args.wm_across_channels == "same":
        mask_shape = (image_size, image_size)
    return mask_shape

def get_save_directory_path(args):
    # get save directory to store logs
    second_query_mask_model = str(args.second_query_mask_model)
    seed = "_seed_"+str(args.seed)
    run_description = str(args.run_description)
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    save_dir = os.path.join("logs", now + "_"+ second_query_mask_model + "_" + run_description)
    # if args.run_description is not None:
    #     save_dir = os.path.join(save_dir, args.run_description)
    return save_dir

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def init_logfile(filename: str, text: str):
    f = open(filename, 'w')
    f.write(text+"\n")
    f.close()

def log(filename: str, text: str):
    f = open(filename, 'a')
    f.write(text+"\n")
    f.close()

def get_optimizer(params, optim_dict):
    if optim_dict["type"] == "SGD":
        return SGD(params,
                   lr=optim_dict["lr"],
                   momentum=optim_dict["momentum"],
                   weight_decay=optim_dict["weight_decay"])
    elif optim_dict["type"] == "Adam":
        return Adam(params,
                    lr=optim_dict["lr"],
                    eps=optim_dict["eps"],
                    weight_decay=optim_dict["weight_decay"])
