import torch 
import torch.nn as nn 

from var_red_gfn.utils import ForwardPolicyMeta  

class ForwardPolicy(ForwardPolicyMeta): 

    def __init__(self, seq_size, src_size, hidden_dim, num_layers, device='cpu', eps=.3): 
        super(ForwardPolicy, self).__init__(eps=eps, device=device) 
        self.seq_size = seq_size 
        self.src_size = src_size 
        self.hidden_dim = hidden_dim 
        self.device = device 
        self.num_layers = num_layers 

        self.mlp = nn.Sequential() 
        for idx in range(self.num_layers): 
            self.mlp.append(nn.Linear(hidden_dim if idx >= 1 else self.seq_size + 1, hidden_dim)) 
            self.mlp.append(nn.LeakyReLU()) 
        self.mlp = self.mlp.to(self.device) 

        self.mlp_logits = nn.Linear(hidden_dim, self.src_size + 1).to(self.device) 
        self.mlp_gflows = nn.Linear(hidden_dim, 1).to(self.device) 
    
    def get_latent_emb(self, batch_state): 
        state = batch_state.state.type(torch.get_default_dtype()) 
        embed = self.mlp(state) 
        return embed 

    def get_pol(self, latent_emb, mask): 
        logits = self.mlp_logits(latent_emb)  
        gflows = self.mlp_gflows(latent_emb).squeeze(dim=-1) 
        logits = logits * mask + self.masked_value * (1 - mask)     
        pol = torch.softmax(logits, dim=-1) 
        return pol, gflows 
    
class BackwardPolicy(nn.Module): 

    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 

    def forward(self, batch_state, actions=None): 
        return torch.zeros((batch_state.batch_size,), device=self.device), \
            torch.zeros((batch_state.batch_size,), device=self.device) 
