from typing import Optional, List
from dataclasses import dataclass


@dataclass
class MeanConfig:
    name: str = 'constant'
    hidden_features: Optional[List] = None


@dataclass
class KernelConfig:
    name: str = 'matern'
    wrapper: str = 'identity'
    hidden_features: Optional[List] = None
    out_features: Optional[int] = None


@dataclass
class TrainConfig:
    mll_opt: str = 'l-bfgs'
    mll_opt_lr: Optional[float] = None
    mll_opt_epochs: Optional[int] = None
    load_path: Optional[str] = None


@dataclass
class AcqfConfig:
    name: str = 'EI'
    acqf_opt: str = 'l-bfgs'


@dataclass
class OptimConfig:
    lr: float = 0.01


@dataclass
class PretrainConfig:
    epochs: int 
    bs: int 
    save_path: str
    shifting: float = 0.0
    eval_intervals: int = 1
    save_best: bool = True
    optim_config: OptimConfig = OptimConfig()

    def __post_init__(self):
        if isinstance(self.optim_config, dict):
            self.optim_config = OptimConfig(**self.optim_config)