
import torch
import torch.nn as nn
import math
import numpy as np
from utilities.GroupLinearLayer import GroupLinearLayer
from utilities.attention_rim import MultiHeadAttention
import itertools
from utilities.attention import SelectAttention
#from pertubed_topk import PerturbedTopK

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)#.transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        #print(x.size())
        #print(self.pe.size())
        x = x + self.pe#.unsqueeze(0)
        return x


class Identity(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input):
		return input * 1.0
	def backward(ctx, grad_output):
		#print(torch.sqrt(torch.sum(torch.pow(grad_output,2))))
		print(grad_output)
		return grad_output * 1.0

class ArgMax(torch.autograd.Function):

	@staticmethod
	def forward(ctx, input):
		idx = torch.argmax(input, 1)
		ctx._input_shape = input.shape
		ctx._input_dtype = input.dtype
		ctx._input_device = input.device
		#ctx.save_for_backward(idx)
		op = torch.zeros(input.size()).to(input.device)
		op.scatter_(1, idx[:, None], 1)
		ctx.save_for_backward(op)
		return op

	@staticmethod
	def backward(ctx, grad_output):
		op, = ctx.saved_tensors
		grad_input = grad_output * op
		return grad_input

class GroupMLP(nn.Module):
	def __init__(self, in_dim, out_dim, num, dropout = 0.1, rule_mlp_dim = 32):
		super().__init__()
		print('Rule MLP Dim:' + str(rule_mlp_dim))
		self.group_mlp1 = GroupLinearLayer(in_dim, rule_mlp_dim, num)
		self.group_mlp2 = GroupLinearLayer(rule_mlp_dim, out_dim, num)
		self.group_mlp3 = GroupLinearLayer(rule_mlp_dim, rule_mlp_dim, num)
		self.group_mlp4 = GroupLinearLayer(rule_mlp_dim, out_dim, num)
		self.dropout = nn.Dropout(p = dropout)
		self.layer_norm_0 = nn.LayerNorm(in_dim)


	def forward(self, x):
		#x = self.layer_norm_0(x)
		x = torch.relu(self.dropout(self.group_mlp1(x)))
		x = self.group_mlp2(x)
		#x = torch.relu(self.dropout(self.group_mlp3(x)))
		#x = self.group_mlp4(x)


		#x = torch.relu(self.group_mlp3(x))
		#x = self.group_mlp4(x)
		return x

class MLP(nn.Module):
	def __init__(self, in_dim, out_dim):
		super().__init__()
		self.mlp1 = nn.Linear(in_dim, 128)
		self.mlp2 = nn.Linear(128, out_dim)
		self.mlp3 = nn.Linear(128, 128)
		self.mlp4 = nn.Linear(128, out_dim)
		#self.dropout = nn.Dropout(p = 0.5)

	def forward(self, x):
		x = torch.relu(self.mlp1(x))
		x = self.mlp2(x)
		#x = torch.relu(self.mlp3(x))
		#x = self.mlp4(x)
		#x = torch.relu(self.mlp3(x))
		#x = self.mlp4(x)
		return x

class Hook():
    def __init__(self, inp):
        self.hook = inp.register_hook(self.hook_fn)
        self.mask = None
    def hook_fn(self, grad):
        grad = grad * self.mask
        return grad
    def close(self):
        self.hook.remove()

def masked_gumbel_softmax(logits, mask, hard = True, tau = 1.0, dim = 1, grad_flow = False):
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    gumbels.masked_fill(mask.to(torch.bool), float("-inf"))
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        if not grad_flow:
        	ret = y_hard - y_soft.detach() + y_soft * y_hard
        else:
        	ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret


def topk(x, k = 4):
	topk = torch.topk(x, dim =1, k = k)
	mask = torch.zeros(x.size()).to(x.device)
	mask.scatter_(1, topk.indices, 1)

	return mask

class RuleNetwork(nn.Module):
	def __init__(self, hidden_dim, num_variables, num_transforms = 3,  num_rules = 4, rule_dim = 64, query_dim = 32, value_dim = 64, key_dim = 32, num_heads = 4, dropout = 0.1, design_config = None, topk1 = 4, topk2 = 2, grad_flow1 = False, grad_flow2 = False, entropy_weight = 1e-2, rule_mlp_dim = 32):
		super().__init__()
		self.rule_dim = rule_dim
		self.num_heads = num_heads
		self.key_dim = key_dim
		self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
		self.value_dim = value_dim
		self.query_dim = query_dim
		self.hidden_dim = hidden_dim
		self.design_config = design_config
		self.entropy_weight = entropy_weight
		self.topk1 = topk1
		self.topk2 = topk2
		self.grad_flow1 = grad_flow1
		self.grad_flow2 = grad_flow2

		print('entropy_weight:' + str(self.entropy_weight))
		print('topk1:' + str(self.topk1))
		print('topk2:' + str(self.topk2))
		print('grad_flow1:' + str(self.grad_flow1))
		print('grad_flow2:' + str(self.grad_flow2))


		self.rule_activation = []
		self.variable_activation = []
		self.rule_prob = None
		self.softmax = []
		self.masks = []
		import math
		rule_dim = rule_dim
		
		print('RULE DIM:' + str(rule_dim))
		w =  torch.randn(1, num_rules, rule_dim).to(self.device)
		h_dim = 48
		self.variable_transform = nn.Linear(hidden_dim, h_dim)
		self.rule_transform = nn.Linear(rule_dim, h_dim)

		self.share_key_value = False
		self.shared_query = GroupLinearLayer(num_transforms, hidden_dim, 1)
		self.shared_key = GroupMLP(rule_dim, hidden_dim, num_rules)

		#self.pertubed_tk = PerturbedTopK(self.topk1)


		self.dummy_transform_rule = nn.Linear(rule_dim, hidden_dim)
		self.rule_embeddings = nn.Parameter(w)
		self.biases = np.zeros((num_rules, num_variables))
		self.use_biases = True
		self.transform_src = nn.Linear(300, 60)

		self.dummy_rule_selector = SelectAttention(num_transforms, rule_dim, d_k = 32, num_read = 1, num_write = num_rules, share_query = True, share_key = True)
		print('dropout?'+str(dropout))
		self.dropout = nn.Dropout(p = dropout)
		self.positional_encoding = PositionalEncoding(rule_dim, max_len = num_rules)

		
		self.transform_variable = nn.Linear(hidden_dim, 1024)

		self.transform_rule = nn.Linear(rule_dim, 1024)
		"""if hidden_dim % 4 != 0:
			num_heads = 2
		try:
			#self.positional_encoding = PositionalEncoding(hidden_dim)
			self.transformer_layer = nn.TransformerEncoderLayer(d_model = hidden_dim, nhead = num_heads, dropout = 0.5)

			self.transformer = nn.TransformerEncoder(self.transformer_layer, 3)
			self.multihead_attention = nn.MultiheadAttention(hidden_dim, 4)

		except:
			pass"""

		

		self.variable_rule_select = SelectAttention(rule_dim, hidden_dim , d_k=32, num_read = num_rules, num_write = num_variables, share_query = True, share_key = True)

		self.encoder_transform = nn.Linear(num_variables * hidden_dim, hidden_dim)
		print(hidden_dim + rule_dim)
		self.rule_mlp = GroupMLP(4, 2, num_rules, dropout = dropout, rule_mlp_dim = rule_mlp_dim)
		self.rule_linear = GroupLinearLayer(rule_dim + hidden_dim, hidden_dim, num_rules)
		self.rule_relevant_variable_mlp = GroupMLP(2 * hidden_dim, hidden_dim, num_rules)
		self.interaction_mlp = GroupMLP(2*hidden_dim, hidden_dim, num_rules)
		self.variables_select = MultiHeadAttention(n_head=4, d_model_read= hidden_dim, d_model_write = hidden_dim , d_model_out = hidden_dim,  d_k=32, d_v=32, num_blocks_read = 1, num_blocks_write = num_variables, topk = 3, grad_sparse = False)

		self.variables_select_1 = SelectAttention(hidden_dim, hidden_dim, d_k = 16, num_read = 1, num_write = num_variables)

		self.phase_1_mha = MultiHeadAttention(n_head = 1, d_model_read = 2 * hidden_dim * num_variables, d_model_write = hidden_dim, d_model_out = hidden_dim, d_k = 64, d_v = 64, num_blocks_read = 1, num_blocks_write = num_rules, topk = num_rules, grad_sparse = False)

		self.variable_mlp = MLP(2 * hidden_dim, hidden_dim)
		num = [i for i in range(num_variables)]
		num_comb = len(list(itertools.combinations(num, r = 2)))
		self.phase_2_mha = MultiHeadAttention(n_head = 1, d_model_read = hidden_dim, d_model_write = hidden_dim, d_model_out = hidden_dim, d_k = 32, d_v = 32, num_blocks_read = num_comb, num_blocks_write = 1, topk = 1, grad_sparse = False )
		self.variable_mlp_2 = GroupMLP(3 * hidden_dim, hidden_dim, num_variables)


		#--------Compositonal Search Based Rule Application---------------------------------------
		r = 2
		self.rule_probabilities = []
		self.variable_probabilities = []
		self.r = r
		self.variable_combinations = torch.combinations(torch.tensor([i for i in range(num_variables)]), r = r, with_replacement = True)
		self.variable_combinations_mlp = MLP(r * hidden_dim, hidden_dim)
		self.variable_rule_mlp = MLP(3 * hidden_dim, hidden_dim)
		self.selecter = SelectAttention(hidden_dim, hidden_dim, d_k = 16, num_read = num_rules, num_write = len(self.variable_combinations))
		self.use_rules = MLP(num_variables * hidden_dim, 2)
		self.transform_combinations = MLP(len(self.variable_combinations) * hidden_dim, hidden_dim)
		self.selecter_1 = SelectAttention(hidden_dim, hidden_dim, d_k = 16, num_read = 1, num_write = num_rules)
		self.selecter_2 = SelectAttention(hidden_dim, hidden_dim, d_k = 16, num_read = 1, num_write = len(self.variable_combinations))
		self.variable_rule_group_mlp = GroupMLP( hidden_dim, hidden_dim, num_rules)
		if self.design_config['selection'] == 'gumble':
			print('using gumble for rule selection')
		else:
			print('using ArgMax for rule selction')

		print('Using application option ' + str(self.design_config['application_option']))

		self.gumble_temperature = 1.0



		### MULTIMNIST stuff
		self.rule_select_ = SelectAttention(3 * hidden_dim, rule_dim, d_k = 32, num_read = 1, num_write = num_rules, share_query = True, share_key = True)
		self.variables_select_ = SelectAttention(rule_dim, hidden_dim, d_k = 32, num_read = 1, num_write = num_variables, share_key = False)
		self.project_rule_ = nn.Linear(rule_dim, hidden_dim)

	def transpose_for_scores(self, x, num_attention_heads, attention_head_size):
	    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
	    x = x.view(*new_x_shape)
	    return x.permute(0, 2, 1, 3)

	def forward(self, hidden1, hidden, prev_mask = None, message_to_rule_network = None, rule_mask = None, gumbel = True):
		#if not self.design_config['grad']:
		#if str(self.design_config['application_option']).split('.')[1] == '0':
		#	hidden = hidden.detach()
		batch_size, num_variables, _ = hidden.size()
		num_rules = self.rule_embeddings.size(1)
		num_rules = self.rule_embeddings.size(1)
		rule_emb_orig = self.rule_embeddings.repeat(batch_size, 1, 1)

		#print(rule_emb)
		rule_emb = rule_emb_orig
		#rule_emb = self.positional_encoding(rule_emb)
		
		#start_index = [0, num_variables]
		#transformer_input = torch.cat((hidden, rule_emb), dim = 1)
		#rule_emb = self.rule_transform(rule_emb)
		#hidden_ = self.variable_transform(hidden)

		scores = self.variable_rule_select(rule_emb, hidden)
		scores = self.dropout(scores)
		self.rule_prob = torch.softmax(scores.detach().clone()[0].reshape(-1), dim = 0).cpu().numpy()
		

		if self.training:
			#topk_mask = topk(scores.reshape(batch_size, -1), k = self.topk1)
			#mask = masked_gumbel_softmax(scores.reshape(batch_size, -1) * topk_mask, topk_mask, dim = 1, tau = 0.5, hard = True, grad_flow = self.grad_flow1)
			mask = torch.nn.functional.gumbel_softmax(scores.reshape(batch_size, -1), dim = 1, tau = 1.0, hard = True)
			self.rule_probabilities.append(mask.clone().reshape(batch_size, num_rules, num_variables).detach())
			probs = torch.softmax(scores.reshape(batch_size, -1), dim = 1)

			probs_ = scores.sum(dim = 1)
			probs_ = torch.softmax(probs_, dim = 1)

			probs = probs * mask

			probs = torch.mean(probs, dim = 0)

			entropy = num_rules * num_variables * torch.sum(probs)


			mask = mask.reshape(batch_size, num_rules, num_variables)
			stat_mask = torch.sum(mask, dim = 0)
			mask = mask.permute(0, 2, 1)
			scores = scores.permute(0, 2, 1).float()
			
		else:
			probs_ = None
			mask = ArgMax().apply(scores.reshape(batch_size, -1)).reshape(batch_size, num_rules, num_variables)
			mask = mask.permute(0, 2, 1)
			scores = scores.permute(0, 2, 1).float()
			self.rule_probabilities.append(torch.softmax(scores.reshape(batch_size, -1), dim = 1).reshape(batch_size, num_variables, num_rules).clone().detach())
			entropy = 0
			mask_print = mask
		
		if str(self.design_config['application_option']).split('.')[0] == '3' and str(self.design_config['application_option']).split('.')[3] == '0':
			#print('old one')
			variable_mask = torch.sum(mask, dim = 2).unsqueeze(-1)
			rule_mask = torch.sum(mask, dim = 1).unsqueeze(-1)
			#if self.training:
			#	hook_hidden.mask = variable_mask
			# using gumbel for training but printing argmax
			rule_mask_print = torch.sum(mask, dim = 1).detach()
			variable_mask_print = torch.sum(mask, dim = 2).detach()

			self.rule_activation.append(torch.argmax(rule_mask_print, dim = 1).detach().cpu().numpy())
			self.variable_activation.append(torch.argmax(variable_mask_print, dim = 1).detach().cpu().numpy())
			selected_variable = torch.sum(hidden * variable_mask, dim = 1).unsqueeze(1).repeat(1, mask.size(2), 1)
			rule_mlp_input = selected_variable #torch.cat((rule_emb_orig, selected_variable), dim = 2)
			rule_mlp_output = self.rule_mlp(rule_mlp_input)
			rule_mlp_output = torch.sum(rule_mlp_output * rule_mask, dim = 1).unsqueeze(1)

			relevant_variables, _, _ = self.variables_select(rule_mlp_output, hidden, hidden)
			rule_mlp_output = torch.cat((rule_mlp_output,relevant_variables), dim = 2)
			rule_mlp_output = rule_mlp_output.repeat(1, num_rules, 1)
			rule_mlp_output = self.rule_relevant_variable_mlp(rule_mlp_output)
			rule_mlp_output = torch.sum(rule_mlp_output * rule_mask, dim = 1).unsqueeze(1)			

			rule_mlp_output = rule_mlp_output.repeat(1, hidden.size(1), 1)
			rule_mlp_output = rule_mlp_output * variable_mask
			return rule_mlp_output, entropy
		elif str(self.design_config['application_option']).split('.')[0] == '3' and str(self.design_config['application_option']).split('.')[3] == '1':
			#print('new one')
			variable_mask = torch.sum(mask, dim = 2).unsqueeze(-1)
			rule_mask = torch.sum(mask, dim = 1).unsqueeze(-1)
			#if self.training:
			#	hook_hidden.mask = variable_mask
			# using gumbel for training but printing argmax
			rule_mask_print = torch.sum(mask, dim = 1).detach()
			variable_mask_print = torch.sum(mask, dim = 2).detach()

			self.rule_activation.append(torch.argmax(rule_mask_print, dim = 1).detach().cpu().numpy())
			self.variable_activation.append(torch.argmax(variable_mask_print, dim = 1).detach().cpu().numpy())

			selected_variable = torch.sum(hidden * variable_mask, dim = 1).unsqueeze(1)

			#print(selected_variable.size())
			#print(hidden.size())
			relevant_variables_attn = self.variables_select_1(selected_variable, hidden)

			relevant_variables_attn = self.dropout(relevant_variables_attn)

			relevant_variables_attn = relevant_variables_attn.squeeze(1)
			if self.training:
				#topk_mask = topk(relevant_variables_attn, k = self.topk2)
				#relevant_variables_mask = masked_gumbel_softmax(relevant_variables_attn * topk_mask, topk_mask, dim = 1, tau = 0.5, hard = True, grad_flow = self.grad_flow2)
				relevant_variables_mask = torch.nn.functional.gumbel_softmax(relevant_variables_attn, dim = 1, hard = True, tau = 1.0)
			else:
				relevant_variables_mask = ArgMax().apply(relevant_variables_attn)

			relevant_variable = hidden * relevant_variables_mask.unsqueeze(-1)
			relevant_variable = torch.sum(relevant_variable, dim = 1)
			selected_variable = selected_variable.squeeze(1)
			h_dim = selected_variable.size(1)
			relevant_variable = relevant_variable[:, :h_dim // 3].unsqueeze(1).repeat(1, rule_emb.size(1), 1)
			selected_variable = selected_variable[:, :h_dim // 3].unsqueeze(1).repeat(1, rule_emb.size(1), 1)

			#selected_rule = torch.sum(rule_emb_orig * rule_mask, dim = 1).unsqueeze(1)
			#selected_rule = selected_rule.repeat(1, rule_emb_orig.size(1), 1)

			mlp_input = torch.cat((selected_variable, relevant_variable),dim = -1)

			
			rule_mlp_output = torch.sum(self.rule_mlp(mlp_input) * rule_mask, dim = 1).unsqueeze(1)
			rule_mlp_output = rule_mlp_output.repeat(1, hidden.size(1), 1)

			rule_mlp_output = rule_mlp_output * variable_mask

			return rule_mlp_output, entropy, variable_mask, probs_
		
	def reset_activations(self):
		self.rule_activation = []
		self.variable_activation = []
		self.rule_probabilities = []
		self.variable_probabilities = []

	def reset_bias(self):
		self.biases = np.zeros((num_rules, num_variables))

if __name__ == '__main__':
	model = RuleNetwork(6, 4).cuda()


	hiddens = torch.autograd.Variable(torch.randn(3, 4, 6), requires_grad = True).cuda()
	new_hiddens = model(hiddens)


	hiddens.retain_grad()
	new_hiddens.backward(torch.ones(hiddens.size()).cuda())

	#print(model.rule_embeddings.grad)
	#print(model.query_layer.w.grad)
