import os, pdb
import sys
import os
import copy
cwd = os.getcwd()
sys.path.append(cwd)

from tqdm import tqdm
import torch.nn as nn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import wandb
import numpy as np
import argparse
import datetime
from aesthetic_scorer import MLPDiff, MLPDiff_simple

def train():
    parser = argparse.ArgumentParser()

    # Add arguments
    parser.add_argument('--num_epochs', type=int, default=100, help='number of epochs to train')
    parser.add_argument('--train_bs', type=int, default=256)
    parser.add_argument('--val_bs', type=int, default=512)
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--noise', type=float, default=0.1, help='label noise level')
    parser.add_argument('--run_name', type=str, default='baseline_v2')
    
    parser.add_argument('--SGLD', type = bool, default = False)
    parser.add_argument('--SGLD_base_noise', type = float, default = 0.1)

    args = parser.parse_args()

    unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
    if not args.run_name:
        args.run_name = unique_id
    else:
        args.run_name += "_" + unique_id

    wandb.init(project="offline_reward_aesthetic", name=args.run_name,
        config={
        'lr': args.lr,
        # 'num_data':args.num_data,
        'num_epochs':args.num_epochs,
        'train_batch_size':args.train_bs,
        'val_batch_size':args.val_bs,
    })
    # load the training data 

    x = np.load("./reward_aesthetic/data/ava_x_openclip_l14.npy")

    y = np.load("./reward_aesthetic/data/ava_y_openclip_l14.npy")

    val_percentage = 0.05 # 5% of the trainingdata will be used for validation

    train_border = int(x.shape[0] * (1 - val_percentage) )

    train_tensor_x = torch.Tensor(x[:train_border]) # transform to torch tensor
    train_tensor_y = torch.Tensor(y[:train_border])

    train_dataset = TensorDataset(train_tensor_x,train_tensor_y) # create your datset
    train_loader = DataLoader(train_dataset, batch_size=args.train_bs
                , shuffle=True,  num_workers=16) # create your dataloader


    val_tensor_x = torch.Tensor(x[train_border:]) # transform to torch tensor
    val_tensor_y = torch.Tensor(y[train_border:])

    '''
    print(train_tensor_x.size())
    print(val_tensor_x.size())
    print( val_tensor_x.dtype)
    print( val_tensor_x[0].dtype)
    '''

    val_dataset = TensorDataset(val_tensor_x,val_tensor_y) # create your datset
    val_loader = DataLoader(val_dataset, batch_size=args.val_bs,  num_workers=16) # create your dataloader


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = MLPDiff_simple().to(device)   # CLIP embedding dim is 768 for CLIP ViT L 14

    optimizer = torch.optim.Adam(model.parameters()) 
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)



    # choose the loss you want to optimze for
    criterion = nn.MSELoss()
    criterion2 = nn.L1Loss()

    model.train()
    best_loss =999
    
    eval_model = MLPDiff().to(device)
    eval_model.requires_grad_(False)
    eval_model.eval()
    s = torch.load("./reward_aesthetic/backup/sac+logos+ava1-l14-linearMSE.pth")   # load the model you trained previously or the model available in this repo
    eval_model.load_state_dict(s)

    def adjust_noise(learning_rate, batch_size):
        return args.SGLD_base_noise * (learning_rate ** 0.5) / (batch_size ** 0.5)   

    for epoch in tqdm(range(args.num_epochs), desc="Epochs"):
        if args.SGLD:
            noise_level = adjust_noise(optimizer.param_groups[0]['lr'], args.train_bs)
        
        losses = []
        save_name = f'./reward_aesthetic/models/{args.run_name}_{epoch+1}.pth'
        
        for batch_num, (x,_) in enumerate(tqdm(train_loader,
                                desc=f"Epoch {epoch+1}/{args.num_epochs}")):
            optimizer.zero_grad()
            x = x.to(device).float()
            y_real = eval_model(x).to(device)
            noisy_y = y_real + torch.randn_like(y_real,device=device) * args.noise

            output = model(x)
            
            loss = criterion(output, noisy_y.detach())
            loss.backward()
            losses.append(loss.item())
            
            if args.SGLD:
                for param in model.parameters(): # add Gaussian noise to gradients
                    param.grad += noise_level * torch.randn_like(param.grad)

            optimizer.step()

            if batch_num % 1000 == 0:
                print('\tEpoch %d | Batch %d | Loss %6.2f' % (epoch, batch_num, loss.item()))
                wandb.log({"batch_loss": loss.item()})
                #print(y)

        print('Epoch %d | Loss %6.2f' % (epoch, sum(losses)/len(losses)))
        wandb.log({"epoch": epoch, "mean_batch_loss": sum(losses)/len(losses)})
        losses = []
        losses2 = []
        
        for batch_num, input_data in enumerate(val_loader):
            model.eval()
            optimizer.zero_grad()
            x, _ = input_data
            x = x.to(device).float()
            y_real = eval_model(x).to(device)

            output = model(x)
            loss = criterion(output, y_real.detach())
            lossMAE = criterion2(output, y_real.detach())

            losses.append(loss.item())
            losses2.append(lossMAE.item())

            if batch_num % 1000 == 0:
                print('\tValidation - Epoch %d | Batch %d | MSE Loss %6.2f' % (epoch, batch_num, loss.item()))
                print('\tValidation - Epoch %d | Batch %d | MAE Loss %6.2f' % (epoch, batch_num, lossMAE.item()))
                
                #print(y)

        print('Validation - Epoch %d | MSE Loss %6.2f' % (epoch, sum(losses)/len(losses)))
        print('Validation - Epoch %d | MAE Loss %6.2f' % (epoch, sum(losses2)/len(losses2)))
        if sum(losses2)/len(losses2) < best_loss:
            print("Best MAE Val loss so far. Saving model")
            best_loss = sum(losses2)/len(losses2)
            print( best_loss ) 

            torch.save(model.state_dict(), save_name )

        scheduler.step(sum(losses)/len(losses))
    torch.save(model.state_dict(), save_name)

    print("Best MAE loss", best_loss) 

    print("training done")
    # inferece test with dummy samples from the val set, sanity check
    print( "inferece test with dummy samples from the val set, sanity check")
    model.eval()
    output = model(x[:5].to(device))
    print(output.size())
    print(output)

def compute_covariance_matrix(model,device='cuda'):
    x = np.load("./reward_aesthetic/data/ava_x_openclip_l14.npy")
    # train_dataset = TensorDataset(torch.Tensor(x))
    # train_loader = DataLoader(train_dataset, batch_size=256,
    #                     shuffle=False,  num_workers=16)
    with torch.no_grad():
        features = model.forward_up_to_second_last(torch.from_numpy(x).to(device))
        cov_mat =  torch.cov(features.t())
    return cov_mat, features      


if __name__   == "__main__":
    # train()
    
    # Calculating Covariance Matrix
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    predictor = MLPDiff_simple().to(device)
    state_dict = torch.load("./reward_aesthetic/models/baseline_simple.pth")
    predictor.load_state_dict(state_dict)
    predictor.eval()
    predictor.requires_grad_(False)
    
    cov_mat, feats = compute_covariance_matrix(model=predictor,device=device)
    
    print(torch.linalg.eigvals(cov_mat))
    print(cov_mat.min(), cov_mat.max())
    torch.save(cov_mat, './reward_aesthetic/models/AVA_covariances_simple.pt')
    
    
    # cov_mat_2 = torch.zeros_like(cov_mat).to(cov_mat.device)
    # mean_feat = torch.mean(feats, dim=0).unsqueeze(1)
 
    # for idx in range(feats.shape[0]):
    #     x = feats[idx,:].unsqueeze(1) - mean_feat
    #     cov_mat_2 += torch.mm(x, x.t()).to(cov_mat.device)
    # cov_mat_2 = cov_mat_2 / (feats.shape[0] - 1)
    # print(cov_mat_2.min(), cov_mat_2.max())
    # print(cov_mat_2 - cov_mat)
    
    