import argparse

import numpy as np
import torch
from torchvision.datasets import VOCSegmentation, Cityscapes
from dataset.ade20k_dataset import ADE20KDataset
from dataset.coco_dataset import COCOSegmentation
from dataset.davis_dataset import DavisInference
import wandb
from torchvision.transforms import InterpolationMode
from tqdm import tqdm
import torchvision.transforms as T
from torchmetrics import JaccardIndex

import sys

import models.backbones
from dataset.dataset_util import dreamsim_transform
from scripts.util import unnorm

sys.path.append('../dinov2')
import dinov2.eval.segmentation.utils.colormaps as colormaps
from PIL import Image
import matplotlib.pyplot as plt
from util.utils import feature_pca, create_cityscapes_label_colormap
from util.train_utils import seed_worker, seed_everything
import torch.nn.functional as F

def log_image():
    n = 4
    f, ax = plt.subplots(n, 4, figsize=(4 * 3, n * 3))

    ax[0][0].set_title(f'image batch {i}')
    ax[0][1].set_title(f'features')
    ax[0][2].set_title(f'pred')
    ax[0][3].set_title(f'target')

    images_unnorm = unnorm(images) if norm else images

    for j in range(n):
        preds_flat = preds.softmax(dim=1).argmax(dim=1)

        preds_flat = preds_flat.detach().cpu()
        preds_flat = preds_flat[j].flatten(-2, -1).numpy()

        target_colored = target.detach().cpu()[j].flatten(-2, -1).numpy()
        target_colored[target_colored == 255] = 21

        colormap = DATASET_COLORMAPS[dataset]
        if colormap is not None:
            colormap_array = np.array(colormap, dtype=np.uint8)
            colormap_array = np.concatenate([colormap_array, np.array([[255, 255, 255]])])

            segmentation_values = colormap_array[preds_flat].reshape(-1, 32, 32, 3)[0]
            target_colored = colormap_array[target_colored].reshape(-1, 32, 32, 3)
        else:
            segmentation_values = preds.softmax(dim=1).argmax(dim=1).detach().cpu()[j]
            target_colored = target_colored.reshape(-1, 32, 32)

        ax[j][0].imshow(images_unnorm[j].permute(1, 2, 0).detach().cpu())
        features_pca = feature_pca(features_reshaped[j:j + 1])
        ax[j][1].imshow(features_pca[0])
        ax[j][2].imshow(segmentation_values)
        ax[j][3].imshow(target_colored[0])

    plt.tight_layout()
    return plt.gcf()

class SegmentHead(torch.nn.Module):
    def __init__(self, n_tokens, in_channels, nc=1):
        super(SegmentHead, self).__init__()
        self.in_channels = in_channels
        self.nc = nc
        self.s = int(n_tokens ** 0.5)
        self.conv = torch.nn.Conv2d(in_channels, nc, (1, 1))
    def forward(self,x):
        x = F.interpolate(x, (self.s, self.s), mode='bicubic')
        return self.conv(x)


parser = argparse.ArgumentParser()
parser.add_argument('-b', '--backbone', type=str, default='dino_vitb16', help='backbone model')
parser.add_argument('-e', '--epoch', type=int, default=9, help='epoch')
parser.add_argument('-d', '--dataset', type=str, default='coco', help='dataset')
args = parser.parse_args()

dataset = args.dataset  # 'coco'
model_name = args.backbone  # 'dreamsim_dino_vitb16'

seed = 1234
seed_everything(seed)

patch_hw = 32

if dataset == 'voc2012':
    nc = 21
elif dataset == 'cityscapes':
    nc = 33
elif dataset == 'ade20k':
    nc = 151
elif dataset == 'coco':
    nc = 21
elif dataset == 'davis':
    nc = 77

lr = 3e-4
epochs = 10
bs = 128
head = SegmentHead(patch_hw * patch_hw,768, nc).cuda()
optimizer = torch.optim.Adam(head.parameters(), lr=lr)
dino_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
target_resize = T.Resize((224, 224), interpolation=InterpolationMode.NEAREST)
target_transform = T.Compose(
    [
        target_resize,
        T.PILToTensor(),
        T.Lambda(lambda x: x.long()) if dataset != 'voc2012' else T.Lambda(lambda x: x),
    ]
)

lossfn = torch.nn.CrossEntropyLoss(ignore_index=255)
norm = False

if model_name == 'dino_vitb16':
    model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16').cuda()
    model_fn = lambda x: model.get_intermediate_layers(images)[0][:, 1:]
    transform = dino_transform
    norm = True
elif model_name == 'dinov2_vitb14':
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
    model_fn = lambda x: model.forward_features(x)['x_norm_patchtokens']
    transform = dino_transform
    norm = True
elif 'dreamsim' in model_name:
    import os
    from dreamsim.dreamsim.model import dreamsim, PerceptualModel
    from dreamsim.dreamsim.feature_extraction.vit_wrapper import ViTConfig, ViTModel
    import json
    from peft import LoraConfig, get_peft_model, PeftModel

    if model_name == 'dreamsim_dinov2_vitb14':
        model_dir = 'lora_single_cat_dinov2_vitb14_n-1_cls_patch_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_16.0_loradropout_0.1/lightning_logs/version_0/checkpoints'
        backbone = 'dinov2_vitb14'
        stride = '14'
        epoch = args.epoch  #9
    elif model_name == 'dreamsim_dino_vitb16':
        model_dir = 'lora_single_cat_dino_vitb16_n-1_cls_patch_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_16.0_loradropout_0.1/lightning_logs/version_0/checkpoints'
        backbone = 'dino_vitb16'
        stride = '16'
        epoch = args.epoch  #3

    load_dir = '/home/fus/repos/repalignment/dreamsim/output/new_backbones2/'
    model_dir = os.path.join(load_dir, model_dir)
    device = 'cuda:0'

    ours_model = PerceptualModel(backbone, device=device, load_dir='/home/fus/repos/repalignment/dreamsim/models/',
                                 normalize_embeds=False, stride=stride, feat_type='cls_patch', lora=True)

    load_dir = os.path.join(model_dir, f'epoch_{epoch}_{backbone}')
    with open(os.path.join(load_dir, 'adapter_config.json'), 'r') as f:
        dreamsim_config = json.load(f)
    lora_config = LoraConfig(**dreamsim_config)
    print(lora_config)
    for extractor in ours_model.extractor_list:
        model = get_peft_model(ViTModel(extractor.model, ViTConfig()), lora_config).to(device)
        extractor.model = model
    for extractor in ours_model.extractor_list:
        extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(device)
        extractor.model.eval().requires_grad_(False)

    ours_model.eval().requires_grad_(False)
    ours_model = ours_model.to(device)
    transform = dreamsim_transform

    model_fn = lambda x: ours_model.embed(x)[:, 1:, :]

if dataset == 'voc2012':
    train_dset = VOCSegmentation('/datasets/pascal_voc12_2024-01-10_1001', '2012', 'train', transform=transform, target_transform=target_transform)
    val_dset = VOCSegmentation('/datasets/pascal_voc12_2024-01-10_1001', '2012', 'val', transform=transform, target_transform=target_transform)
elif dataset == 'cityscapes':
    train_dset = Cityscapes('/datasets/cityscapes_2024-01-04_1601', split='train', mode='fine', target_type='semantic', transform=transform, target_transform=target_transform)
    val_dset = Cityscapes('/datasets/cityscapes_2024-01-04_1601', split='val', mode='fine', target_type='semantic', transform=transform, target_transform=target_transform)
elif dataset == 'ade20k':
    train_dset = ADE20KDataset('/datasets/ade20k_2024-01-10_1001/ADEChallengeData2016', split='training', transform=transform, target_transform=target_transform)
    val_dset = ADE20KDataset('/datasets/ade20k_2024-01-10_1001/ADEChallengeData2016', split='validation', transform=transform, target_transform=target_transform)
elif dataset == 'coco':
    train_dset = COCOSegmentation('/datasets/coco2017_2024-01-04_1601', 'train', transform=transform, target_transform=target_transform)
    val_dset = COCOSegmentation('/datasets/coco2017_2024-01-04_1601', 'val', transform=transform, target_transform=target_transform)
elif dataset == 'davis':
    train_dset = DavisInference('/scratch/one_month/2024_05/fus/davis', 'train', transform=transform, target_transform=target_transform)
    val_dset = DavisInference('/scratch/one_month/2024_05/fus/davis', 'val', transform=transform, target_transform=target_transform)

DATASET_COLORMAPS = {
    'ade20k': colormaps.ADE20K_COLORMAP,
    'voc2012': colormaps.VOC2012_COLORMAP,
    'cityscapes': create_cityscapes_label_colormap(train_dset) if dataset == 'cityscapes' else None,
    'coco': None,
    'davis': None
}

nw = 4
g = torch.Generator()
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=bs, shuffle=True, worker_init_fn=seed_worker, generator=g, num_workers=nw)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=64, shuffle=False, num_workers=nw)

iou_metric = JaccardIndex(task='multiclass', average='micro', num_classes=nc, ignore_index=255).cuda()

exp_name = f'fixed_seg_{model_name}_{dataset}_lr_{lr}_bs_{bs}_epochs_{epochs}'
exp_name = f'e{epoch}_' + exp_name if 'dreamsim' in model_name else exp_name
debug = False

if not debug:
    wandb.init(
        project='rep_segment',
        name=exp_name,
        config={
            "learning_rate": lr,
            "epochs": epochs,
        }
    )

for epoch_i in range(epochs):
    print(f'Training Epoch {epoch_i}')
    head.train()
    for i, (images, masks) in tqdm(enumerate(train_loader), total=len(train_loader)):
        images = images.cuda()
        masks = masks.cuda()
        target = torch.nn.functional.interpolate(masks.float(), (32, 32), mode='nearest').cuda().long().squeeze()

        if dataset == 'cityscapes':
            target[target == -1] = 256
            target[target == 0] = 256
            target -= 1

        with torch.no_grad():
            features = model_fn(images)
        b, hw, c = features.shape
        s = int(hw ** 0.5)
        features_reshaped = features.permute(0, 2, 1).reshape(b, c, s, s)
        preds = head(features_reshaped)
        preds_soft = preds.softmax(dim=1).argmax(dim=1)
        iou = iou_metric(preds_soft, target)
        acc = (preds_soft == target).sum()
        acc = acc.float() / (target.float() != 255).sum()
        loss = lossfn(preds, target)
        if i % 10 == 0 and not debug:
            wandb.log({'train/loss': loss})
            wandb.log({'train/iou': iou})
            wandb.log({'train/acc': acc})

        # if i % 40 == 0 and not debug:
        #     f = log_image()
        #     wandb.log({'train/segmentation': wandb.Image(f)})
        loss.backward()
        optimizer.step()
    plt.close()
    print(f'Val Epoch {epoch_i}')
    head.eval()
    with torch.no_grad():
        val_loss = 0
        val_iou = 0
        val_acc = 0
        val_total = 0
        for i, (images, masks) in tqdm(enumerate(val_loader), total=len(val_loader)):
            images = images.cuda()
            masks = masks.cuda()
            target = torch.nn.functional.interpolate(masks.float(), (32, 32), mode='nearest').cuda().long().squeeze()
            if dataset == 'cityscapes':
                target[target == -1] = 256
                target[target == 0] = 256
                target -= 1

            with torch.no_grad():
                features = model_fn(images)
            b, hw, c = features.shape
            s = int(hw ** 0.5)
            features_reshaped = features.permute(0, 2, 1).reshape(b, c, s, s)
            preds = head(features_reshaped)
            preds_soft = preds.softmax(dim=1).argmax(dim=1)
            iou = iou_metric(preds_soft, target)
            val_iou += iou * len(images)
            val_acc = (preds_soft == target).sum()
            val_acc = val_acc.float() / (target.float() != 255).sum()
            val_total += len(images)

            if i == 0 and not debug:
                f = log_image()
                wandb.log({'val/segmentation': wandb.Image(f)})

            loss = lossfn(preds, target)
            val_loss += loss
        val_loss /= val_total
        val_iou /= val_total
        if not debug:
            wandb.log({'val/loss': val_loss})
            wandb.log({'val/iou': val_iou})
            wandb.log({'val/acc': val_acc})
        print(f'Val Loss: {val_loss}, Val IoU: {val_iou}, Val Acc: {val_acc}')
        plt.close()
