from vizier import pyvizier as vz
import numpy as np
import torch

class Sphere2D:
    def __init__(self, center: tuple, dim):
        self.dim = dim
        self.center = center.clone()
    
    def __call__(self, arr: torch.Tensor) -> torch.Tensor:
        if arr.dim() == 1 and self.dim == 2:
            # 如果arr是一维的且self.dim为2，则扩展arr以适应广播机制
            arr = arr.unsqueeze(0)
        if arr.dim() != 2 or arr.size(1) != self.dim:
            raise ValueError(f"Input tensor must be 2D with shape (n, {self.dim}).")
        return torch.sum((arr - self.center) ** 2, dim=1)*-1
    
centers = {
    "LunarLand":torch.tensor([0.53787638, 0.47824028, 0.91598408, 0.77314852, 
                              0.96957006, 0.84345981001,  0.46749829, 0.67018698, 
                              0.02196906, 0.58780826, 0.0536066,  0.09198336]),
    "RobotPush":torch.tensor([0.53787638, 0.47824028, 0.91598408, 0.77314852, 
                              0.96957006, 0.84345981001,  0.46749829, 0.67018698, 
                              0.02196906, 0.58780826, 0.0536066,  0.09198336]),
    "Rover": torch.tensor([0.01602924, 0.16642377, 0.18237113, 0.03353382, 0.41124133, 0.42245292,
                           0.00707353, 0.10983336, 0.41509733, 0.22277063, 0.32910567, 0.47408313,
                           0.28169449, 0.06454428, 0.1233088 , 0.38724173, 0.36914349, 0.43852703,
                           0.27432707, 0.16041871, 0.06474964, 0.48953182, 0.08822084, 0.53879451,
                           0.17309441, 0.35760608, 0.24053643, 0.47047524, 0.47467989, 0.03343116,
                           0.38284842, 0.61781459, 0.0945775 , 0.06175058, 0.22795625, 0.13281982,
                           0.02328893, 0.15478563, 0.22417962, 0.03578619, 0.05789552, 0.03693963,
                           0.19187214, 0.40362205, 0.10665334, 0.04158815, 0.54265102, 0.31012696,
                           0.33323145, 0.25422306, 0.05642867, 0.00196058, 0.57354732, 0.94900599,
                           0.54319996, 0.64702857, 0.94747889, 0.17960919, 0.62659977, 0.89840602,])}
                    

def problem_statement(search_space_id, dataset_id):
    lb, ub = 0, 1
    func = "LunarLand"
    center = 1-centers[func].clone()
    dim = center.shape[0]
    f = Sphere2D(center = torch.clamp(center, 0, 1), dim=dim)
    
    problem = vz.ProblemStatement()
    root = problem.search_space.root
    for i in range(dim):
        root.add_float_param('x{}'.format(i), lb, ub)
    metric = vz.MetricInformation(
        name='obj', goal=vz.ObjectiveMetricGoal.MAXIMIZE,
    )
    problem.metric_information.append(metric)
    return problem, f
