import math

import numpy as np
from matplotlib import pyplot as plt
from util.utils import get_preprocess
from torch.utils.data import DataLoader
import torch
from util.utils import feature_pca, remove_axes

from dataset_nights.dataset import TwoAFCDataset

from training.train import LightningPerceptualModel
import os

def normalize(features_pca):
    features_pca = features_pca - np.min(features_pca, axis=(1, 2), keepdims=True)
    features_pca /= np.max(features_pca, axis=(1, 2), keepdims=True)
    return features_pca

model_dict = {
    'dinov2_vitb14': {
        'root': '/home/fus/repos/dreamsim/output/new_backbones/lora_single_dinov2_vitb14_cls_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_4.0_loradropout_0.0/lightning_logs/version_0/checkpoints',
        'epoch': 7,
        'stride': '14',
        'lora_alpha': 4,
    },
    'dino_vitb16': {
        'root': '/home/fus/repos/dreamsim/output/orig/5_None_dino_vitb16_dist_cls_lora_lr_0.0003_batchsize_32_wd_0.0_psnr_hiddensize_1024_margin_0.05_lorar_16_loraalpha_0.5_loradropout_0.3_lorabias_False/lightning_logs/version_0/checkpoints',
        'epoch': 6,
        'stride': '16',
        'lora_alpha': 0.5
    },
    'open_clip_vitb32': {
        'root': '/home/fus/repos/dreamsim/output/orig/6_None_open_clip_vitb32_dist_embedding_lora_lr_0.0003_batchsize_128_wd_0.0_psnr_hiddensize_1024_margin_0.05_lorar_16_loraalpha_0.5_loradropout_0.3_lorabias_False/lightning_logs/version_0/checkpoints',
        'epoch': 7,
        'stride': '32',
        'lora_alpha': 0.5
    }
}

model_type = 'open_clip_vitb32'

root = model_dict[model_type]['root']
epoch = model_dict[model_type]['epoch']
stride = model_dict[model_type]['stride']
lora_alpha = model_dict[model_type]['lora_alpha']

lora_model_path = f'epoch_{epoch}_{model_type}'
ckpt_path = f'epoch=0{epoch}.ckpt'

model = LightningPerceptualModel(
    feat_type='cls_patch', model_type=model_type, stride=stride, lora_r=16,
     lora_alpha=lora_alpha, device='cuda', use_lora=True
).cuda()

lora_sd = torch.load(os.path.join(root, lora_model_path, 'adapter_model.bin'))
model_sd = model.perceptual_model.extractor_list[0].model.state_dict()
model_lora_sd = {k: v for k, v in model_sd.items() if 'lora' in k}
model.eval()

msg = model.load_state_dict(torch.load(os.path.join(root, ckpt_path)), strict=False)
model.perceptual_model.extractor_list[0].model.load_state_dict(model_lora_sd, strict=False)

orig_model = LightningPerceptualModel(
    feat_type='cls_patch', model_type=model_type, stride=stride, device='cuda'
).cuda()
orig_model.eval()

val_dataset = TwoAFCDataset(root_dir='/datasets/nights_2024-04-04_1804', split="val", preprocess=get_preprocess(model_type))
val_loader = DataLoader(val_dataset, batch_size=16, num_workers=2, shuffle=False)

val_iter = iter(val_loader)
for n_iter in [3]:
    for n_iter_idx in range(n_iter):
        next(val_iter)
    batch = next(val_iter)
    img_ref, img_0, img_1, target, id = batch
    img_ref = img_ref.cuda()
    img_0 = img_0.cuda()
    img_1 = img_1.cuda()
    target = target.cuda()
    dist_0, dist_1, patch_ref_orig, patch_0_orig, patch_1_orig = model.forward(img_ref, img_0, img_1, return_patch=True)
    orig_dist_0, orig_dist_1, orig_patch_ref_orig, orig_patch_0_orig, orig_patch_1_orig = orig_model.forward(img_ref, img_0, img_1, return_patch=True)

    b, hw, c = patch_0_orig.shape
    n = 6

    f, ax = plt.subplots(n * 3, 3, figsize=(3 * 2, n * 6))
    for i in range(n):
        patch_ref = patch_ref_orig[i:i + 1]
        patch_ref = patch_ref.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)

        patch_0 = patch_0_orig[i:i + 1]
        patch_0 = patch_0.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)

        patch_1 = patch_1_orig[i:i + 1]
        patch_1 = patch_1.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)

        _, pca = feature_pca(torch.cat([patch_ref, patch_0, patch_1]), return_pca=True)

        patch_ref_pca = feature_pca(patch_ref, pcas=pca, normalize=False)
        patch_0_pca = feature_pca(patch_0, pcas=pca, normalize=False)
        patch_1_pca = feature_pca(patch_1, pcas=pca, normalize=False)

        # zero out all pixels where the first component is negative
        patch_0_pca_normalized = normalize(patch_0_pca)
        patch_ref_pca_normalized = normalize(patch_ref_pca)
        patch_1_pca_normalized = normalize(patch_1_pca)

        # patch_0_pca_normalized[patch_0_pca[..., 0] < 0] = 0
        # patch_ref_pca_normalized[patch_ref_pca[..., 0] < 0] = 0
        # patch_1_pca_normalized[patch_1_pca[..., 0] < 0] = 0

        ax[i * 3][1].set_title(f'gt={target[i].item()}', fontsize=18)
        ax[i * 3][0].imshow(img_0[i].permute(1, 2, 0).detach().cpu())
        ax[i * 3][1].imshow(img_ref[i].permute(1, 2, 0).detach().cpu())
        ax[i * 3][2].imshow(img_1[i].permute(1, 2, 0).detach().cpu())

        orig_patch_ref = orig_patch_ref_orig[i:i + 1]
        orig_patch_ref = orig_patch_ref.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)

        orig_patch_0 = orig_patch_0_orig[i:i + 1]
        orig_patch_0 = orig_patch_0.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)

        orig_patch_1 = orig_patch_1_orig[i:i + 1]
        orig_patch_1 = orig_patch_1.reshape(1, int(math.sqrt(hw)), int(math.sqrt(hw)), c).permute(0, 3, 1, 2)

        _, orig_pca = feature_pca(torch.cat([orig_patch_ref, orig_patch_0, orig_patch_1]), return_pca=True)

        orig_patch_ref_pca = feature_pca(orig_patch_ref, pcas=orig_pca, normalize=False)
        orig_patch_0_pca = feature_pca(orig_patch_0, pcas=orig_pca, normalize=False)
        orig_patch_1_pca = feature_pca(orig_patch_1, pcas=orig_pca, normalize=False)

        # zero out all pixels where the first component is negative
        orig_patch_0_pca_normalized = normalize(orig_patch_0_pca)
        orig_patch_ref_pca_normalized = normalize(orig_patch_ref_pca)
        orig_patch_1_pca_normalized = normalize(orig_patch_1_pca)

        # orig_patch_0_pca_normalized[orig_patch_0_pca[..., 0] < 0] = 0
        # orig_patch_ref_pca_normalized[orig_patch_ref_pca[..., 0] < 0] = 0
        # orig_patch_1_pca_normalized[orig_patch_1_pca[..., 0] < 0] = 0

        ax[i * 3 + 1][0].imshow(orig_patch_0_pca_normalized[0])
        ax[i * 3 + 1][1].imshow(orig_patch_ref_pca_normalized[0])
        ax[i * 3 + 1][2].imshow(orig_patch_1_pca_normalized[0])

        ax[i * 3 + 2][0].imshow(patch_0_pca_normalized[0])
        ax[i * 3 + 2][1].imshow(patch_ref_pca_normalized[0])
        ax[i * 3 + 2][2].imshow(patch_1_pca_normalized[0])
    remove_axes(ax)
    plt.tight_layout()
    f.savefig(f'figs/frozen/{model_type}_patch_{n_iter}.png')