import torch
import torch.nn as nn
import numpy as np
from RuleNetwork import RuleNetwork
from utilities.GroupLinearLayer import GroupLinearLayer

class ArithmeticModel(nn.Module):
	def __init__(self, slot_dim, num_slots, num_rules, topk1 = 4, topk2 = 2, grad_flow1 = False, grad_flow2 = False, entropy_weight = 1e-2, use_entropy = False, rule_dim = 3, dropout = 0.1, rule_mlp_dim = 32):
		super().__init__()
		self.slot_dim = slot_dim
		self.number_encoder = nn.Sequential(
											nn.Linear(1, 32),
											nn.ReLU(),
											nn.Linear(32, 32),
											nn.ReLU(),
											nn.Linear(32, 32),
											nn.ReLU(),
											nn.Linear(32, slot_dim))

		self.number_decoder = nn.Sequential(
											nn.Linear(slot_dim, 32),
											nn.ReLU(),
											nn.Linear(32, 32),
											nn.ReLU(),
											nn.Linear(32, 32),
											nn.ReLU(),
											nn.Linear(32, 1))



		self.use_entropy = use_entropy
		print('use_entropy:'+str(self.use_entropy))

		self.design_config = {'comm': True, 'grad': False,
			'transformer': True, 'application_option': '3.0.1.1', 'selection': 'gumble'}
        
		self.rule_network = RuleNetwork(6, num_slots, num_rules = num_rules, rule_dim = rule_dim, query_dim = 32, value_dim = 64, key_dim = 32, num_heads = 4, dropout = dropout, design_config = self.design_config, topk1 = topk1, topk2 = topk2, grad_flow1 = grad_flow1, grad_flow2 = grad_flow2, entropy_weight = entropy_weight, rule_mlp_dim = rule_mlp_dim)
		num_params = 0
		for p in self.rule_network.variable_rule_select.parameters():
			num_params += p.numel()
		for p in self.rule_network.variables_select_1.parameters():
			num_params += p.numel()
		print('parameters:' + str(num_params))
		self.mse_loss = nn.MSELoss()
		self.ce_loss = nn.CrossEntropyLoss()

	def forward(self, x, gumbel = True, true_mask = None):
		B, T, _, _ = x.size()
		aux_input = x[:, -1]#.unsqueeze(-1)
		#aux_input = self.number_encoder(aux_input)
		state = x[:, 0]#.unsqueeze(-1)

		mse_loss = 0

		true_mask = torch.argmax(true_mask.squeeze(-1), dim = -1)
    	
		self.rule_network.reset_activations()
		entropy = 0
		for t in range(T - 1):
			#state = self.number_encoder(state)
			#state_flip = torch.flip(state, dims = [1])
			diff = aux_input - state

			state = torch.cat((state, diff, aux_input), dim = -1)
			#print(state)


			
			state_, ent, mask, variable_probs = self.rule_network(state, state, gumbel = gumbel)
			
			state = mask * state_ + (1 - mask) * state[:, :, :2]

			#mse_state = (state_ * mask).sum(dim = 1)
			#mse_target = (aux_input * mask).sum(dim = 1)
			#print(mse_state)
			#print(mse_target)
			#print('-------------')
			
			loss_ = self.mse_loss(state, aux_input)
			if self.use_entropy:
				mse_loss += (loss_ + ent)
			else:
				mse_loss += loss_
			#print(true_mask.size())

			

			#if self.training:
				#print(mask.size())
				#print

			#	ce_loss = self.ce_loss(mask.squeeze(-1), true_mask[:, t])
			#	mse_loss +=  ce_loss

			predicted_number = state
			#state = predicted_number

		rule_activations = self.rule_network.rule_activation
		variable_activations = self.rule_network.variable_activation
		return mse_loss, predicted_number, rule_activations, variable_activations



class GroupMLP(nn.Module):
	def __init__(self, in_dim, out_dim, num, dropout = 0.1, rule_mlp_dim = 32):
		super().__init__()
		print('Rule MLP Dim:' + str(rule_mlp_dim))
		self.group_mlp1 = GroupLinearLayer(in_dim, rule_mlp_dim, num)
		self.group_mlp2 = GroupLinearLayer(rule_mlp_dim, out_dim, num)
		self.group_mlp3 = GroupLinearLayer(rule_mlp_dim, rule_mlp_dim, num)
		self.group_mlp4 = GroupLinearLayer(rule_mlp_dim, out_dim, num)
		self.dropout = nn.Dropout(p = dropout)
		self.layer_norm_0 = nn.LayerNorm(in_dim)


	def forward(self, x):
		#x = self.layer_norm_0(x)
		x = torch.relu(self.dropout(self.group_mlp1(x)))
		x = self.group_mlp2(x)
		#x = torch.relu(self.dropout(self.group_mlp3(x)))
		#x = self.group_mlp4(x)


		#x = torch.relu(self.group_mlp3(x))
		#x = self.group_mlp4(x)
		return x


class Baseline(nn.Module):
	def __init__(self, slot_dim, num_slots, num_rules, topk1 = 4, topk2 = 2, grad_flow1 = False, grad_flow2 = False, entropy_weight = 1e-2, use_entropy = False, rule_dim = 3, dropout = 0.1, rule_mlp_dim = 32):
		super().__init__()

		self.rules = GroupMLP(4, 2, 4, rule_mlp_dim = rule_mlp_dim)

		self.encoder = nn.Sequential(nn.Linear(8, 32), nn.ReLU(),
									nn.Linear(32, 32), nn.ReLU(),
									nn.Linear(32, 32), nn.ReLU(),
									nn.Linear(32, 32))
		num_params = 0
		for p in self.encoder.parameters():
			num_params += p.numel()
		print('parameters:' + str(num_params + 64 + 64 + 128))

		self.variable_1_head = nn.Linear(32, 2)
		self.variable_2_head = nn.Linear(32, 2)
		self.operation_head = nn.Linear(32, 4)

		self.mse_loss = nn.MSELoss()

	def forward(self, x, gumbel = True, true_mask = None):
		inputs = x[:, 0]
		targets = x[:, 1]
		batch_size = inputs.shape[0]

		flatten_inputs = inputs.reshape(batch_size, -1)
		flatten_targets = targets.reshape(batch_size, -1)

		mlp_inp = torch.cat((flatten_inputs, flatten_targets), dim = 1)

		mlp_out = self.encoder(mlp_inp)

		variable_1_scores = self.variable_1_head(mlp_out)
		variable_2_scores = self.variable_2_head(mlp_out)

		operation_scores = self.operation_head(mlp_out)

		variable_activation = torch.argmax(variable_1_scores, dim = 1)

		if self.training:
			variable_1_mask = torch.nn.functional.gumbel_softmax(variable_1_scores, dim = 1, tau = 1.0, hard = True)
		else:
			variable_1_mask = torch.zeros(variable_1_scores.size()).to(variable_1_scores.device)
			variable_1_mask.scatter_(1, variable_activation.unsqueeze(-1), 1)

		variable_activation = variable_activation.cpu().numpy()

		contextual_activation = torch.argmax(variable_2_scores, dim = 1)

		if self.training:
			variable_2_mask = torch.nn.functional.gumbel_softmax(variable_2_scores, dim = 1, tau = 1.0, hard = True)
		else:
			variable_2_mask = torch.zeros(variable_2_scores.size()).to(variable_2_scores.device)
			variable_2_mask.scatter_(1, contextual_activation.unsqueeze(-1), 1)


		rule_activation = torch.argmax(operation_scores, dim = 1)

		if self.training:
			rule_mask = torch.nn.functional.gumbel_softmax(operation_scores, dim = 1, tau = 1.0, hard = True)
		else:
			rule_mask = torch.zeros(operation_scores.size()).to(operation_scores.device)
			rule_mask.scatter_(1, rule_activation.unsqueeze(-1), 1)

		rule_activation = rule_activation.cpu().numpy()
		rule_input_variable_1 = (inputs * variable_1_mask.unsqueeze(-1)).sum(dim = 1)

		rule_input_variable_2 = (inputs * variable_2_mask.unsqueeze(-1)).sum(dim = 1)

		rule_input = torch.cat((rule_input_variable_1, rule_input_variable_1), dim = 1).unsqueeze(1).repeat(1, 4, 1)

		rule_output = self.rules(rule_input)

		rule_output = (rule_output * rule_mask.unsqueeze(-1)).sum(dim = 1).unsqueeze(1).repeat(1, 2, 1)

		rule_output = (1 - variable_1_mask.unsqueeze(-1)) * inputs + variable_1_mask.unsqueeze(-1) * rule_output

		mse_loss = self.mse_loss(rule_output, targets)

		return mse_loss, rule_output, [rule_activation], [variable_activation]




