import argparse
import json
import os

import clip
import numpy as np
import open_clip
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

from parse_tree.make_tree import Parse_Tree


def load_model(device, name='ViT-L-14', pretrined='datacomp_xl_s13b_b90k'):
    # model, preprocess = clip.load('ViT-B/32', device=device)
    model, _, preprocess = open_clip.create_model_and_transforms(
        name, pretrained=pretrined)
    model = model.to(device)
    return model, preprocess


def load_data(root, transform, dataset):
    if dataset == 'cifar10':
        dataset_class = torchvision.datasets.CIFAR10
        testset = dataset_class(
            root=root, train=False, download=True, transform=transform)
    elif dataset == 'cifar100':
        dataset_class = torchvision.datasets.CIFAR100
        testset = dataset_class(
            root=root, train=False, download=True, transform=transform)
    elif dataset == 'imagenet':
        valdir = os.path.join(root, 'val')
        testset = torchvision.datasets.ImageFolder(
            root=valdir,
            transform=transform)
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

    testloader = DataLoader(testset, batch_size=64,
                            shuffle=False, num_workers=2)
    return testloader


def load_classes(path):
    with open(path, 'r') as f:
        classes = f.read().splitlines()
    return classes


def load_json(path):
    with open(path, 'r') as f:
        json_obj = json.load(f)
    return json_obj


def extract_text_features(device, model, sentences):
    sentences = [sent.lower() for sent in sentences]
    tokenizer = open_clip.get_tokenizer('ViT-L-14')
    sentence_tokens = tokenizer(sentences).to(device)
    with torch.no_grad():
        text_features = model.encode_text(sentence_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features


def create_trees(json_obj, classes, model, device):
    trees = [None for _ in classes]
    all_sentences = {}
    text_features_dict = {}
    for key in json_obj:
        sentences = []
        traverse_json(json_obj[key], sentences)
        text_features_dict[key] = extract_text_features(
            device, model, sentences)
        all_sentences[key] = sentences
        c_tree = Parse_Tree.from_dict(json_obj[key])
        trees[classes.index(key)] = c_tree
    return trees, all_sentences, text_features_dict


def compute_sim(device, model, testloader, classes, trees, all_sentences, text_features_dict):

    sims = []
    targets = []
    for images, labels in tqdm(testloader):
        images = images.to(device)
        with torch.no_grad():
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)

        for image_feature, target in zip(image_features, labels):
            text_feature = text_features_dict[classes[target]]

            sim = (image_feature @ text_feature.T)
            sims.append(sim)
            targets.append(target)
    return sims, targets


def traverse_json(json_obj, sentences):
    if isinstance(json_obj, dict):
        for key in json_obj:
            traverse_json(json_obj[key], sentences)
    elif isinstance(json_obj, list):
        for item in json_obj:
            traverse_json(item, sentences)
    elif isinstance(json_obj, str):
        sentences.append(json_obj)


def parse_args():
    parser = argparse.ArgumentParser(description='CLI for training a model.')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        help="The dataset to use. Supported options: 'cifar10', 'cifar100'.")
    parser.add_argument('--data-root', type=str, default='./data/cifar10',
                        help='Root directory of the dataset.')
    parser.add_argument('--json-file', type=str, default='cifar10_prompt.json',
                        help='JSON file containing the parse trees.')
    parser.add_argument('--classes-file', type=str, default='cifar10/cifar10_classes.txt',
                        help='Text file containing the class names.')
    parser.add_argument('--output-file', type=str,
                        help='Output JSON file containing the updated parse trees.')
    return parser.parse_args()


def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    all_sims = None
    models_cfgs = [
                    ('ViT-B-32', 'laion2b_s34b_b79k'),
                #    ('RN50', 'openai'),
                   ('ViT-B-16', 'datacomp_l_s1b_b8k'),
                   ('ViT-L-14', 'datacomp_xl_s13b_b90k'),
                #    ('ViT-L-14', 'commonpool_xl_clip_s13b_b90k')
                   ]

    num_models = len(models_cfgs)
    for name, pretrained in models_cfgs:
        print(name, pretrained)
        model, preprocess = load_model(device, name=name, pretrined=pretrained)

        testloader = load_data(args.data_root, preprocess, args.dataset)

        classes = load_classes(args.classes_file)

        json_obj = load_json(args.json_file)

        trees, all_sentences, text_features_dict = create_trees(
            json_obj, classes, model, device)

        sims, targets = compute_sim(
            device, model, testloader, classes, trees, all_sentences, text_features_dict)
        if all_sims is None:
            all_sims = sims
        else:
            for i in range(len(all_sims)):
                all_sims[i] = all_sims[i] + sims[i]

    for i in range(len(all_sims)):
        all_sims[i] /= num_models

    path_dicts = []  # Initialize a list to store all path_dicts
    for i, sim in enumerate(all_sims):
        target = targets[i]
        sents = all_sentences[classes[target]]
        op_values, top_index = torch.topk(sim, 5)
        top_sents = [sents[i] for i in top_index]
        path = trees[target]._build_and_merge_trees(trees[target], top_sents)
        # print(path)
        path_dict = path.to_dict()
        # Append the path_dict and target to the list
        path_dicts.append({'path_dict': path_dict, 'target': target.item()})

    # Save the path_dicts list to a json file
    with open(args.output_file, "w") as f:
        json.dump(path_dicts, f)

    print("="*20)
    print("Parse trees updated")
    print("="*20)


if __name__ == '__main__':
    args = parse_args()
    main(args)
