import numpy as np
import torch
from tqdm import trange
from torch import nn
import copy
import time
import random 

def to_device(x, device='cuda'):
    if torch.is_tensor(x):
        return x.to(device)
    elif type(x) is dict:
        return {k: to_device(v, device) for k, v in x.items()}
    else:
        print(f'Unrecognized type in `to_device`: {type(x)}')

def batch_to_device(batch, device='cuda:0'):
    vals = [to_device(getattr(batch, field), device) for field in batch._fields]
    return type(batch)(*vals)

@torch.jit.script
def compute_kernel(x, y):
    x_size = x.shape[0]#32
    y_size = y.shape[0]#32
    dim = x.shape[1]#16

    tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1)#(32,32,16)
    tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1)#(32,32,16)

    return torch.exp(-torch.mean((tiled_x - tiled_y)**2, dim=2)/dim*1.0) # (32,32)

@torch.jit.script
def compute_mmd(x, y):#(32,16)(32,16)
    x_kernel = compute_kernel(x, x)#(32,32)
    y_kernel = compute_kernel(y, y)#(32,32)
    xy_kernel = compute_kernel(x, y)#(32,32)
    return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel)

class EMA():
    '''
        empirical moving average
    '''
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class AllTrainer():
    def __init__(
        self,
        en_model, # encoder
        de_model, # diffusion model
        optimizer,
        batch_size,
        get_batch,
        device,
        et_optimizer,
        w,
        w_std,
        w_optimizer,
        repre_type,
        phi_norm_loss_ratio,
        info_loss_weight,
    ):
        super().__init__()
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.dataloader = get_batch
        self.diagnostics = dict()
        self.en_model = en_model
        self.de_model = de_model
        self.info_loss_weight = info_loss_weight

        self.device = device
        self.ema = EMA(0.995)
        self.ema_model = copy.deepcopy(self.de_model)
        self.reset_parameters()
        self.step = 1

        self.et_optimizer = et_optimizer
        self.repre_type = repre_type
        self.w = w
        self.w_std = w_std
        self.w_optimizer = w_optimizer
        self.phi_loss = nn.MSELoss()
        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
        self.phi_norm_loss_ratio = phi_norm_loss_ratio

    def reset_parameters(self):
        self.ema_model.load_state_dict(self.de_model.state_dict())

    def step_ema(self):
        if self.step < 2000:
            self.reset_parameters()
            return
        self.ema.update_model_average(self.ema_model, self.de_model)

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):

        diffusion_losses, inv_losses, info_losses, simi_losses = [], [], [], []
        regress_losses, phi_norm_losses, w_losses, si_train_losses = [], [], [], []
        logs = dict()
        train_start = time.time()

        self.en_model.train()
        self.de_model.train()
        for i in trange(num_steps, desc='train_step'):
            if self.repre_type == 'none': # train decision diffuser
                diffusion_loss, inv_loss, regress_loss, phi_norm_loss, w_loss, si_loss, simi_loss = self.train_dd()
            elif self.repre_type == 'vec': # train pref diffuser with vector representations
                diffusion_loss, inv_loss, regress_loss, phi_norm_loss, w_loss, si_loss, info_loss = self.train_with_vec()
            elif self.repre_type == 'dist': # train pref diffuser with distributional representations
                diffusion_loss, inv_loss, regress_loss, phi_norm_loss, w_loss, si_loss, info_loss = self.train_with_dist()

            diffusion_losses.append(diffusion_loss)
            inv_losses.append(inv_loss)
            info_losses.append(info_loss)
            # simi_losses.append(si_loss)
            regress_losses.append(regress_loss)
            phi_norm_losses.append(phi_norm_loss)
            w_losses.append(w_loss)
            si_train_losses.append(si_loss)

        logs['training/time'] = time.time() - train_start
        logs['training/diffusion_loss_mean'] = np.mean(diffusion_losses)
        logs['training/diffusion_loss_std'] = np.std(diffusion_losses)
        logs['training/inv_loss_mean'] = np.mean(inv_losses)
        logs['training/inv_loss_std'] = np.std(inv_losses)
        logs['training/info_loss_mean'] = np.mean(info_losses)
        logs['training/info_loss_std'] = np.std(info_losses)
        # logs['training/simi_loss_mean'] = np.mean(simi_losses)
        # logs['training/simi_loss_std'] = np.std(simi_losses)
        logs['training/pref_loss_mean'] = np.mean(regress_losses)
        logs['training/pref_loss_std'] = np.std(regress_losses)
        logs['training/phi_norm_loss_mean'] = np.mean(phi_norm_losses)
        logs['training/phi_norm_loss_std'] = np.std(phi_norm_losses)
        logs['training/w_loss_mean'] = np.mean(w_losses)
        logs['training/w_loss_std'] = np.std(w_losses)
        logs['training/si_train_loss'] = np.mean(si_train_losses)
        logs['training/si_train_loss_std'] = np.std(si_train_losses)
        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def train_with_dist(self):
        i, j = random.sample(range(10), 2)
        batch1 = next(self.dataloader[i])
        batch2 = next(self.dataloader[i])
        batch3 = next(self.dataloader[j])
        states_1, states_2, states_3 = batch1['samples'].to(self.device), batch2['samples'].to(self.device), batch3['samples'].to(self.device) # (32, 100, 11 )
        actions_1, actions_2, actions_3 = batch1['actions'].to(self.device), batch2['actions'].to(self.device), batch3['actions'].to(self.device) # (32,100,3)
        rtg_1,rtg_2, rtg_3 = batch1['returns'].to(self.device), batch2['returns'].to(self.device), batch3['returns'].to(self.device) # (32,100)
        timesteps_1, timesteps_2, timesteps_3 = batch1['timesteps'].to(self.device), batch2['timesteps'].to(self.device), batch3['timesteps'].to(self.device) # (32,100)
        mask_1, mask_2, mask_3 = batch1['masks'].to(self.device), batch2['masks'].to(self.device), batch3['masks'].to(self.device) # (32,100)
        task_id_1, task_id_2, task_id_3 = batch1['task_ids'].to(self.device), batch2['task_ids'].to(self.device), batch3['task_ids'].to(self.device)
        # lb = (batch1['task_ids'][:,0] == chosen_task_id[0])
        # rb = (batch1['task_ids'][:,0] != chosen_task_id[0])
        lb = (rtg_1[:] - rtg_2[:]) >= 0
        rb = (rtg_2[:] - rtg_1[:]) > 0
        
        # when encoder outputs distributions
        phi_1_mean, phi_1_std = self.en_model.forward(states_1, timesteps_1, mask_1, task_id_1) # (64,16),(64,16)
        phi_2_mean, phi_2_std = self.en_model.forward(states_2, timesteps_2, mask_2, task_id_2) # (64,16),(64,16)
        phi_3_mean, phi_3_std = self.en_model.forward(states_3, timesteps_3, mask_3, task_id_3) # (64,16),(64,16)
        positive_mean = torch.cat((phi_1_mean[lb], phi_2_mean[rb]), 0) # (60, 16)
        negative_mean = torch.cat((phi_2_mean[lb], phi_1_mean[rb]), 0) # (60, 16)
        # positive_mean = phi_1_mean[lb]
        # negative_mean = phi_1_mean[rb]
        # positive_std = phi_1_std[lb]
        # negative_std = phi_1_std[rb]
        positive_std = torch.cat((phi_1_std[lb], phi_2_std[rb]), 0)
        negative_std = torch.cat((phi_2_std[lb], phi_1_std[rb]), 0)
        positive_mean = torch.cat((positive_mean, positive_mean), 0)
        negative_mean = torch.cat((negative_mean, phi_3_mean), 0)
        positive_std = torch.cat((positive_std, positive_std), 0)
        negative_std = torch.cat((negative_std, phi_3_std), 0)
        positive_dist = torch.distributions.MultivariateNormal(loc=positive_mean, covariance_matrix=torch.diag_embed(torch.exp(positive_std)))
        negative_dist = torch.distributions.MultivariateNormal(loc=negative_mean, covariance_matrix=torch.diag_embed(torch.exp(negative_std)))
        w_std = torch.clamp(self.w_std[i], min=-20, max=2)
        anchor_dist = torch.distributions.MultivariateNormal(loc=self.w[i], covariance_matrix=torch.diag_embed(torch.exp(w_std)))
        positive_kl = torch.distributions.kl.kl_divergence(anchor_dist, positive_dist).mean()
        negative_kl = torch.distributions.kl.kl_divergence(anchor_dist, negative_dist).mean()
        kl_loss = positive_kl + 1.0 / negative_kl
        anchor_mean = self.w[i].expand(positive_mean.shape[0], -1).detach()
        trip_loss = self.triplet_loss(anchor_mean, positive_mean, negative_mean)
        phi_norm_loss = self.phi_loss(torch.norm(phi_1_mean, dim=1), torch.ones(self.batch_size).to(self.device))# \
                # + self.phi_loss(torch.norm(phi_2_mean, dim=1), torch.ones(self.batch_size).to(self.device))
        pref_loss = trip_loss + kl_loss + self.phi_norm_loss_ratio * phi_norm_loss

        # update diffusion
        states = torch.cat([states_1, states_2, states_3], dim=0)
        actions = torch.cat([actions_1, actions_2, actions_3], dim=0)
        phis = torch.cat([phi_1_mean, phi_2_mean, phi_3_mean], dim=0)
        phis_std = torch.cat([phi_1_std, phi_2_std, phi_3_std], dim=0)
        # phis = phi_1_mean
        # phis_std = phi_1_std
        conditions = states[:,0,:] # condition在当前状态下，用于做planning
        trajectories = torch.concat([actions, states], dim=-1)  # 将这段state和action合并
        diff_loss, inv_loss, si_loss, info_loss = self.de_model.loss(trajectories, conditions, phis, phis_std) # compute loss
        # maximizing the mutual information between w and x_0
        # generated_phi_mean, generated_phi_std = self.de_model.generate(conditions, phis) # (batch, 16)
        # generated_phi_dist = torch.distributions.MultivariateNormal(loc=generated_phi_mean, 
        #                                                         covariance_matrix=torch.diag_embed(torch.exp(generated_phi_std)))
        # phi_dist = torch.distributions.MultivariateNormal(loc=phis, 
        #                                                       covariance_matrix=torch.diag_embed(torch.exp(phis_std)))
        # info_loss = torch.distributions.kl_divergence(generated_phi_dist, phi_dist).mean()
        # diff_loss += 0.1 * info_loss
        # # maximizing the mutual information between w* and x_{t-10}, OOD info loss
        # noise =  torch.randn(phis.shape, device=self.device)
        # anchor_mean = self.w.expand(phis.shape[0], -1) + 0.01 * noise
        # # anchor_mean = self.w.expand(phis.shape[0], -1)
        # generated_phi_mean, generated_phi_std = self.de_model.generate(conditions, anchor_mean) # (batch, 16)
        # generated_phi_dist = torch.distributions.MultivariateNormal(loc=generated_phi_mean, 
        #                                                         covariance_matrix=torch.diag_embed(torch.exp(generated_phi_std)))
        # phi_dist = torch.distributions.MultivariateNormal(loc=anchor_mean, 
        #                                                       covariance_matrix=torch.diag_embed(torch.exp(self.w_std)))
        # info_loss = torch.distributions.kl_divergence(generated_phi_dist, phi_dist).mean()

        diffusion_loss = diff_loss + inv_loss
        # diffusion_loss += self.info_loss_weight * info_loss
        diffusion_loss += pref_loss

        self.optimizer.zero_grad()
        self.et_optimizer.zero_grad()
        diffusion_loss.backward()
        self.optimizer.step()
        self.et_optimizer.step()
        if self.step % 10 == 0:
            self.step_ema()
        self.step += 1
        
        # when encoder outputs distributions
        # phi_1_mean, phi_1_std = self.en_model.forward(states_1, timesteps_1, mask_1, task_ids) # (64,16),(64,16)
        # phi_2_mean, phi_2_std = self.en_model.forward(states_2, timesteps_2, mask_2) # (64,16),(64,16)
        phi_1_mean, phi_1_std = self.en_model.forward(states_1, timesteps_1, mask_1, task_id_1) # (64,16),(64,16)
        phi_2_mean, phi_2_std = self.en_model.forward(states_2, timesteps_2, mask_2, task_id_2) # (64,16),(64,16)
        phi_3_mean, phi_3_std = self.en_model.forward(states_3, timesteps_3, mask_3, task_id_3) # (64,16),(64,16)
        positive_mean = torch.cat((phi_1_mean[lb], phi_2_mean[rb]), 0) # (60, 16)
        negative_mean = torch.cat((phi_2_mean[lb], phi_1_mean[rb]), 0) # (60, 16)
        # positive_mean = phi_1_mean[lb]
        # negative_mean = phi_1_mean[rb]
        # positive_std = phi_1_std[lb]
        # negative_std = phi_1_std[rb]
        positive_std = torch.cat((phi_1_std[lb], phi_2_std[rb]), 0)
        negative_std = torch.cat((phi_2_std[lb], phi_1_std[rb]), 0)
        positive_mean = torch.cat((positive_mean, positive_mean), 0)
        negative_mean = torch.cat((negative_mean, phi_3_mean), 0)
        positive_std = torch.cat((positive_std, positive_std), 0)
        negative_std = torch.cat((negative_std, phi_3_std), 0)
        # positive_mean = phi_1_mean[lb]
        # negative_mean = phi_1_mean[rb]
        # positive_std = phi_1_std[lb]
        # negative_std = phi_1_std[rb]
        # positive_mean = torch.cat((phi_1_mean[lb], phi_2_mean[rb]), 0) # (60, 16)
        # negative_mean = torch.cat((phi_2_mean[lb], phi_1_mean[rb]), 0) # (60, 16)
        # positive_std = torch.cat((phi_1_std[lb], phi_2_std[rb]), 0)
        # negative_std = torch.cat((phi_2_std[lb], phi_1_std[rb]), 0)
        positive_dist = torch.distributions.MultivariateNormal(loc=positive_mean, covariance_matrix=torch.diag_embed(torch.exp(positive_std)))
        negative_dist = torch.distributions.MultivariateNormal(loc=negative_mean, covariance_matrix=torch.diag_embed(torch.exp(negative_std)))
        w_std = torch.clamp(self.w_std[i], min=-20, max=2)
        anchor_dist = torch.distributions.MultivariateNormal(loc=self.w[i], covariance_matrix=torch.diag_embed(torch.exp(w_std)))
        positive_kl = torch.distributions.kl.kl_divergence(anchor_dist, positive_dist).mean()
        negative_kl = torch.distributions.kl.kl_divergence(anchor_dist, negative_dist).mean()
        kl_loss = positive_kl + 1.0 / negative_kl
        anchor_mean = self.w[i].expand(positive_mean.shape[0], -1)
        trip_loss = self.triplet_loss(anchor_mean, positive_mean, negative_mean)
        w_loss = trip_loss + kl_loss
        self.w_optimizer.zero_grad()
        w_loss.backward()
        self.w_optimizer.step()

        return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), pref_loss.detach().cpu().item(), phi_norm_loss.detach().cpu().item(), w_loss.detach().cpu().item(), si_loss.detach().cpu().item(), info_loss.detach().cpu().item()
    
    def train_with_vec(self):
        # for training pref_diffuser
        batch1 = next(self.dataloader)
        batch2 = next(self.dataloader)
        states_1, states_2 = batch1['samples'].to(self.device), batch2['samples'].to(self.device) # (32, 100, 11 )
        actions_1, actions_2 = batch1['actions'].to(self.device), batch2['actions'].to(self.device) # (32,100,3)
        rtg_1, rtg_2 = batch1['returns'].to(self.device), batch2['returns'].to(self.device) # (32,100)
        timesteps_1, timesteps_2 = batch1['timesteps'].to(self.device), batch2['timesteps'].to(self.device) # (32,100)
        mask_1, mask_2 = batch1['masks'].to(self.device), batch2['masks'].to(self.device) # (32,100)

        lb = (rtg_1[:, 0] - rtg_2[:, 0]) >= 0
        rb = (rtg_2[:, 0] - rtg_1[:, 0]) > 0

        # pref loss and phi norm loss, when representation is vector
        phi_1 = self.en_model.forward(states_1, timesteps_1, mask_1)
        phi_2 = self.en_model.forward(states_2, timesteps_2, mask_2)
        phi_norm_loss = (self.phi_loss(torch.norm(phi_1, dim=1), torch.ones(self.batch_size).to(self.device))
                + self.phi_loss(torch.norm(phi_2, dim=1), torch.ones(self.batch_size).to(self.device)))
        positive = torch.cat((phi_1[lb], phi_2[rb]), 0)
        negative = torch.cat((phi_2[lb], phi_1[rb]), 0)
        anchor = self.w.expand(positive.shape[0], -1).detach()
        trip_loss = self.triplet_loss(anchor, positive, negative)
        pref_loss = trip_loss + self.phi_norm_loss_ratio * phi_norm_loss

        # update diffusion
        states = torch.cat([states_1, states_2], dim=0)
        actions = torch.cat([actions_1, actions_2], dim=0)
        phis = torch.cat([phi_1, phi_2], dim=0)
        conditions = states[:,0,:] # condition在当前状态下，用于做planning
        trajectories = torch.concat([actions, states], dim=-1)  # 将这段state和action合并
        diff_loss, inv_loss, si_loss = self.de_model.loss(trajectories, conditions, phis) # compute loss
        
        # ood info loss
        noise =  torch.randn(phis.shape, device=self.device)
        anchor_mean = self.w.expand(phis.shape[0], -1) + 0.01 * noise
        generated_phi_mean, generated_phi_std = self.de_model.generate(conditions, anchor_mean) # (batch, 16)
        info_loss = compute_mmd(anchor_mean, generated_phi_mean)

        diffusion_loss = diff_loss + inv_loss
        diffusion_loss += self.info_loss_weight * info_loss
        diffusion_loss += pref_loss

        self.optimizer.zero_grad()
        self.et_optimizer.zero_grad()
        diffusion_loss.backward()
        self.optimizer.step()
        self.et_optimizer.step()
        if self.step % 10 == 0:
            self.step_ema()
        self.step += 1

        # update anchor vector
        phi_1 = self.en_model.forward(states_1, timesteps_1, mask_1)
        phi_2 = self.en_model.forward(states_2, timesteps_2, mask_2)
        positive = torch.cat((phi_1[lb], phi_2[rb]), 0)
        negative = torch.cat((phi_2[lb], phi_1[rb]), 0)
        anchor = self.w.expand(positive.shape[0], -1)
        w_loss = self.triplet_loss(anchor, positive, negative)
        self.w_optimizer.zero_grad()
        w_loss.backward()
        self.w_optimizer.step()

        return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), pref_loss.detach().cpu().item(), phi_norm_loss.detach().cpu().item(), w_loss.detach().cpu().item(), si_loss.detach().cpu().item(), info_loss.detach().cpu().item()

    def train_dd(self):
        # for reproducing decision diffuser
        batch = next(self.dataloader)
        states = batch['samples'].to(self.device)
        actions = batch['actions'].to(self.device)
        conditions = batch['conditions'][0].to(self.device)
        phis = batch['returns'].to(self.device)
        trajectories = torch.concat([actions, states], dim=-1)  # 将这段state和action合并
        diff_loss, inv_loss, si_loss = self.de_model.loss(trajectories, conditions, phis) # compute loss

        diffusion_loss = diff_loss + inv_loss

        self.optimizer.zero_grad()
        diffusion_loss.backward()
        self.optimizer.step()
        if self.step % 10 == 0:
            self.step_ema()
        self.step += 1

        return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), 0., 0., 0., si_loss.detach().cpu().item(), 0.0