import logging
import math

import yaml
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

try:
    from util.train_utils import Mean, HingeLoss, seed_worker
    from util.utils import get_preprocess
    from dreamsim import PerceptualModel
    from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig
except:
    from dreamsim.util.train_utils import Mean, HingeLoss, seed_worker
    from dreamsim.util.utils import get_preprocess
    from dreamsim.dreamsim.model import PerceptualModel
    from dreamsim.dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig

from torch.utils.data import DataLoader
import torch
from peft import get_peft_model, LoraConfig, PeftModel
import os
import configargparse
from tqdm import tqdm

from dataset_nights.dataset import TwoAFCDataset


# log = logging.getLogger("lightning.pytorch")
# log.propagate = False
# log.setLevel(logging.INFO)


def parse_args():
    parser = configargparse.ArgumentParser()
    parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path')

    ## Run options
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--tag', type=str, default='', help='tag for experiments (ex. experiment name)')
    parser.add_argument('--log_dir', type=str, default="./logs", help='path to save model checkpoints and logs')
    parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints')

    ## Model options
    parser.add_argument('--model_type', type=str, default='dino_vitb16',
                        help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated'
                             'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, '
                             'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, '
                             'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14, dinov2_vitb14, synclr_vitb16]')
    parser.add_argument('--feat_type', type=str, default='cls',
                        help='What type of feature to extract from the model. If finetuning an ensemble, pass a '
                             'comma-separated list of features (same length as model_type). Accepted feature types: '
                             '[cls, embedding, last_layer].')
    parser.add_argument('--stride', type=str, default='16',
                        help='Stride of first convolution layer the model (should match patch size). If finetuning'
                             'an ensemble, pass a comma-separated list (same length as model_type).')
    parser.add_argument('--use_lora', type=bool, default=False,
                        help='Whether to train with LoRA finetuning [True] or with an MLP head [False].')
    parser.add_argument('--hidden_size', type=int, default=1, help='Size of the MLP hidden layer.')

    ## Dataset options
    parser.add_argument('--dataset_name', type=str, default="nights")
    parser.add_argument('--dataset_root', type=str, default="./dataset/nights", help='path to training dataset.')
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--n_samples', type=int, default=-1)
    parser.add_argument('--threshold', type=int, default=6)

    ## Training options
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for training.')
    parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay for training.')
    parser.add_argument('--batch_size', type=int, default=4, help='Dataset batch size.')
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.')
    parser.add_argument('--margin', default=0.01, type=float, help='Margin for hinge loss')

    ## LoRA-specific options
    parser.add_argument('--lora_r', type=int, default=8, help='LoRA attention dimension')
    parser.add_argument('--lora_alpha', type=float, default=0.1, help='Alpha for attention scaling')
    parser.add_argument('--lora_dropout', type=float, default=0.1, help='Dropout probability for LoRA layers')

    parser.add_argument('--debug', type=bool, default=False)

    return parser.parse_args()


class LightningPerceptualModel(pl.LightningModule):
    def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", hidden_size: int = 1,
                 lr: float = 0.0003, use_lora: bool = False, margin: float = 0.05, lora_r: int = 16,
                 lora_alpha: float = 0.5, lora_dropout: float = 0.3, weight_decay: float = 0.0, train_data_len: int = 1,
                 load_dir: str = "./models", device: str = "cuda",
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.feat_type = feat_type
        self.model_type = model_type
        self.stride = stride
        self.hidden_size = hidden_size
        self.lr = lr
        self.use_lora = use_lora
        self.margin = margin
        self.weight_decay = weight_decay
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.train_data_len = train_data_len

        self.started = False
        self.val_metrics = {'loss': Mean().to(device), 'score': Mean().to(device),
                            'loss_cls': Mean().to(device), 'loss_patch': Mean().to(device)}
        self.test_metrics = {'loss': Mean().to(device), 'score': Mean().to(device),
                            'loss_cls': Mean().to(device), 'loss_patch': Mean().to(device)}
        self.__reset_val_metrics()
        self.__reset_test_metrics()

        self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, stride=self.stride,
                                                hidden_size=self.hidden_size, lora=self.use_lora, load_dir=load_dir,
                                                device=device)
        if self.use_lora:
            self.__prep_lora_model()
        else:
            self.__prep_linear_model()

        pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters())
        pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad)
        print(pytorch_total_params)
        print(pytorch_total_trainable_params)

        self.criterion = HingeLoss(margin=self.margin, device=device)

        self.epoch_loss_train = 0.0
        self.train_num_correct = 0.0
        self.__reset_plotting()

    def forward(self, img_ref, img_0, img_1, return_patch=False):
        if return_patch:
            out_0 = self.perceptual_model(img_ref, img_0, return_patch=return_patch)
            out_1 = self.perceptual_model(img_ref, img_1, return_patch=return_patch)
            dist_0 = out_0[0]
            dist_1 = out_1[0]
            patch_ref = out_0[1]
            patch_0 = out_0[2]
            patch_1 = out_1[2]
            return dist_0, dist_1, patch_ref, patch_0, patch_1
        else:
            dist_0 = self.perceptual_model(img_ref, img_0, return_patch=return_patch)
            dist_1 = self.perceptual_model(img_ref, img_1, return_patch=return_patch)
            return dist_0, dist_1


    def training_step(self, batch, batch_idx):
        img_ref, img_0, img_1, target, idx = batch
        dist_0, dist_1 = self.forward(img_ref, img_0, img_1, return_patch=False)

        decisions = torch.lt(dist_1, dist_0)
        logit = dist_0 - dist_1
        loss = self.criterion(logit.squeeze(), target)
        loss /= target.shape[0]
        self.epoch_loss_train += loss
        self.train_num_correct += ((target >= 0.5) == decisions).sum()
        return loss

    def validation_step(self, batch, batch_idx):
        img_ref, img_0, img_1, target, id = batch
        plot = False#atch_idx == 0
        if plot:
            dist_0, dist_1 = self.forward(img_ref, img_0, img_1, return_patch=False)
            # from util.utils import feature_pca, remove_axes
            #
            # b, hw, c = patch_0_orig.shape
            # n = 6
            # f, ax = plt.subplots(n * 2, 3, figsize=(3 * 2, n * 2 * 2))
            # for i in range(n):
            #     patch_ref = patch_ref_orig[i:i+1]
            #     patch_ref = patch_ref.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)
            #     patch_ref_pca, pca = feature_pca(patch_ref, return_pca=True)
            #
            #     patch_0 = patch_0_orig[i:i+1]
            #     patch_0 = patch_0.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)
            #     patch_0_pca = feature_pca(patch_0, pcas=pca)
            #
            #     patch_1 = patch_1_orig[i:i+1]
            #     patch_1 = patch_1.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)
            #     patch_1_pca = feature_pca(patch_1, pcas=pca)
            #
            #
            #     ax[i * 2][1].set_title(f'gt={target[i].item()}', fontsize=18)
            #     ax[i * 2][0].imshow(img_0[i].permute(1, 2, 0).detach().cpu())
            #     ax[i * 2][1].imshow(img_ref[i].permute(1, 2, 0).detach().cpu())
            #     ax[i * 2][2].imshow(img_1[i].permute(1, 2, 0).detach().cpu())
            #
            #     ax[i * 2 + 1][0].imshow(patch_0_pca[0])
            #     ax[i * 2 + 1][1].imshow(patch_ref_pca[0])
            #     ax[i * 2 + 1][2].imshow(patch_1_pca[0])
            # remove_axes(ax)
            # plt.tight_layout()
            # self.logger.experiment.add_figure('val/features', f, global_step=self.current_epoch)

        else:
            dist_0, dist_1 = self.forward(img_ref, img_0, img_1, return_patch=False)
        logit = dist_0 - dist_1
        decisions = torch.lt(dist_1, dist_0)
        loss = self.criterion(logit.squeeze(), target)
        loss /= target.shape[0]
        val_num_correct = ((target >= 0.5) == decisions).sum()

        self.__plot_scatterplot(dist_0, dist_1, target)

        self.val_metrics['loss'].update(loss, target.shape[0])
        self.val_metrics['score'].update(val_num_correct, target.shape[0])
        return loss

    def test_step(self, batch, batch_idx):
        img_ref, img_0, img_1, target, id = batch
        dist_0, dist_1 = self.forward(img_ref, img_0, img_1)
        logit = dist_0 - dist_1
        decisions = torch.lt(dist_1, dist_0)
        loss = self.criterion(logit.squeeze(), target)
        loss /= target.shape[0]
        val_num_correct = ((target >= 0.5) == decisions).sum()

        self.__plot_scatterplot(dist_0, dist_1, target)

        self.test_metrics['loss'].update(loss, target.shape[0])
        self.test_metrics['score'].update(val_num_correct, target.shape[0])
        return loss

    def on_train_epoch_start(self):
        self.epoch_loss_train = 0.0
        self.train_num_correct = 0.0
        self.started = True

    def on_train_epoch_end(self):
        epoch = self.current_epoch + 1 if self.started else 0
        self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, epoch)
        self.logger.experiment.add_scalar(f'train_2afc_acc/', self.train_num_correct / self.train_data_len, epoch)
        if self.use_lora:
            self.__save_lora_weights()

    def on_train_start(self):
        for extractor in self.perceptual_model.extractor_list:
            extractor.model.train()

    def on_validation_start(self):
        for extractor in self.perceptual_model.extractor_list:
            extractor.model.eval()

    def on_validation_epoch_start(self):
        self.__reset_val_metrics()

    def on_test_epoch_start(self):
        self.__reset_test_metrics()

    def on_test_epoch_end(self):
        print('Test metrics')
        score = self.test_metrics['score'].compute()
        print(score)

    def on_validation_epoch_end(self):
        epoch = self.current_epoch + 1 if self.started else 0
        score = self.val_metrics['score'].compute()
        loss = self.val_metrics['loss'].compute()

        self.log(f'val_acc_ckpt', score, logger=False)
        self.log(f'val_loss_ckpt', loss, logger=False)
        print(f'Epoch {epoch} - Val Acc: {score} - Val Loss: {loss}')
        # log for tensorboard
        self.logger.experiment.add_scalar(f'val_2afc_acc/', score, epoch)
        self.logger.experiment.add_scalar(f'val_loss/', loss, epoch)
        self.logger.experiment.add_figure(f'val_2afc_scatter/', self.val_scatter_plots[0], epoch)

        self.__reset_plotting()

        return score

    def configure_optimizers(self):
        params = list(self.perceptual_model.parameters())
        for extractor in self.perceptual_model.extractor_list:
            params += list(extractor.model.parameters())
        for extractor, feat_type in zip(self.perceptual_model.extractor_list, self.perceptual_model.feat_type_list):
            if feat_type == 'embedding':
                params += [extractor.proj]
        optimizer = torch.optim.Adam(params, lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)
        return [optimizer]

    def load_lora_weights(self, checkpoint_root, epoch_load):
        for extractor in self.perceptual_model.extractor_list:
            load_dir = os.path.join(checkpoint_root,
                                    f'epoch_{epoch_load}_{extractor.model_type}')
            extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(extractor.device)

    def __reset_plotting(self):
        self.val_scatter_plots = plt.subplots()

    def __reset_val_metrics(self):
        for k, v in self.val_metrics.items():
            v.reset()

    def __reset_test_metrics(self):
        for k, v in self.test_metrics.items():
            v.reset()

    def __prep_lora_model(self):
        for extractor in self.perceptual_model.extractor_list:
            config = LoraConfig(
                r=self.lora_r,
                lora_alpha=self.lora_alpha,
                lora_dropout=self.lora_dropout,
                bias='none',
                target_modules=['qkv']
            )
            extractor_model = get_peft_model(ViTModel(extractor.model, ViTConfig()),
                                             config).to(extractor.device)
            extractor.model = extractor_model

    def __prep_linear_model(self):
        for extractor in self.perceptual_model.extractor_list:
            extractor.model.requires_grad_(False)
            if self.feat_type == "embedding":
                extractor.proj.requires_grad_(False)
            self.perceptual_model.mlp.requires_grad_(True)

    def __save_lora_weights(self):
        for extractor in self.perceptual_model.extractor_list:
            save_dir = os.path.join(self.trainer.callbacks[-1].dirpath,
                                    f'epoch_{self.trainer.current_epoch}_{extractor.model_type}')
            extractor.model.save_pretrained(save_dir, safe_serialization=False)
            adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin'))
            new_adapters_weights = dict()

            for k, v in adapters_weights.items():
                new_k = 'base_model.model.' + k
                new_adapters_weights[new_k] = v
            torch.save(new_adapters_weights, os.path.join(save_dir, 'adapter_model.bin'))

    def __plot_scatterplot(self, dist_0, dist_1, target):
        self.val_scatter_plots[1].scatter(
            dist_0[target > 0.5].detach().cpu().numpy(),
            dist_1[target > 0.5].detach().cpu().numpy(), c='red')
        self.val_scatter_plots[1].scatter(
            dist_0[target < 0.5].detach().cpu().numpy(),
            dist_1[target < 0.5].detach().cpu().numpy(), c='blue')
        self.val_scatter_plots[1].scatter(
            dist_0[target == 0.5].detach().cpu().numpy(),
            dist_1[target == 0.5].detach().cpu().numpy(), c='green')

def run(args, device):
    tag = args.tag if len(args.tag) > 0 else ""
    training_method = "lora" if args.use_lora else "mlp"
    exp_dir = os.path.join(args.log_dir,
                           f'{tag}_{str(args.model_type)}_n{str(args.n_samples)}_{str(args.feat_type)}_{str(training_method)}_' +
                           f'lr_{str(args.lr)}_batchsize_{str(args.batch_size)}_wd_{str(args.weight_decay)}'
                           f'_hiddensize_{str(args.hidden_size)}_margin_{str(args.margin)}'
                           )
    if args.use_lora:
        exp_dir += f'_lorar_{str(args.lora_r)}_loraalpha_{str(args.lora_alpha)}_loradropout_{str(args.lora_dropout)}'

    seed_everything(args.seed)
    g = torch.Generator()
    g.manual_seed(args.seed)

    train_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="train", preprocess=get_preprocess(args.model_type), threshold=args.threshold, name=args.dataset_name)
    val_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="val", preprocess=get_preprocess(args.model_type), name=args.dataset_name)
    # test_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="test", preprocess=get_preprocess(args.model_type), name=args.dataset_name)

    n_samples = len(train_dataset) if args.n_samples == -1 else args.n_samples
    train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset))[:n_samples])
    val_dataset = torch.utils.data.Subset(val_dataset, torch.randperm(len(val_dataset))[:1000])
    print(f"Training on {len(train_dataset)} samples")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True,
                              worker_init_fn=seed_worker, generator=g)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
    # test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)

    logger = TensorBoardLogger(save_dir=exp_dir, default_hp_metric=False) if not args.debug else None
    trainer = Trainer(devices=1,
                      accelerator='gpu',
                      log_every_n_steps=10,
                      logger=logger,
                      max_epochs=args.epochs,
                      default_root_dir=exp_dir,
                      callbacks=ModelCheckpoint(monitor='val_loss_ckpt',
                                                save_top_k=-1,
                                                save_last=True,
                                                filename='{epoch:02d}',
                                                mode='max') if not args.debug else None,
                      num_sanity_val_steps=0,
                      )

    checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'{trainer.logger.version}')
    # checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'version_0/checkpoints')#{trainer.logger.version}')
    os.makedirs(checkpoint_root, exist_ok=True)
    with open(os.path.join(checkpoint_root, 'config.yaml'), 'w') as f:
        yaml.dump(args, f)

    logging.basicConfig(filename=os.path.join(checkpoint_root, 'exp.log'), level=logging.INFO, force=True)
    logging.info("Arguments: ", vars(args))

    model = LightningPerceptualModel(device=device, train_data_len=len(train_dataset), **vars(args))
    # model.load_lora_weights(checkpoint_root=checkpoint_root, epoch_load=7)
    logging.info("Validating before training")
    trainer.validate(model, val_loader)
    logging.info("Training")
    trainer.fit(model, train_loader, val_loader)
    # trainer.test(model, test_loader)

    print("Done :)")


if __name__ == '__main__':
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    run(args, device)






