from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List
from omegaconf import MISSING


@dataclass
class DataConfig:
    name: str = MISSING
    num_workers: int = 40
    # num views
    num_views: int = MISSING


@dataclass
class LogConfig:
    wandb_entity: str = "XXX"
    # mv_mimic_mps,
    wandb_group: str = "XXX"
    wandb_run_name: str = ""
    wandb_project_name: str = "XXX"
    wandb_log_freq: int = 30
    wandb_offline: bool = True
    wandb_local_instance: bool = False

    # logs
    dir_logs: str = "/XXX/XXX/XXX/XXX"


@dataclass
class MimicCXRDataConfig(DataConfig):
    name: str = "mimic_cxr"

    # num views = 2 : lateral (LATERAL + LL) and frontal (AP + PA)
    num_views: int = 2
    dir_data: str = "XXX"
    dir_cache: str = "XXX"
    dir_clfs_base: str = "XXX"
    suffix_clfs: str = "mimic_clf"
    use_cache: bool = True

    # split settings
    splitting_method: str = "random"
    train_val_split: float = 0.8
    test_val_split: float = 0.5
    split_seed: int = 0
    # one_frontal_one_lateral or all_combi_no_missing
    studies_policy: str = "all_combi_no_missing"
    reduced_dataset: bool = False

    # labels
    target_list: List[str] = field(
        default_factory=lambda: [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Enlarged Cardiomediastinum",
            "Fracture",
            "Lung Lesion",
            "Lung Opacity",
            "No Finding",
            "Pleural Effusion",
            "Pleural Other",
            "Pneumonia",
            "Pneumothorax",
            "Support Devices",
        ]
    )
    n_clfs_outputs: int = 14
    num_labels: int = 14

    img_size: int = 224  # use 224
    image_channels: int = 1

    # copied from celeba - text not used
    num_layers_img: int = 5
    filter_dim_img: int = 64
    filter_dim_text: int = 64
    beta_img: float = 1.0
    beta_text: float = 1.0
    skip_connections_img_weight_a: float = 1.0
    skip_connections_img_weight_b: float = 1.0

    use_rec_weight: bool = True
    include_channels_rec_weight: bool = False

    # img settings
    img_RGB: bool = False  # set to False if you want to use grayscale images
    pre_load_images: bool = False


@dataclass
class PolyMNISTConfig(DataConfig):
    num_views: int = 3
    num_workers: int = 8
    dir_data_base: str = "INSERT PATH"
    dir_clfs_base: str = (
        "XXX"
    )


@dataclass
class PMvanillaDataConfig(PolyMNISTConfig):
    name: str = "PM_vanilla"
    suffix_data_train: str = "PolyMNIST_vanilla/train"
    suffix_data_test: str = "PolyMNIST_vanilla/test"
    suffix_clfs: str = "vanilla_resnet"


@dataclass
class PMtranslatedData50Config(PolyMNISTConfig):
    name: str = "PM_translated_50"
    suffix_data_train: str = "PolyMNIST_translated_50/train"
    suffix_data_test: str = "PolyMNIST_translated_50/test"
    suffix_clfs: str = "translatedl50_resnet"


@dataclass
class PMtranslatedData55Config(PolyMNISTConfig):
    name: str = "PM_translated_55"
    suffix_data_train: str = "PolyMNIST_translated_55/train"
    suffix_data_test: str = "PolyMNIST_translated_55/test"
    suffix_clfs: str = "translatedl55_resnet"


@dataclass
class PMtranslatedData60Config(PolyMNISTConfig):
    name: str = "PM_translated_60"
    suffix_data_train: str = "PolyMNIST_translated_60/train"
    suffix_data_test: str = "PolyMNIST_translated_60/test"
    suffix_clfs: str = "translated60_resnet"


@dataclass
class PMtranslatedData65Config(PolyMNISTConfig):
    name: str = "PM_translated_65"
    suffix_data_train: str = "PolyMNIST_translated_65/train"
    suffix_data_test: str = "PolyMNIST_translated_65/test"
    suffix_clfs: str = "translated65_resnet"


@dataclass
class PMtranslatedData70Config(PolyMNISTConfig):
    name: str = "PM_translated_70"
    suffix_data_train: str = "translated_70/train"
    suffix_data_test: str = "translated_70/test"
    suffix_clfs: str = "translated70_resnet"


@dataclass
class PMtranslatedData75Config(PolyMNISTConfig):
    name: str = "PM_translated75"
    suffix_data_train: str = "PolyMNIST_translated_scale075/train"
    suffix_data_test: str = "PolyMNIST_translated_scale075/test"
    suffix_clfs: str = "translated75_resnet"


@dataclass
class PMtranslatedData50FixedConfig(PolyMNISTConfig):
    name: str = "PM_translated_50_fixed"
    suffix_data_train: str = "PolyMNIST_translated_50_fixed/train"
    suffix_data_test: str = "PolyMNIST_translated_50_fixed/test"
    suffix_clfs: str = "translated_50_fixed_resnet"


@dataclass
class PMrotatedDataConfig(PolyMNISTConfig):
    name: str = "PM_rotated"
    suffix_data_train: str = "PolyMNIST_rotated/train"
    suffix_data_test: str = "PolyMNIST_rotated/test"
    suffix_clfs: str = "rotated_resnet"


@dataclass
class CelebADataConfig(DataConfig):
    name: str = "celeba"
    num_views: int = 2
    dir_data: str = "INSERT PATH"
    dir_alphabet: str = (
        "XXX"
    )
    dir_clf: str = (
        "XXX"
    )

    len_sequence: int = 256
    random_text_ordering: bool = True
    random_text_startindex: bool = False
    img_size: int = 64
    image_channels: int = 3
    crop_size_img: int = 148
    n_clfs_outputs: int = 40
    num_labels: int = 40

    num_features: int = 41  # len(alphabet)
    num_layers_text: int = 7
    num_layers_img: int = 5
    filter_dim_img: int = 128
    filter_dim_text: int = 128
    skip_connections_weight_a: float = 1.0
    skip_connections_weight_b: float = 1.0

    use_rec_weight: bool = True
    include_channels_rec_weight: bool = True


@dataclass
class ModelConfig:
    device: str = "cuda"
    batch_size: int = 128
    lr: float = 1e-3
    epochs: int = 250


@dataclass
class MyClfConfig:
    seed: int = 1
    checkpoint_metric: str = "val/loss/mean_auroc"
    model: ModelConfig = MISSING
    log: LogConfig = MISSING
    dataset: DataConfig = MISSING
