import torch
from torch.optim import AdamW

from nnunetv2.training.lr_scheduler.polylr import WarmupPolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer


class nnUNetTrainerAdamW_WD0(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.initial_lr = 1e-2
        self.weight_decay = 0
        self.warmup_steps = 20

    def configure_optimizers(self):
        optimizer = AdamW(
            self.network.parameters(),
            lr=self.initial_lr,
            weight_decay=self.weight_decay,
            amsgrad=True,
        )
        
        lr_scheduler = WarmupPolyLRScheduler(
            optimizer=optimizer,
            initial_lr=self.initial_lr,
            max_steps=self.num_epochs,
            warmup_steps=self.warmup_steps,
        )
        return optimizer, lr_scheduler


class nnUNetTrainerAdamW_QuickWD0(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.initial_lr = 1e-2
        self.weight_decay = 0
        self.warmup_steps = 8
        self.num_epochs = 500

    def configure_optimizers(self):
        optimizer = AdamW(
            self.network.parameters(),
            lr=self.initial_lr,
            betas=(0.9, 0.99),
            weight_decay=self.weight_decay,
            amsgrad=True,
        )

        lr_scheduler = WarmupPolyLRScheduler(
            optimizer=optimizer,
            initial_lr=self.initial_lr,
            max_steps=self.num_epochs,
            warmup_steps=self.warmup_steps,
        )
        return optimizer, lr_scheduler


class nnUNetTrainerAdamW_WDe5(nnUNetTrainerAdamW_WD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-5


class nnUNetTrainerAdamW_QuickWDe5(nnUNetTrainerAdamW_QuickWD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-5


class nnUNetTrainerAdamW_QuickerWDe5(nnUNetTrainerAdamW_QuickWDe5):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.num_epochs = 100


class nnUNetTrainerAdamW_WDe4(nnUNetTrainerAdamW_WD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-4


class nnUNetTrainerAdamW_QuickWDe4(nnUNetTrainerAdamW_QuickWD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-4


class nnUNetTrainerAdamW_QuickerWDe4(nnUNetTrainerAdamW_QuickWDe4):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.num_epochs = 100


class nnUNetTrainerAdamW_WDe3(nnUNetTrainerAdamW_WD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-3


class nnUNetTrainerAdamW_QuickWDe3(nnUNetTrainerAdamW_QuickWD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-3


class nnUNetTrainerAdamW_QuickerWDe3(nnUNetTrainerAdamW_QuickWDe3):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.num_epochs = 100


class nnUNetTrainerAdamW_WDe2(nnUNetTrainerAdamW_WD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-2


class nnUNetTrainerAdamW_WDe2beta0p99(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.initial_lr = 1e-2
        self.weight_decay = 1e-2
        self.warmup_steps = 20

    def configure_optimizers(self):
        optimizer = AdamW(
            self.network.parameters(),
            lr=self.initial_lr,
            betas=(0.9, 0.99),
            weight_decay=self.weight_decay,
            amsgrad=True,
        )

        lr_scheduler = WarmupPolyLRScheduler(
            optimizer=optimizer,
            initial_lr=self.initial_lr,
            max_steps=self.num_epochs,
            warmup_steps=self.warmup_steps,
        )
        return optimizer, lr_scheduler


class nnUNetTrainerAdamW_QuickWDe2(nnUNetTrainerAdamW_QuickWD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-2


class nnUNetTrainerAdamW_QuickerWDe2(nnUNetTrainerAdamW_QuickWDe2):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.num_epochs = 100


class nnUNetTrainerAdamW_WDe1(nnUNetTrainerAdamW_WD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-1


class nnUNetTrainerAdamW_QuickWDe1(nnUNetTrainerAdamW_QuickWD0):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.weight_decay = 1e-1


class nnUNetTrainerAdamW_QuickerWDe1(nnUNetTrainerAdamW_QuickWDe1):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.num_epochs = 100
