import numpy as np

from .torch_replay_buffer import TorchReplayBuffer


class HERReplayBuffer(TorchReplayBuffer):
    def __init__(self, env, max_buf_size, horizon, reward_fn, get_achieved_goal, set_goal,
                 *, p_hindsight=0.8, n_envs=1):
        super().__init__(env, max_buf_size)
        self.p_hindsight = p_hindsight
        self.reward_fn = reward_fn
        self.get_achieved_goal = get_achieved_goal
        self.set_goal = set_goal

        # how the replay buffer is aligned.
        self.horizon = horizon
        self.n_envs = n_envs

    def sample(self, n_samples=1, *, indices=None):
        assert self.length % (self.horizon * self.n_envs) == 0, 'bad alignment'

        if indices is None:
            indices = np.random.randint(len(self), size=(n_samples,), dtype=np.int64)
        batch = super().sample(n_samples, indices=indices)

        batch_indices = np.where(np.random.random(len(indices)) < self.p_hindsight)
        hindsight_indices = indices[batch_indices]

        cur_steps_in_episodes = hindsight_indices // self.n_envs % self.horizon
        delta = (np.random.rand(len(hindsight_indices)) * (self.horizon - cur_steps_in_episodes)).astype(np.int64)
        new_goals = self.get_achieved_goal(self.data['state'][hindsight_indices + delta * self.n_envs])

        batch['state'][batch_indices] = self.set_goal(batch['state'][batch_indices], new_goals)
        batch['next_state'][batch_indices] = self.set_goal(batch['next_state'][batch_indices], new_goals)

        batch['reward'] = self.reward_fn(batch['state'], batch['action'], batch['next_state'])
        return batch
