from typing import List, Optional, Any
from dataclasses import dataclass

from omegaconf import OmegaConf, MISSING

from .decoder_with_xattn.configs import DecoderWithCrossAttentionConfig, MultiBandDecoderWithCrossAttentionConfig
from .data_encoders.configs import DataEncoderConfig

__all__ = [
    "DataEncoderConfig",
    "GridCoordSamplerConfig",
    "RayCoordSamplerConfig",
    "DecoderWithCrossAttentionConfig",
    "MultiBandDecoderWithCrossAttentionConfig",
    "TransformerConfig",
]


@dataclass
class GridCoordSamplerConfig:
    data_type: str = "image"
    coord_range: List[float] = MISSING
    train_strategy: Optional[str] = MISSING
    val_strategy: Optional[str] = MISSING


@dataclass
class RayCoordSamplerConfig:
    data_type: str = "nvs"
    coord_range: List[float] = MISSING
    num_points_per_ray: int = 128
    train_strategy: Optional[str] = MISSING
    val_strategy: Optional[str] = MISSING


@dataclass
class AttentionBlockConfig:
    embed_dim: int = MISSING
    n_head: int = MISSING
    mlp_bias: bool = True
    attn_bias: bool = True
    attn_pdrop: float = 0.0
    resid_pdrop: float = 0.1
    gelu: str = "v1"


@dataclass
class TransformerConfig:
    n_layer: int = MISSING
    embed_dim: int = 768
    mask: bool = False
    block: AttentionBlockConfig = AttentionBlockConfig()

    use_cross_attention: bool = False
    shared_block: bool = False
    block_type: str = "cross_self"  # cross_self | self_cross | cross | cross_via_qk_fusion
    ln_ctx: bool = False

    use_input_pe: bool = True
