import torch
import torch.nn as nn
import exp_utils as PQ


class SSampleOptimizer(nn.Module):
    def __init__(self, crabs, state_box):
        super().__init__()
        self.crabs = crabs
        self.s = nn.Parameter(torch.randn(100_000, *state_box.shape), requires_grad=False)
        self.state_box = state_box

    def hardD(self, s):
        result = self.crabs.obj_eval(s)
        return result['hard_obj']

    @torch.no_grad()
    def evaluate(self, *, step):
        self.state_box.fill_(self.s)
        s = self.s

        hardD_sample_s = self.hardD(s).max().item()
        inside = (self.crabs.L(s) <= 0).sum().item()
        PQ.writer.add_scalar('L/sample_s/hardD', hardD_sample_s, global_step=step)
        PQ.log.debug(f"[S sampler]: hardD = {hardD_sample_s:.6f}, inside = {inside}")
