import torch
import torch.nn
import torch.nn.functional as F
from torchvision.transforms import transforms
from IPython import embed
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import time
import argparse
import torch.nn.functional as F
import scipy
import time
import os
import pickle
from dataset import TrajDataset
from net import TransformerVision

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if __name__ == '__main__':
    if not os.path.exists('figs/loss'):
        os.makedirs('figs/loss', exist_ok=True)
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)

    import argparse

    parser = argparse.ArgumentParser()
    
    parser.add_argument("--envs", type=int, required=False, default=5000, help="Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=30, help="Context horizon")
    parser.add_argument("--embd", type=int, required=False, default=32, help="Embedding")
    parser.add_argument("--head", type=int, required=False, default=1, help="Embedding")
    parser.add_argument("--sample", type=int, required=False, default=0, help="Read in the data?")
    parser.add_argument("--layer", type=int, required=False, default=3, help="Embedding")
    parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Dimension")
    parser.add_argument("--opt", type=int, required=False, default=0, help="Optimizer type")
    parser.add_argument("--dropout", type=float, required=False, default=0, help="Dropout")
    parser.add_argument("--envname", type=str, required=False, default='mini', help="Environment name")

    parser.add_argument('--full', default=False, action='store_true')
    parser.add_argument('--shuffle', default=False, action='store_true')
    parser.add_argument('--small', default=False, action='store_true')
    parser.add_argument('--shift', default=False, action='store_true')


    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H']
    n_embd = args['embd']
    n_head = args['head']
    n_layer = args['layer']
    useQ = False 
    use_net = False 
    lr = args['lr']
    shuffle = args['shuffle']
    full = args['full']
    opt = args['opt']
    dropout = args['dropout']
    sample = args['sample']
    small = args['small']
    shift = args['shift']
    envname = args['envname']
    trans = 1
    du = 4

    distr = n_envs < 0 

    EPOCHS = 1000

    paths_train = [

        f'datasets/trajs_mini_four_boxes_fixed2_envs20000_hists{n_hists}_samples{n_samples}_small{small}_train.pkl',
        f'datasets/expert_trajs_mini_four_boxes_fixed2_envs10000_hists{n_hists}_samples{n_samples}_small{small}_train.pkl',
    ]
    paths_test = [

        f'datasets/trajs_mini_four_boxes_fixed2_envs20000_hists{n_hists}_samples{n_samples}_small{small}_test.pkl',
        f'datasets/expert_trajs_mini_four_boxes_fixed2_envs10000_hists{n_hists}_samples{n_samples}_small{small}_test.pkl',
    ]
    filename = f'{envname}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_sample{sample}_shift{shift}'



    batch_size = 128
    config = {
        'shuffle': shuffle,
        'distr': distr
    }
    


    config = {
        'H': H,
        'du': du,
        'n_layer': n_layer,
        'n_embd': n_embd,
        'n_head': n_head,
        'shuffle': shuffle,
        'full': full,
        'dropout': dropout,
        'small': small,
        'shift': shift,
    }

    model = TransformerVision(config).to(device)
    
    params = {'batch_size': batch_size,
            'shuffle': True}

    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ])

    if sample == 1:
        sample = True
    elif sample == 0:
        sample = False
    else:
        raise NotImplementedError
    ds_train = TrajDataset(paths_train, config, transform, sample=sample, mode='train')
    ds_test = TrajDataset(paths_test, config, transform, sample=sample, mode='test')

    n_train = len(ds_train)
    n_test = len(ds_test)

    train_loader = torch.utils.data.DataLoader(ds_train, **params)
    test_loader = torch.utils.data.DataLoader(ds_test, **params)



    if opt == 0:    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    elif opt == 1:  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:           raise NotImplementedError
    
    loss_fn = nn.CrossEntropyLoss(reduction='sum')


    test_loss = []
    train_loss = []


    for epoch in range(EPOCHS):
        print(f"Epoch: {epoch}")

        with torch.no_grad():
            epoch_test_loss = 0.0
            test_mistakes = 0
            for i, batch in enumerate(test_loader):

                true_actions = batch['actions']
                pred_actions = model(batch)

                if full:    true_actions = true_actions.unsqueeze(1).repeat(1, pred_actions.shape[1], 1)

                loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
                epoch_test_loss += loss.item() / H

                true_indices = true_actions[:,-1,:].detach().cpu().numpy().argmax(axis=1)
                pred_indices = pred_actions[:,-1,:].detach().cpu().numpy().argmax(axis=1)
                test_mistakes += np.sum(true_indices != pred_indices)

            test_loss.append(epoch_test_loss / n_test)
            print(f'Test Loss:        {test_loss[-1]}')


        epoch_train_loss = 0.0
        start_time = time.time()

        train_mistakes = 0
        for i, batch in enumerate(train_loader):

            true_actions = batch['actions']
            pred_actions = model(batch)

            if full:    true_actions = true_actions.unsqueeze(1).repeat(1, pred_actions.shape[1], 1)

            optimizer.zero_grad()
            loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item() / H
        
            true_indices = true_actions[:,-1,:].cpu().detach().numpy().argmax(axis=1)
            pred_indices = pred_actions[:,-1,:].detach().cpu().numpy().argmax(axis=1)
            train_mistakes += np.sum(true_indices != pred_indices)

        end_time = time.time()
        diff = end_time - start_time
        train_loss.append(epoch_train_loss / n_train)
        print(f'Train Loss:       {train_loss[-1]}')
        print(f'Batch time:       {diff}\n\n')

        test_err = test_mistakes / (n_test)
        train_err = train_mistakes / (n_train)

        print(f'Train Error:      {train_err}')
        print(f'Test Error:       {test_err}\n\n')

        # LOGGING
        if (not distr and ((epoch + 1) % 50 == 0)) or (distr and ((epoch + 1) % 500 == 0)):

            torch.save(model.state_dict(), f'models/{filename}_epoch{epoch+1}.pt')


        if (epoch + 1) % 50 == 0:
            plt.yscale('log')
            plt.plot(train_loss[1:], label="Train final")
            plt.plot(test_loss[1:], label="Test final")
            plt.legend()
            plt.savefig(f"figs/loss/{filename}_train_loss.png")
            plt.clf()

    torch.save(model.state_dict(), f'models/{filename}.pt')
    print("Done.")

    batch = next(iter(train_loader))
    train_angles = batch['angles'].cpu().detach().numpy()
    train_actions = batch['actions'].cpu().detach().numpy().argmax(axis=1)
    logits = model(batch).cpu().detach().numpy()[:,-1,:]
    probs = scipy.special.softmax(logits, axis=1)
    pred_actions = model(batch).cpu().detach().numpy()[:,-1,:].argmax(axis=1)


    embed()
