import argparse
from datetime import datetime
import gym
import numpy as np
import torch
import json
import pickle
import random
from d4rl.infos import REF_MIN_SCORE, REF_MAX_SCORE
import wandb
from tqdm import trange
from diffuser.datasets.normalization import CDFNormalizer
from diffuser.utils.arrays import to_np
from diffuser.models.diffusion import GaussianInvDynDiffusion
from diffuser.models.temporal import TemporalUnet, AttTemporalUnet, TransformerNoise
import os

parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='hopper-medium-expert')
parser.add_argument('--K', type=int, default=20)
parser.add_argument('--seed', type=int, default=3333)
parser.add_argument('--max_iters', type=int, default=1000)
parser.add_argument('--z_dim', type=int, default=16)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--condition_guidance_w', type=float, default=1.2)
parser.add_argument('--load_code', type=int, required=True)
parser.add_argument('--load_step', type=int, default=200)
parser.add_argument('--n_timesteps', type=int, default=200)
parser.add_argument('--repre_type', type=str, choices=['vec', 'dist', 'vq_vec'], default='dist')
args = parser.parse_args()
variant = vars(args)

def seed(seed: int = 0):
  RANDOM_SEED = seed
  np.random.seed(RANDOM_SEED)
  torch.manual_seed(RANDOM_SEED)
  torch.cuda.manual_seed_all(RANDOM_SEED)
  random.seed(RANDOM_SEED)
seed(variant['seed']) # 0

device = variant.get('device', 'cuda')
env_name = variant['env_name']
env = gym.make(f'{env_name}-v2')
env.reset(seed=variant['seed'])
max_ep_len = 1000
scale = 1000.
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

# load dataset
dir_path = variant.get('dirpath', '.') # current path
dataset_path = f'{dir_path}/data/{env_name}-v2.pkl'
with open(dataset_path, 'rb') as f:
    trajectories = pickle.load(f)

# save all path information into separate lists
mode = variant.get('mode', 'normal')
states, traj_lens, returns, actions = [], [], [], []
for path in trajectories:
    if mode == 'delayed':  # delayed: all rewards moved to end of trajectory
        path['rewards'][-1] = path['rewards'].sum()
        path['rewards'][:-1] = 0.
    states.append(path['observations'])
    actions.append(path['actions'])
    traj_lens.append(len(path['observations']))
    returns.append(path['rewards'].sum())
traj_lens, returns = np.array(traj_lens), np.array(returns)

# used for input normalization
states = np.concatenate(states, axis=0) # [19999400, 11]
state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
state_mins, state_maxs = states.min(axis=0), states.max(axis=0)
actions = np.concatenate(actions, axis=0) # [19999400, 11]
action_mins, action_maxs = actions.min(axis=0), actions.max(axis=0)

def normalize(mins, maxs, x):
        # [ 0, 1 ]
        x = (x - mins) / (maxs - mins + 1e-20)
        # [ -1, 1 ]
        x = 2 * x - 1
        return x

def unnormalize(mins, maxs, x, eps=1e-4):
        """
        x : [ -1, 1 ]
        """
        if x.max() > 1 + eps or x.min() < -1 - eps:
            # print(f'[ datasets/mujoco ] Warning: sample out of range | ({x.min():.4f}, {x.max():.4f})')
            x = np.clip(x, -1, 1)
        # [ -1, 1 ] --> [ 0, 1 ]
        x = (x + 1) / 2.0
        return x * (maxs - mins) + mins

def eval_episodes(model, phi, state_mean, state_std):
    envs = gym.vector.make(f'{env_name}-v2', num_envs=10)
    states = envs.reset()
    dones = [False for _ in range(10)]
    s_mean = torch.from_numpy(state_mean).to(device=device)
    s_std = torch.from_numpy(state_std).to(device=device)
    episode_returns, episode_lengths = [0 for _ in range(10)], [0 for _ in range(10)]
    phi = phi[0].unsqueeze(0).repeat(10,1) # given w
    phi = torch.ones_like(phi).to(device=device, dtype=torch.float32) # all ones
    # phi = torch.zeros_like(phi).to(device=device, dtype=torch.float32) # all zeros
    phi = torch.tensor(0.9).to(device=device, dtype=torch.float32).unsqueeze(0).repeat(10,1)
    for i in trange(max_ep_len, desc='evaluation'):
        states = normalize(state_mins, state_maxs, states)
        states = torch.from_numpy(states).to(device=device, dtype=torch.float32)
        # conditions = (states - s_mean) / s_std
        conditions = states
        samples = model.conditional_sample(conditions, returns=phi)# condition在(s_t,R)上, (100,17)
        # with torch.no_grad():
        #     samples = model.dpm_sample(conditions, returns=phi)
        obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)# [s0, s1]
        obs_comb = obs_comb.reshape(-1, 2*state_dim)
        actions = model.inv_model(obs_comb)#由逆动态模型来得到action
        actions = to_np(actions)
        actions = unnormalize(action_mins, action_maxs, actions)
        states, rewards, dones, _ = envs.step(actions)
        episode_returns += rewards * (1-dones)
        episode_lengths += 1 * (1-dones)

        if i % 100 == 0:
            print('step ', i, ' return ', episode_returns, ' score ', (episode_returns - random_score) / (expert_score - random_score) * 100)

        if dones.all():
            break
    norm_ret = (episode_returns - random_score) / (expert_score - random_score) * 100
    return {
            f'target_return_mean': np.mean(episode_returns),
            f'target_return_std': np.std(episode_returns),
            f'target_norm_return_mean': np.mean(norm_ret),
            f'target_norm_return_std': np.std(norm_ret),
            f'target_length_mean': np.mean(episode_lengths),
            f'target_length_std': np.std(episode_lengths),
        }

num_timesteps = sum(traj_lens)

print('=' * 50)
print(f'Starting new experiment: {env_name}')
print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
print('=' * 50)

K = variant['K'] # 20
z_dim = variant['z_dim'] # 8
print(f'z_dim is: {z_dim}')

expert_score = REF_MAX_SCORE[f"{variant['env_name']}-v2"]
random_score = REF_MIN_SCORE[f"{variant['env_name']}-v2"]
# print(f"max score is: {expert_score}, min score is {random_score}")

noise_predictor = TemporalUnet(horizon=K, transition_dim=state_dim, cond_dim=state_dim)
# noise_predictor = AttTemporalUnet(horizon=K, transition_dim=state_dim)
# noise_predictor = TransformerNoise(horizon=K, obs_dim=state_dim)
model = GaussianInvDynDiffusion(noise_predictor, horizon=K, observation_dim=state_dim, action_dim=act_dim,
                                condition_guidance_w=variant['condition_guidance_w'],
                                n_timesteps=variant['n_timesteps']).to(device=device)
repre_type = variant['repre_type']
encoder_path = f'saved_models/encoder_dist/{env_name}-3333-20231013180041/params_100.pt'
saved_model = torch.load(os.path.join(dir_path, encoder_path), map_location=device)
w = saved_model[1]

load_code = variant['load_code']
ss = variant['seed']
t = datetime.now().strftime('%Y%m%d%H%M%S')
condition_w = variant['condition_guidance_w']
load_step = variant['load_step']

# for load_step in trange(600,1601,200, desc='eval'):
# path = f'saved_models/diffusion_model_dist/{env_name}-{ss}-{load_code}/diffusion_model_{load_step}.pt'
path = f'saved_models/all_model_vec/{env_name}-{ss}-{load_code}/diffusion_model_{load_step}.pt'
# path = f'saved_models/diffusion_model_300.pt'
saved_model = torch.load(path, map_location=device)
model.load_state_dict(saved_model[0])

logs = dict()
outputs = eval_episodes(model, w, state_mean, state_std)
for k, v in outputs.items():
    logs[f'evaluation/{k}'] = v
print(logs)