from pettingzoo.mpe import simple_push_v3
import random
import argparse
import itertools
import torch
import numpy as np
import itertools
import datetime
import time
from sac.sac import SAC
from torch.utils.tensorboard import SummaryWriter
from sac.replay_memory import ReplayMemory

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env_name', default="simple_push_v3",
                    help='environment (default: simple-push-v3)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=False,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=123456, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--num_steps', type=int, default=500001, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=1000000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_true", default=True,
                    help='run on CUDA (default: False)')
parser.add_argument('--kernel', default=None,
                    help='kernel Type: rf | nystrom | None (default: None)')
parser.add_argument('--sigma', default=0,
                    help='sigma of noise (default: 1)')
parser.add_argument('--m', type = int, default=64,
                    help='number of features (default: 64)')
args = parser.parse_args()

np.random.seed(args.seed)







eps=0.1
env = simple_push_v3.env(render_mode="human", continuous_actions=True)
env.reset(seed=42)





players = dict()

for agent in env.world.agents:
    players[agent.name] = SAC(env.state().shape[0], env.action_space(agent.name),  env, args)



# Tesnorboard
if args.kernel:
    writer = SummaryWriter('tensorboards/{}_SAC_{}_{}_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,args.kernel,
                                                             args.policy,  args.m,args.sigma))
else:
    writer = SummaryWriter('tensorboards/{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
                                                             args.policy,  args.sigma))
# Memory
memory = ReplayMemory(args.replay_size, args.seed)

agent_names = []
for agent in env.world.agents:
    agent_names.append(agent.name)
# Training Loop
total_numsteps = 0
updates = 0
for i_episode in itertools.count(1):
    episode_steps = 0
    env.reset()
    observation, reward, termination, truncation, info = env.last()
    done =termination or truncation
    flag = 0
    start_time = time.time()
    for agent in env.agent_iter():
        state = env.state()
        observation, reward, termination, truncation, info = env.last()
        done = termination or truncation

        if done:
            action = None
        else:
            random_number = random.random()
            if random_number>eps:
                action = players[agent].select_action(state)
            else:
                action = env.action_space(agent).sample()
        
        next_idx = env.next_idx()

        env.step(action)
        next_state = env.state()
        observation, reward, termination, truncation, info = env.last()
        if agent == agent_names[0]:
            a1 = action
            r1 = reward
        else:
            a2 = action
            r2 = reward
            r1 = - r2
        
        if next_idx == 0:
            episode_steps += 1
            total_numsteps += 1
            
            observation, reward, termination, truncation, info = env.last()
            done = termination or truncation
            mask = env.mask(episode_steps, done)
            memory.push(state, a1, a2, r2, r1, next_state, mask) # Append transition to memory
            if done:
                break
            if len(memory) > args.batch_size:
            # Number of updates per step in environment
                for i in range(args.updates_per_step):
                    # Update parameters of all the networks
                    
                    # start_time = time.time()
                    critic_loss_1, policy_loss_1, ent_loss_1, alpha_1 = players[agent_names[0]].update_parameters1(memory, args.batch_size, updates,players[agent_names[1]])
                    critic_loss_2, policy_loss_2, ent_loss_2, alpha_2 = players[agent_names[1]].update_parameters2(memory, args.batch_size, updates,players[agent_names[0]])
                        
                    # end_time = time.time()
                    # print("update time:",end_time-start_time)

                    writer.add_scalar('loss/critic_1', critic_loss_1, updates)
                    writer.add_scalar('loss/policy_1', policy_loss_1, updates)
                    writer.add_scalar('loss/entropy_loss_1', ent_loss_1, updates)
                    writer.add_scalar('entropy_temprature/alpha_1', alpha_1, updates)
                    writer.add_scalar('loss/critic_2', critic_loss_2, updates)
                    writer.add_scalar('loss/policy_2', policy_loss_2, updates)
                    writer.add_scalar('loss/entropy_loss_2', ent_loss_2, updates)
                    writer.add_scalar('entropy_temprature/alpha_2', alpha_2, updates)
                    updates += 1
                    
    if total_numsteps > args.num_steps:
        break
    if i_episode % 2000 == 0:
        for agent in env.world.agents:
            players[agent.name].save_checkpoint(agent_name = agent.name, suffix = str(i_episode))
        end_time = time.time()
        print("Episode: {}, total numsteps: {}, episode steps: {}, updates: {}".format(i_episode, total_numsteps, episode_steps, updates))

    if i_episode % 10 == 0 and args.eval is True:
        avg_reward = 0.
        episodes = 10
        for _  in range(episodes):
            state = env.reset()
            episode_reward = 0
            done = False
            while not done:
                action = agent.select_action(state, evaluate=True)

                next_state, reward, done, _ = env.step(action)
                episode_reward += reward


                state = next_state
            avg_reward += episode_reward
        avg_reward /= episodes


        writer.add_scalar('avg_reward/test', avg_reward, i_episode)

        print("----------------------------------------")
        print("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
        print("----------------------------------------")

env.close()