import torch.nn as nn


class OracleUncertainty(nn.Module):
    def __init__(self, oracle, lyapunov):
        super().__init__()
        self.oracle = oracle
        self.lyapunov = lyapunov

    def forward(self, states, actions):
        return self.lyapunov(self.oracle(states, actions)) + 0.01


class OracleModel(nn.Module):
    def __init__(self, make_env):
        super().__init__()
        self.env = make_env()

    def forward(self, states, actions):
        return self.env.trans_fn(states, actions)
