import numpy as np
import torch
import gym
import argparse
import os
import sys

from utils import ReplayBuffer
import TD3_BC
import torch
from tqdm import trange, tqdm
# from Because.agent.icil.icil_state import ICIL as ICIL_state
from stable_baselines3.common.vec_env import SubprocVecEnv
from robosuite import make
from robosuite import load_controller_config
from collector.gym_wrapper import GymStackWrapper, GymLiftWrapper, GymLiftCausalWrapper
WRAPPER = {
    "LiftCausal": GymLiftWrapper,
    "StackCausal": GymStackWrapper,
    "CausalPick": GymLiftCausalWrapper,
}
def CUDA(x):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    return x.cuda()
def evaluate_on_env(model, mean, std, env, num_eval_ep=20,num_envs=16):

    eval_batch_size = num_envs  # required for forward pass
    envs_timestep=np.zeros(num_envs)

    results = {}
    total_reward = 0
    total_timesteps = 0
    total_succ = 0
    total_value=0
    discount=0.99

    count_done = 0
    pbar = tqdm(total=num_eval_ep)
    timestep_last = np.zeros(eval_batch_size, dtype=np.int32)
    t = 0
    with torch.no_grad():
    #     task='LiftCausal'
    #     horizon=100
    #     control_freq=20
    #     spurious_type='xnr'
    #     # Create a RoboSuite environment
    #     env = make(
    #     task,
    #     'Kinova3',
    #     horizon=horizon,
    #     control_freq=control_freq,
    #     has_renderer=False,
    #     has_offscreen_renderer=False,
    #     ignore_done=False,
    #     use_camera_obs=False,
    #     use_object_obs=True,
    #     controller_configs=load_controller_config(default_controller='OSC_POSITION'),
    #     # spurious_type='xnr',
    #     spurious_type=spurious_type,
    # )
        # env = WRAPPER[task](env)
        # env.seed(150)
        
        running_state = env.reset()
        while count_done < num_eval_ep:
            envs_timestep+=1

            # Modify the observation to fit your model's input requirements
            state = running_state  # You may need to preprocess obs here
            # columns_to_keep = list(range(3)) + list(range(7,13)) + list(range(19, 22)) + list(range(29, 33))
            # state=state[:,columns_to_keep]
            
            # print(mean.shape,std.shape)
            state = (state- mean)/(std+1e-3)
            # state=torch.from_numpy(state).float()
            
            # # Forward pass through your model to get actions
            # pri_state = CUDA(state)
            act_logits = model.select_action(state)
            # act = act_logits[:,:4].detach().cpu().numpy()
            

            # Take a step in the environment with the selected actions
            running_state, running_reward, done, info = env.step(act_logits)

            # Extract relevant information from the environment
            # total_reward += np.sum(running_reward)
            for i in range(len(done)):
                if done[i]:
                    if info[i]['success']:
                        total_reward+=1
                        total_succ +=1
                        time=envs_timestep[i]
                        value=discount**time
                        total_value+=value
                    count_done += 1
                    total_timesteps+=envs_timestep[i]
                    envs_timestep[i]=0
                    pbar.update(1)
                if count_done >= num_eval_ep:
                    break
            t += 1

    pbar.close()
    results['eval/avg_reward'] = total_reward / count_done
    results['eval/avg_ep_len'] = total_timesteps / count_done
    results['eval/success_rate'] = total_succ / count_done
    results['eval/avg_value']=total_value / count_done
    print(total_succ / count_done)
    print(total_value/count_done)

    return results

def make_envs(seed,task='LiftCausal', horizon=70,control_freq=5,spurious_type='xpr'): 
    env = make(
        task,
        'Kinova3',
        horizon=horizon,
        control_freq=control_freq,
        has_renderer=False,
        has_offscreen_renderer=False,
        ignore_done=False,
        use_camera_obs=False,
        use_object_obs=True,
        controller_configs=load_controller_config(default_controller='OSC_POSITION'),
        # spurious_type='xnr',
        spurious_type=spurious_type,
    )
    env = WRAPPER[task](env)
    env.seed(seed)
    return env

def make_pickenvs(task='CausalPick',horizon=70,control_freq=5,seed=100): 
    # print(seed)
    env = make(
        task,
        'Kinova3',
        horizon=horizon,
        control_freq=control_freq,
        has_renderer=False,
        has_offscreen_renderer=False,
        ignore_done=False,
        use_camera_obs=False,
        use_object_obs=True,
        controller_configs=load_controller_config(default_controller='OSC_POSITION'),
        num_unmovable_objects=1,
        num_random_objects=0,
        num_markers=3,
    )
    env = WRAPPER[task](env,mode="train")
    env.seed(seed)
    return env

if __name__ == '__main__':
    success_rate=np.zeros(10)
    for i in range(10):
        print(f"itr_{i}")
        parser = argparse.ArgumentParser()
        # Experiment
        parser.add_argument("--policy", default="TD3_BC")               # Policy name
        parser.add_argument("--env", default="lift")        # OpenAI gym environment name
        parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
        parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
        parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment
        parser.add_argument("--save_model", action="store_true")        # Save model and optimizer parameters
        parser.add_argument("--load_model", default="")                 # Model load file name, "" doesn't load, "default" uses file_name
        # TD3
        parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
        parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
        parser.add_argument("--discount", default=0.99)                 # Discount factor
        parser.add_argument("--tau", default=0.005)                     # Target network update rate
        parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
        parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
        parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updateseva
        # TD3 + BC
        parser.add_argument("--alpha", type=float,default=1.0)
        parser.add_argument("--normalize", default=True)
        parser.add_argument("--num_eval_ep", type=int, default=20, help='number of evaluation episode length')
        parser.add_argument("--num_envs", type=int, default=16, help='number')
        parser.add_argument("--type", type=str, default='expert', help='expert/medium/random')
        parser.add_argument("--task", type=str, default='lift', help='lift/pick')
        args = parser.parse_args()
        seed=i+1
        torch.manual_seed(seed)
        np.random.seed(seed)
        best_succ_rate = -1
        best_value=-1
        state_dim = 33
        action_dim = 4
        replay_buffer = ReplayBuffer(state_dim, action_dim)
        max_action=float(replay_buffer.convert_D4RL(file_path="dataset/height01",type=args.type))

        kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": max_action,
            "discount": args.discount,
            "tau": args.tau,
            # TD3
            "policy_noise": args.policy_noise * max_action,
            "noise_clip": args.noise_clip * max_action,
            "policy_freq": args.policy_freq,
            # TD3 + BC
            "alpha": args.alpha
        }

        # Initialize policy
        file_name = f"{args.policy}_{args.env}_{args.type}_{seed}"
        policy = TD3_BC.TD3_BC(**kwargs)
        policy_file = file_name
        policy.load(f"./models/{policy_file}")
        if args.normalize:
            mean,std = replay_buffer.normalize_states() 
        else:
            mean,std = 0,1
        seed_list = np.random.choice(list(range(0, 10000000)), 16, replace=False)
        env = SubprocVecEnv([lambda i=i: make_envs(seed=seed_list[i],horizon=30,spurious_type="xpr") for i in range(16)], start_method="spawn")

        results=evaluate_on_env(policy, mean,std, env, num_eval_ep=100,num_envs=16)
        print(results['eval/success_rate'])
        success_rate[i]=results['eval/success_rate']
    mean=np.mean(success_rate)
    std=np.std(success_rate)
    print(mean,std)
    data_without_extremes = np.delete(success_rate, [np.argmax(success_rate), np.argmin(success_rate)])
    mean_we=np.mean(data_without_extremes)
    std_we=np.std(data_without_extremes)
    print(mean_we,std_we)


