import torch 
import torch.nn as nn 

from var_red_gfn.utils import Environment 

class LogReward(nn.Module): 

    def __init__(self, src_size, seq_size, seed, device='cpu'): 
        super(LogReward, self).__init__() 
        self.device = device 
        self.src_size = src_size 
        self.seq_size = seq_size

        g = torch.Generator(device=self.device) 
        g.manual_seed(seed) 
        self.val = 2 * torch.rand((self.src_size + 1), device=self.device, generator=g) - 1
        self.val[-1] = 0. # For the index corresponding to padding 
        
        g = torch.Generator(device=self.device) 
        g.manual_seed(seed + 1)
        self.pos_val = 2 * torch.rand((self.seq_size + 1), device=self.device, generator=g) - 1 

    @torch.no_grad() 
    def forward(self, batch_state): 
        mask = (batch_state.state != self.src_size).long()  
        log_rewards = (self.val[batch_state.state] * mask * self.pos_val).sum(dim=1) 
        return 3 * log_rewards 

class Sequences(Environment): 

    def __init__(self, seq_size, src_size, batch_size, log_reward, device='cpu'): 
        super(Sequences, self).__init__(batch_size, seq_size + 1, log_reward, device) 
        self.seq_size = seq_size 
        self.src_size = src_size 
        # A token defining the EoS 
        self.state = torch.ones((self.batch_size, self.seq_size + 1), dtype=torch.long, device=self.device) 
        self.state *= self.src_size 
        self.curr_idx = torch.zeros((self.batch_size,), dtype=torch.long, device=self.device) 
        self.max_idx = torch.ones_like(self.curr_idx) * (self.seq_size - 1) 

        self.forward_mask = torch.ones((self.batch_size, self.src_size + 1), device=self.device)
        self.backward_mask = torch.zeros((self.batch_size, 1), device=self.device)  
        self.traj_size = torch.ones((self.batch_size,), device=self.device) 

    @torch.no_grad() 
    def apply(self, indices): 
        is_stop_action = (indices == self.src_size) 
        indices_non_stop = indices[~is_stop_action] 
        batchid_non_stop = self.batch_ids[~is_stop_action] 
        curridx_non_stop = self.curr_idx[~is_stop_action] 

        self.state[batchid_non_stop, curridx_non_stop] = indices_non_stop  

        # Mask actions corresponding to stopped and filled states 
        umask = self.forward_mask.clone() 
        umask[is_stop_action, :-1] = 0. 
        umask[self.curr_idx == self.seq_size - 1, :-1] = 0. 
        self.forward_mask = umask.clone() 

        self.stopped = is_stop_action.long() 
        self.curr_idx += (~is_stop_action).long() 
        self.traj_size = self.curr_idx + 1 
        self.curr_idx = torch.minimum(self.curr_idx, self.max_idx)  
        self.is_initial[:] = (self.curr_idx == 0).long() 
        self.backward_mask = (1 - self.is_initial).view(-1, 1)
        return (self.stopped < 2.) 
  
    @torch.no_grad() 
    def backward(self, indices): 
        is_non_initial = ~(self.is_initial == 1)  
        self.curr_idx -= (is_non_initial & (self.stopped != 1)).long()  

        batchid_non_initial = self.batch_ids[is_non_initial] 
        curridx_non_initial = self.curr_idx[is_non_initial] 

        forward_actions = self.state[self.batch_ids, self.curr_idx].clone() 
        self.state[batchid_non_initial, curridx_non_initial] = self.src_size         
        self.forward_mask[:, :-1] = 1 

        self.stopped[:] = 0. 
        self.is_initial = (self.state == self.src_size).all(dim=1).long() 
        self.backward_mask = (1 - self.is_initial).view(-1, 1)  
        self.traj_size = self.curr_idx + 1
        
        return forward_actions 
    
    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.state = torch.vstack([self.state, batch_state.state]) 

    @property 
    def unique_input(self): 
        return self.state 
