import gym
import numpy as np

from action_masking.benchmark.benchmark_seeker import BenchmarkSeeker
from action_masking.provably_safe_env.envs.seeker_circle_env import (
    SeekerCircleEnv,
)
from action_masking.sb3_contrib.common.maskable.utils import (
    array_to_generator_center,
    generator_center_to_array,
)
from action_masking.util.util import (
    Algorithm,
    Approach,
    ContMaskingMode,
    TransitionTuple,
    get_policy,
    load_configs_from_dir,
)

# I have to import all of this to avoid the circular import lol


def main():
    alg = Algorithm.PPO
    policy = get_policy(alg)

    _env = gym.make(
        "SeekerCircleEnv-v0", seed=1, render_mode="human", render_safe_input_set=True
    )

    config = load_configs_from_dir(path="hyperparams/")
    env_name = "SeekerCircleEnv"
    masking_mode = ContMaskingMode.Generator

    benchmark = BenchmarkSeeker(config=config[env_name])
    env = benchmark.create_env(
        "SeekerCircleEnv-v0",
        space=_env.action_space,
        approach=Approach.Masking,
        transition_tuple=TransitionTuple.Naive,
        continuous_action_masking_mode=masking_mode,
        render_mode="human",
        render_safe_input_set=True,
        randomize=True,
    )

    model = alg.value.load(
        "models/SeekerCircleEnv/Masking/Naive/Continuous/Generator/PPO/1",
        env=env,
    )
    obs = env.reset()

    rewards = []
    for i in range(1000):
        safe_set_zono = env.envs[0].get_safe_space()
        safe_set = None
        if masking_mode == ContMaskingMode.ConstrainedNormal:
            safe_set = generator_center_to_array(safe_set_zono.G, safe_set_zono.c)

        action, _states = model.predict(obs, deterministic=True, action_masks=safe_set)
        last_pos = obs[0, :2]
        obs, reward, done, info = env.step(action)

        # print(f"Agent action {np.array2string(action, precision=3)}")
        # print(f"Mask  action {np.array2string(obs[0, :2] - last_pos, precision=3)}")

        rewards.append(reward)

        if done:
            if reward > 0:
                print("Goal reached!")
            else:
                if reward == -10:
                    print("Collision!")
                else:
                    print("Time limit reached!")
            print(f"Total reward: {np.sum(rewards)}")
            rewards = []
            print()
            env.reset()


if __name__ == "__main__":
    main()
