import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist

from utils  import *
from data   import *
from buffer import Buffer
from copy   import deepcopy
from pydoc  import locate
from model  import ResNet18, MLP
import numpy as np
import datetime


#This runs the iid online , iid offline, and single baselines

# arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, choices=['split_mnist', 'permuted_mnist', 'split_cifar100_rahaf', 'split_cifar', 'split_cifar100_fb'], default = 'split_cifar')
parser.add_argument('--n_tasks', type=int, default=5)
parser.add_argument('--n_epochs', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--buffer_batch_size', type=int, default=10)
parser.add_argument('--use_conv', type=int, default=1)
parser.add_argument('--samples_per_task', type=int, default=1000, help='if negative, full dataset is used')
parser.add_argument('--mem_size', type=int, default=5, help='controls buffer size') # mem_size in the tf repo.
parser.add_argument('--n_runs', type=int, default=2, help='number of runs to average performance')
parser.add_argument('--n_iters', type=int, default=1, help='training iterations on incoming batch')
parser.add_argument('--rehearsal', type=int, default=1, help='whether to replay previous data')
parser.add_argument('--hidden_dim', type=int, default=20)
parser.add_argument('--multiple_heads', action='store_true', help='multiple_gheads')
parser.add_argument('--compare_to_old_logits', action='store_true', help='for max_loss')
parser.add_argument('--compare_to_old_ratio', type=float, default=1.0, help='ratio of old loss')
parser.add_argument('--pseudo_targets', action='store_true')
parser.add_argument('--mixup', action='store_true', help='use manifold mixup')
parser.add_argument('--subsample', type=int, help='subsample', default=50)
parser.add_argument('--mixup_buf', action='store_true', help='use manifold mixup')
parser.add_argument('--name', type=str, default='', help='name_exp')
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--ratio', type=float, default=1.0)
parser.add_argument('--age', action='store_true',help='maximize the true loss instead of KL')
parser.add_argument('--logit_soft', action='store_true',help='maximize the true loss instead of KL')
#Added by Rahaf
parser.add_argument('--max_loss', action='store_true',help='maximize the true loss instead of KL')
parser.add_argument('--reuse_samples', action='store_true', help='reuse same samples over the iterations')
parser.add_argument('--diverse_retreival', action='store_true', help='retreive divese samples')
parser.add_argument('--entropy', type=float, default=0)
parser.add_argument('--validation', type=int, default=0,help='use validation')
args = parser.parse_args()

##################### Logs
time_stamp = str(datetime.datetime.now().isoformat())
name_log_txt = args.dataset+'_'+time_stamp + str(np.random.randint(0, 1000)) + args.name
name_log_txt=name_log_txt +'.log'
with open(name_log_txt, "a") as text_file:
    print(args, file=text_file)

# fixed for now
args.ignore_mask = False
args.input_size = (3, 32, 32)
args.device = 'cuda:0'
if args.dataset == 'split_cifar100_fb':
    args.n_classes = 100
    if args.validation:
        args.n_tasks = 3
    else:
        args.n_tasks = 17
    args.samples_per_task = 2500
   # args.multiple_heads = False
    buffer_size = args.n_tasks*args.mem_size*5
    args.n_classes = 5 * args.n_tasks
elif args.dataset == 'split_cifar100_rahaf':
    args.n_classes = 100
    args.n_tasks =  5
    args.samples_per_task = 5000
    args.multiple_heads = False
    buffer_size = args.n_tasks * args.mem_size * 20
elif args.dataset == 'split_mnist':
    args.n_classes = 10
    args.n_tasks = 5
    buffer_size = args.n_tasks * args.mem_size * 2
    args.input_size = (784,)
    args.use_conv = False
elif args.dataset == 'permuted_mnist':
    args.n_classes = 10
    args.n_tasks = 10
    buffer_size = args.mem_size * args.n_classes
    args.input_size = (784,)
    args.use_conv = False
    args.ignore_mask = True
else:
    args.n_classes = 10
    args.n_tasks = 5
    buffer_size = args.n_tasks*args.mem_size*2

args.gen = False

kl = dist.kl.kl_divergence
Cat = dist.categorical.Categorical

# fetch data
data = locate('data.get_%s' % args.dataset)(args)

seed=0
torch.manual_seed(seed)

# make dataloaders just like in the CL case
train_loader, test_loader  = [CLDataLoader(elem, args, train=t) for elem, t in zip(data, [True, False])]

x_s = []
y_s = []
for loader in train_loader:
    x_temp = loader.dataset.x
    y_temp = loader.dataset.y
    if args.samples_per_task > 0:
        idx = np.random.permutation(np.arange(0,len(x_temp)))
        x_temp = x_temp[idx[0:args.samples_per_task]]
        y_temp = y_temp[idx[0:args.samples_per_task]]
    x_s.append(x_temp)
    y_s.append(y_temp)

train_loader = torch.utils.data.DataLoader(
                        XYDataset(torch.cat(x_s),torch.cat(y_s), **{'source':train_loader.loaders[0].dataset.source}),
                        batch_size=args.batch_size, shuffle=True, drop_last=False)
x_s = []
y_s = []
for loader in test_loader:
    x_temp = loader.dataset.x
    y_temp = loader.dataset.y
    x_s.append(x_temp)
    y_s.append(y_temp)
test_loader = torch.utils.data.DataLoader(
                        XYDataset(torch.cat(x_s),torch.cat(y_s), **{'source':test_loader.loaders[0].dataset.source}),
                        batch_size=64, shuffle=True, drop_last=True)

if args.use_conv:
    # fetch model and ship to GPU
    reset_model = lambda : ResNet18(args.n_classes, nf=args.hidden_dim).to(args.device)
else:
    reset_model = lambda: MLP(args).to(args.device)

reset_opt = lambda model : torch.optim.SGD(model.parameters(), lr=args.lr)
all_models = {}

CE = lambda student, teacher : F.kl_div(F.log_softmax(student, dim=-1), F.softmax(teacher.detach(), dim=-1), reduction='batchmean')
entropy_fn = lambda x : torch.sum(F.softmax(x) * F.log_softmax(x),dim=1)
# Train the model 
# -------------------------------------------------------------------------------


for run in range(args.n_runs):
    all_models[run] = []
    model = reset_model()
    model = model.train()
    opt = reset_opt(model)
    torch.manual_seed(seed+run)
    grad_dims = []
    for param in model.parameters():
        grad_dims.append(param.data.numel())

    # iterate over samples from task
    for epoch in range(args.n_epochs):
        loss_ , correct, deno = 0., 0., 0.
        for i, (data, target) in enumerate(train_loader):
            data, target = data.to(args.device), target.to(args.device)
            # data = data.float()
            # data = data / 255.
            for iter in range(args.n_iters):

                input_x, input_y = data, target

                hid = model.return_hidden(input_x)
                logits = model.linear(hid)
                if args.multiple_heads:
                    logits = logits.masked_fill(loader.dataset.mask == 0, -1e9)
                loss = F.cross_entropy(logits, input_y)

                pred = logits.argmax(dim=1, keepdim=True)
                correct += pred.eq(input_y.view_as(pred)).sum().item()
                deno  += pred.size(0)
                loss_ += loss.item()

                opt.zero_grad()
                loss.backward()
                opt.step()

        print('Epoch {}\t Loss {:.6f}\t, Acc {:.6f}'.format(epoch, loss_ / i, correct / deno))

        all_models[run] += [deepcopy(model).cpu()]


# Test the model 
# -------------------------------------------------------------------------------
avgs = []
with torch.no_grad():
    accuracies = np.zeros((args.n_runs,args.n_epochs))
    for run in range(args.n_runs):
        for epoch, model in enumerate(all_models[run]):
            model = model.eval().to(args.device)
            loss_, correct, deno = 0., 0., 0.
            for i, (data, target) in enumerate(test_loader):
                data, target = data.to(args.device), target.to(args.device)

                logits = model(data)
                if args.multiple_heads:
                    logits = logits.masked_fill(loader.dataset.mask == 0 , -1e9)
                loss   = F.cross_entropy(logits, target)
                pred = logits.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                deno += data.size(0) #pred.size(0)

            accuracies[run][epoch]= correct / deno
            model = model.cpu()


#print('Max loss = {}. AVG over {} runs : {:.4f}'.format(args.max_loss, args.n_runs, sum(avgs) / len(avgs)))
for epoch in range(len(all_models[run])):
    avg = accuracies[:,epoch].mean()
    std = accuracies[:,epoch].std()
    with open(name_log_txt, "a") as text_file:
        print('After Epoch {} AVG over {} runs : {:.4f} +- {:.4f}'
              .format(epoch, args.n_runs, avg, std*2./np.sqrt(args.n_runs)) , file=text_file)

