import torch
import torch.nn as nn
import utils
import numpy as np
from .partition_utils import get_env_factorization, get_action_dependency


class Actor(nn.Module):
	def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim, sac, log_std_bounds, domain):
		super().__init__()

		self.sac = sac
		feature_dim = feature_dim if obs_type == 'pixels' else hidden_dim

		self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim),
								   nn.LayerNorm(feature_dim), nn.Tanh())

		policy_layers = []
		policy_layers += [
			nn.Linear(feature_dim, hidden_dim),
			nn.ReLU(inplace=True)
		]
		# add additional hidden layer for pixels
		if obs_type == 'pixels':
			policy_layers += [
				nn.Linear(hidden_dim, hidden_dim),
				nn.ReLU(inplace=True)
			]

		if self.sac:
			policy_layers += [nn.Linear(hidden_dim, 2 * action_dim)]
		else:
			policy_layers += [nn.Linear(hidden_dim, action_dim)]

		self.policy = nn.Sequential(*policy_layers)
		self.log_std_bounds = log_std_bounds

		self.domain = domain

		self.apply(utils.weight_init)

	def forward(self, obs, std):
		h = self.trunk(obs)

		if self.sac:
			mu, log_std = self.policy(h).chunk(2, dim=-1)

			# constrain log_std inside [log_std_min, log_std_max]
			log_std = torch.tanh(log_std)
			log_std_min, log_std_max = self.log_std_bounds
			log_std = log_std_min + 0.5 * (log_std + 1) * (log_std_max - log_std_min)
			std = log_std.exp()
			dist = utils.SquashedNormal(mu, std)
		else:
			mu = self.policy(h)
			mu = torch.tanh(mu)
			std = torch.ones_like(mu) * std
			dist = utils.TruncatedNormal(mu, std)
		return dist


class SkillActor(nn.Module):
	"""
	This network is very similar to the normal Actor.
	However, it first processes observations and latent skills separately
	- rarely used
	"""
	def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim, sac, log_std_bounds, skill_dim):
		super().__init__()

		self.sac = sac
		self.skill_dim = skill_dim
		feature_dim = feature_dim if obs_type == 'pixels' else hidden_dim

		obs_dim = obs_dim - skill_dim  # skill dim should be the total skill dim (i.e. skill * channel)
		self.obs_dim = obs_dim
		self.obs_trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim),
										nn.LayerNorm(feature_dim), nn.Tanh())

		self.skill_trunk = nn.Sequential(nn.Linear(skill_dim, feature_dim),
										nn.LayerNorm(feature_dim), nn.Tanh())

		policy_layers = []
		policy_layers += [
			nn.Linear(feature_dim * 2, hidden_dim),
			nn.ReLU(inplace=True)
		]

		# add additional hidden layer for pixels
		if obs_type == 'pixels':
			policy_layers += [
				nn.Linear(hidden_dim, hidden_dim),
				nn.ReLU(inplace=True)
			]

		if self.sac:
			policy_layers += [nn.Linear(hidden_dim, 2 * action_dim)]
		else:
			policy_layers += [nn.Linear(hidden_dim, action_dim)]

		self.policy = nn.Sequential(*policy_layers)
		self.log_std_bounds = log_std_bounds

		self.apply(utils.weight_init)

	def forward(self, obs, std):
		obs, skill = torch.split(obs, [self.obs_dim, self.skill_dim], dim=-1)
		obs_h = self.obs_trunk(obs)
		skill_h = self.skill_trunk(skill)
		h = torch.concat([obs_h, skill_h], dim=-1)

		if self.sac:
			mu, log_std = self.policy(h).chunk(2, dim=-1)

			# constrain log_std inside [log_std_min, log_std_max]
			log_std = torch.tanh(log_std)
			log_std_min, log_std_max = self.log_std_bounds
			log_std = log_std_min + 0.5 * (log_std + 1) * (log_std_max - log_std_min)
			std = log_std.exp()

			dist = utils.SquashedNormal(mu, std)
		else:
			mu = self.policy(h)
			mu = torch.tanh(mu)
			std = torch.ones_like(mu) * std
			dist = utils.TruncatedNormal(mu, std)
		return dist


class MCPActor(nn.Module):
	"""
	- this network should take in the observation,
	- parse the observation and feed into sub-policies different z's
			# 1. Create a nn moduleList of skills
			# 2. a weight
	- Right now, we have a state encoder, which branches out into each action module,
		in the mean time, each action module takes in an additional z vector
	# Maybe the state should also include previous actions?
	"""
	def __init__(self, obs_type, obs_skill_dim, action_dim, feature_dim, hidden_dim, sac, log_std_bounds,
				 skill_channel, skill_dim, use_gate):
		super().__init__()

		self.sac = sac
		self.use_gate = use_gate
		feature_dim = feature_dim if obs_type == 'pixels' else hidden_dim

		obs_dim = self.obs_dim = obs_skill_dim - skill_dim * skill_channel
		self.skill_dim = skill_dim
		self.skill_channel = skill_channel
		self.action_dim = action_dim

		self.primitive_state_encoder = nn.Sequential(
			nn.Linear(obs_dim, feature_dim),
			nn.ReLU(),
			nn.Linear(feature_dim, feature_dim),
			nn.ReLU(),
		)

		if self.sac:
			action_layer_size = self.action_dim * 2
		else:
			# Even for ddpg, we still output the std for primitives
			action_layer_size = self.action_dim * 2

		self.gate = nn.Sequential(
			nn.Linear(obs_skill_dim, feature_dim),
			nn.ReLU(),
			nn.Linear(feature_dim, skill_channel),
			nn.Sigmoid()
		)

		self.skill_encoders = nn.ModuleList(
			[nn.Linear(skill_dim, feature_dim) for _ in range(skill_channel)])
		self.primitives = nn.ModuleList(
			[nn.Sequential(nn.Linear(feature_dim + feature_dim, feature_dim), nn.ReLU(),
							nn.Linear(feature_dim, action_layer_size)) for _ in range(skill_channel)])

		self.log_std_bounds = log_std_bounds
		self.apply(utils.weight_init)

	def forward_primitives(self, obs_features, skill, idx):
		skill_ft = self.skill_encoders[idx](skill)
		input = torch.concat([obs_features, skill_ft], dim=-1)
		out = self.primitives[idx](input)
		mu, log_std = torch.split(out, self.action_dim, -1)
		mu = torch.tanh(mu)

		# constrain log_std inside [log_std_min, log_std_max]
		log_std = torch.tanh(log_std)
		log_std_min, log_std_max = self.log_std_bounds
		log_std = log_std_min + 0.5 * (log_std + 1) * (log_std_max - log_std_min)

		sigma = torch.ones_like(mu) * log_std.exp()
		return mu, sigma

	def forward_weights(self, obs_skill, device):
		if self.use_gate:
			weights = self.gate(obs_skill)
			# Expand on the action dimension level
			weights = weights.unsqueeze(dim=-2)
			return weights
		else:
			# TODO: change this division to be based on skill channel
			return torch.ones(self.skill_channel, device=device) * 0.5

	def forward(self, obs_skill, std):
		if len(obs_skill.shape) == 3:
			assert obs_skill.shape[0] == 1
			obs_skill = obs_skill.squeeze(0)
		obs, skills = torch.split(obs_skill, [self.obs_dim, self.skill_dim * self.skill_channel], dim=-1)
		skill_list = torch.split(skills, self.skill_dim, dim=-1)

		prim_embed = self.primitive_state_encoder(obs)

		outs = [self.forward_primitives(prim_embed, skill_list[i], i) for i in range(self.skill_channel)]
		mus, sigmas = zip(*outs)

		mus = torch.stack(mus, -1)
		sigmas = torch.stack(sigmas, -1)
		weights = self.forward_weights(obs_skill, sigmas.device)

		denom = (weights / sigmas).sum(-1)
		unnorm_mu = (weights / sigmas * mus).sum(-1)

		mean = unnorm_mu / denom

		scale_tril = 1 / denom

		# For calculating sum
		self.mus = mus
		self.sigmas = sigmas
		self.cmb_sigma = scale_tril
		self.gate_weights = weights
		self.unnorm_mu = unnorm_mu
		self.act_mean = mean

		if self.sac:
			dist = utils.SquashedNormal(mean, scale_tril)
		else:
			std = torch.ones_like(mean) * std
			dist = utils.TruncatedNormal(mean, std)

		return dist


class SeparateSkillActor(nn.Module):
	"""
	Use two separate actor networks, one for each action parts
	TODO: pass in the actions
	"""
	def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim, sac,
				 log_std_bounds, skill_channel, skill_dim):
		super().__init__()

		self.sac = sac
		feature_dim = feature_dim if obs_type == 'pixels' else hidden_dim

		obs_dim = self.obs_dim = obs_dim - skill_dim * skill_channel
		self.skill_dim = skill_dim
		self.skill_channel = skill_channel
		self.action_dim = action_dim

		self.obs_trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim),
										nn.LayerNorm(feature_dim), nn.Tanh())

		self.skill_encoders = nn.ModuleList(
			[nn.Linear(skill_dim, feature_dim) for _ in range(skill_channel)])

		self.primitives = nn.ModuleList(
			[nn.Sequential(nn.Linear(feature_dim + feature_dim, feature_dim), nn.ReLU(),
						   nn.Linear(feature_dim, 1)) for _ in range(2)])

		if self.sac:
			raise NotImplementedError
		self.log_std_bounds = log_std_bounds

		self.apply(utils.weight_init)

	def forward(self, obs, std):
		obs, skill = torch.split(obs, [self.obs_dim, self.skill_dim * self.skill_channel], dim=-1)

		skill_list = torch.split(skill, self.skill_dim, dim=-1)

		prim_embed = self.obs_trunk(obs)

		outs = [self.forward_primitives(prim_embed, skill_list[i], i) for i in range(self.skill_channel)]

		mu = torch.concat(outs, dim=-1)
		mu = torch.tanh(mu)
		dist = utils.TruncatedNormal(mu, std)
		return dist

	def forward_primitives(self, obs_features, skill, idx):
		skill_ft = self.skill_encoders[idx](skill)
		input = torch.concat([obs_features, skill_ft], dim=-1)
		out = self.primitives[idx](input)
		return out


class IndepActor(nn.Module):
	"""
	Separate the z & obs for each action output
	Requires full domain knowledge
	- currently implemented for 2d moma and particle
	"""
	def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim, sac,
				 log_std_bounds, skill_channel, skill_dim, ind_type, domain):
		super().__init__()

		self.sac = sac
		feature_dim = hidden_dim

		self.skill_dim = skill_dim
		self.skill_channel = skill_channel
		self.action_dim = action_dim
		self.obs_dim = obs_dim - skill_dim * skill_channel

		self.obs_partition, self.skill_partition, self.action_partition = get_env_factorization(
			domain, skill_dim, skill_channel)
		self.action_obs_dependency, self.action_skill_dependency = get_action_dependency(
			domain, ind_type, skill_channel)

		policy_list = []
		for i in range(len(self.action_partition)):
			action_length = self.action_partition[i]
			action_obs = self.action_obs_dependency[i]
			if self.sac:
				output_length = 2 * action_length
			else:
				output_length = action_length

			obs_length = 0
			for obs_idx in action_obs:
				obs_length += self.obs_partition[obs_idx]

			action_skill = self.action_skill_dependency[i]
			for skill_idx in action_skill:
				obs_length += self.skill_partition[skill_idx]

			policy_list.append(nn.Sequential(
				nn.Linear(obs_length, feature_dim),
				nn.ReLU(),
				nn.Linear(feature_dim, feature_dim),
				nn.ReLU(),
				nn.Linear(feature_dim, output_length)
			))

		self.primitives = nn.ModuleList(policy_list)

		self.log_std_bounds = log_std_bounds

		self.apply(utils.weight_init)

	def forward(self, obs, std):
		obs, skill = torch.split(obs, [self.obs_dim, self.skill_dim * self.skill_channel], dim=-1)

		# Next, partition both the obs and the skill
		skill_list = torch.split(skill, self.skill_partition, dim=-1)
		obs_list = torch.split(obs, self.obs_partition, dim=-1)

		if self.sac:
			outs = [torch.chunk(self.forward_primitives(
				[obs_list[obs_idx] for obs_idx in self.action_obs_dependency[i]],
				[skill_list[sk_idx] for sk_idx in self.action_skill_dependency[i]],
				i), 2, dim=-1) for i in range(len(self.action_partition))]
			mus, var = zip(*outs)
			mu = torch.concat(mus, dim=-1)
			log_std = torch.concat(var, dim=-1)
			# constrain log_std inside [log_std_min, log_std_max]
			log_std = torch.tanh(log_std)
			log_std_min, log_std_max = self.log_std_bounds
			log_std = log_std_min + 0.5 * (log_std + 1) * (log_std_max - log_std_min)
			std = log_std.exp()
			dist = utils.SquashedNormal(mu, std)
		else:
			outs = [self.forward_primitives(
				[obs_list[obs_idx] for obs_idx in self.action_obs_dependency[i]],
				[skill_list[sk_idx] for sk_idx in self.action_skill_dependency[i]],
				i) for i in range(len(self.action_partition))]
			mu = torch.concat(outs, dim=-1)
			mu = torch.tanh(mu)
			dist = utils.TruncatedNormal(mu, std)
		return dist

	def forward_primitives(self, obs_features, skill, idx):
		input = torch.concat([*obs_features, *skill], dim=-1)
		out = self.primitives[idx](input)
		return out


class AttnActor(nn.Module):
	"""
	This policy tokenize the obs and skill, and then use attention to predict next action
	This is not finished yet
	"""

	def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim, sac,
				 log_std_bounds, skill_channel, skill_dim, ind_type):
		super().__init__()

		self.sac = sac
		feature_dim = hidden_dim

		self.skill_dim = skill_dim
		self.skill_channel = skill_channel
		self.action_dim = action_dim
		self.obs_dim = obs_dim - skill_dim * skill_channel

		# PARTION IS USED TO CREATE TOKENS
		self.obs_partition = [4, 4, 4, 2, 3, 1]  # base, arm, view, base, arm, view
		self.skill_partition = [self.skill_dim] * self.skill_channel  # base, arm, view
		self.action_partition = [2, 3, 1]  # base, arm, view

		policy_list = []
		for i in range(len(self.action_partition)):
			action_length = self.action_partition[i]
			action_obs = self.action_obs_dependency[i]

			obs_length = 0
			for obs_idx in action_obs:
				obs_length += self.obs_partition[obs_idx]

			action_skill = self.action_skill_dependency[i]
			for skill_idx in action_skill:
				obs_length += self.skill_partition[skill_idx]

			policy_list.append(nn.Sequential(
				nn.Linear(obs_length, feature_dim),
				nn.ReLU(),
				nn.Linear(feature_dim, feature_dim),
				nn.ReLU(),
				nn.Linear(feature_dim, action_length)
			))

		self.primitives = nn.ModuleList(policy_list)

		if self.sac:
			raise NotImplementedError
		self.log_std_bounds = log_std_bounds
		self.apply(utils.weight_init)

	def forward(self, obs, std):
		obs, skill = torch.split(obs, [self.obs_dim, self.skill_dim * self.skill_channel], dim=-1)

		# Next, partition both the obs and the skill
		skill_list = torch.split(skill, self.skill_partition, dim=-1)
		obs_list = torch.split(obs, self.obs_partition, dim=-1)

		# TODO: get this done tomorrow - start from here
		# Input size:         # (ensemble_size, feature_dim, num_objs + 1, bs, out_dim)
		# TODO: this is how the original code handles attns
		log_attn_mask = None
		for attn in self.attns:
			# (ensemble_size, feature_dim, num_objs + 1, bs, attn_out_dim)
			sa_feature = attn(sa_feature, sa_feature, log_attn_mask=log_attn_mask)

		# outs = [self.forward_primitives(
		# 	[obs_list[obs_idx] for obs_idx in self.action_obs_dependency[i]],
		# 	[skill_list[sk_idx] for sk_idx in self.action_skill_dependency[i]],
		# 	i) for i in range(len(self.action_partition))]

		outs = None
		mu = torch.concat(outs, dim=-1)
		mu = torch.tanh(mu)
		dist = utils.TruncatedNormal(mu, std)
		return dist

