import torch
import numpy as np
from data import Data
from model import ArithmeticModel, Baseline
from tqdm import tqdm
from utilities.rule_stats import get_stats
from termcolor import colored
import argparse
import random
import os

def str2bool(v):
    """Method to map string to bool for argument parser"""
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    if v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    raise argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model')
parser.add_argument('--num_slots', type = int)
parser.add_argument('--topk1', type = int)
parser.add_argument('--topk2', type = int)
parser.add_argument('--grad_flow1', type = str2bool)
parser.add_argument('--grad_flow2', type = str2bool)
parser.add_argument('--entropy_weight', type = float)
parser.add_argument('--use_entropy', type = str2bool)
parser.add_argument('--seed', type = int)
parser.add_argument('--rule_dim', type = int, default = 3)
parser.add_argument('--dropout', type = float, default = 0.2)
parser.add_argument('--rule_mlp_dim', type = int, default = 32)
parser.add_argument('--baseline', type = str2bool, default = False)
args = parser.parse_args()


NUM_SLOTS = args.num_slots
NUM_RULES = 4


def set_seed(seed):
    """Set seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True

set_seed(args.seed)

train_dataset = Data(num_operations = 1, num_slots = NUM_SLOTS, num_examples = 10000)
eval_dataset = Data(num_operations = 1, num_slots = NUM_SLOTS, num_examples = 2000)

plot_data = {"X Addition": [[] for _ in range(NUM_RULES)], "Y Addition": [[] for _ in range(NUM_RULES)], "X Subtraction": [[] for _ in range(NUM_RULES)], "Y Subtraction": [[] for _ in range(NUM_RULES)]}

opertion_to_plot_key = {0: 'X Addition', 1: "Y Addition", 2: "X Subtraction", 3: "Y Subtraction" }

def collate_fn(x):
	numbers = []
	operations = []
	masks = []
	for b in x:
		num = b[0]
		op = b[1]
		mask = b[2]

		operations.append(op)
		numbers.append(num)
		masks.append(mask)


	numbers = torch.stack(numbers, dim = 0)
	masks = torch.stack(masks, dim = 0)
	return numbers, operations, masks

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64,
                        shuffle=True, num_workers=4, collate_fn = collate_fn)

val_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size = 64,
				shuffle = True, num_workers = 4, collate_fn = collate_fn)
if args.baseline:
	model_name = Baseline
else:
	model_name = ArithmeticModel

model = model_name(64, NUM_SLOTS, NUM_RULES, args.topk1, args.topk2, args.grad_flow1, args.grad_flow2, args.entropy_weight, args.use_entropy, args.rule_dim, args.dropout, args.rule_mlp_dim).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

def rule_stats(rule_activations, variable_activations, operations, number_to_slot, operation_to_rule, data = None, preds = None):
	#print(len(operations))
	#print(rule_activations[0].shape[0])
	for b in range(rule_activations[0].shape[0]):
		for t in range(len(rule_activations)):
			var_selected = variable_activation[t][b]
			number_key = operations[b][t][0]

			number_to_slot[number_key][var_selected] += 1

			rule_selected = rule_activations[t][b]
			operation_key = operations[b][t][-1]

			operation_to_rule[operation_key][rule_selected] += 1
			

	return number_to_slot, operation_to_rule


best_val_loss = 100

for n in range(300):
	print(colored(str(n), "blue"))
	model.train()
	loss = 0
	count = 0
	number_to_slot = {k:{j:0 for j in range(NUM_SLOTS)} for k in range(NUM_SLOTS)}
	operation_to_rule = {k:{j:0 for j in range(NUM_RULES)} for k in range(NUM_RULES)}
	for iter_, t in tqdm(enumerate(train_dataloader)):
		x, ops, masks = t
		masks = masks.cuda()
		#print(ops)
		x = x.float().cuda()
		#if n < 150:
		loss_, preds, rule_activations, variable_activation = model(x, gumbel = True, true_mask = masks)
		#else:
		#	loss_, preds, rule_activations, variable_activation = model(x, gumbel = False)

		number_to_slot, operation_to_rule = rule_stats(rule_activations, variable_activation, ops, number_to_slot, operation_to_rule)
		#if iter_ % 50 == 0 :

		#	print('predictions:')
		#	print(preds[0])
		#	print('targets:')
		#	print(x[0, -1, :])
		#	print('--------------------')
		
		optimizer.zero_grad()
		loss_.backward()
		optimizer.step()

		loss += loss_.item()
		count += 1

	print(colored('train_loss:' + str(loss / count), "red"))
	print('Number to slot:')
	for k in number_to_slot:
		print(k, end = ' ')
		print(number_to_slot[k])
	print('Operation to Rule:')
	for k in operation_to_rule:
		print(k, end = ' ')
		print(operation_to_rule[k])
	
	model.eval()
	loss = 0
	count = 0

	number_to_slot = {k:{j:0 for j in range(NUM_SLOTS)} for k in range(NUM_SLOTS)}
	operation_to_rule = {k:{j:0 for j in range(NUM_RULES)} for k in range(4)}

	for iter_, t in tqdm(enumerate(val_dataloader)):
		x, ops, masks = t
		masks = masks.cuda()
		x = x.float().cuda()
		loss_, preds, rule_activations, variable_activation = model(x, true_mask = masks)

		number_to_slot, operation_to_rule = rule_stats(rule_activations, variable_activation, ops, number_to_slot, operation_to_rule, data = x, preds = preds)

		targets = x[:, -1]
		
		loss += loss_
		count += 1

	print(preds[0])
	print(targets[0])

	print(colored('eval_loss:' + str(loss / count), "green"))
	val_loss = loss / count
	if val_loss < best_val_loss:
		best_val_loss = val_loss
	
	print(colored('best_eval_loss:' + str(best_val_loss), "yellow"))	
	print('Number to slot:')
	for k in number_to_slot:
		print(k, end = ' ')
		print(number_to_slot[k])
	print('Operation to Rule:')
	for k in operation_to_rule:
		print(k, end = ' ')
		print(operation_to_rule[k])

	for k in operation_to_rule:
		key = opertion_to_plot_key[k]
		for j in operation_to_rule[k]:
			plot_data[key][j].append(operation_to_rule[k][j])

import pickle
if args.baseline:
	file_name = 'baseline/baseline_' + str(args.seed)
else:
	file_name = 'NPS/nps_' + str(args.seed)


f = open(file_name, "wb")

pickle.dump(plot_data, f)
		



