from typing import Optional, Tuple
from dataclasses import dataclass


@dataclass
class TransformerConfig:
    d_model: int = 256
    n_head: int = 4
    n_hidden: int = 1024
    dropout: float = 0.1
    n_layer: int = 4
    pos_embedding: Optional[str] = None


@dataclass
class AcqfConfig:
    name: str = 'EI'
    acqf_opt: str = 'adam'
    lr: float = 0.01
    epochs: int = 300


@dataclass
class OptimConfig:
    lr: float = 3e-4
    warmup_steps: int = 0
    weight_decay: float = 0.0


@dataclass
class PretrainConfig:
    epochs: int
    bs: int
    seq_len_range: Tuple[int]
    save_path: str
    shifting: float = 0.0
    eval_intervals: int = 1
    save_best: bool = True
    optim_config: OptimConfig = OptimConfig()

    def __post_init__(self):
        if isinstance(self.optim_config, dict):
            self.optim_config = OptimConfig(**self.optim_config)
    