from parse_tree.make_tree import Parse_Tree
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision
import torch
import numpy as np
import os
import json
import argparse
from dataset.cxr import ChestXray
from medclip import MedCLIPModel, MedCLIPVisionModelViT
from medclip import MedCLIPProcessor
from PIL import Image


def load_model(device):
    # model, preprocess = clip.load('ViT-B/32', device=device)
    preprocess = MedCLIPProcessor()
    model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
    model.from_pretrained()
    model = model.to(device)
    return model, preprocess


def load_data(root):
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5862785803043838],
                             std=[0.27950088968644304])]
    )
    train_dataset = ChestXray(root=root,
                              split_txt='train_val_list.txt',
                              transform=transform,
                              labeled_file=os.path.join(
                                  root, 'train_val_list_labels.txt')
                              )
    testloader = DataLoader(train_dataset, batch_size=64,
                            shuffle=False, num_workers=2)
    return testloader


def load_classes(path):
    with open(path, 'r') as f:
        classes = f.read().splitlines()
    return classes


def load_json(path):
    with open(path, 'r') as f:
        json_obj = json.load(f)
    return json_obj


def create_trees(json_obj, classes, model, device):
    trees = [None for _ in classes]
    all_sentences = {}
    for key in json_obj:
        sentences = []
        traverse_json(json_obj[key], sentences)
        all_sentences[key] = sentences
        c_tree = Parse_Tree.from_dict(json_obj[key])
        trees[classes.index(key)] = c_tree
    return trees, all_sentences


def compute_sim(device, model, testloader, classes, trees, all_sentences, processor):
    # print(all_sentences)
    text = []
    for c in classes:
        print(list(all_sentences[c]))
        text.extend(list(all_sentences[c]))
    print(text)
    sims = []
    with torch.no_grad():
        for images, labels in tqdm(testloader):
            
            inputs = processor(
                text=text,
                images=images,
                return_tensors="pt",
                padding=True
            )
            outputs = model(**inputs)
            print(outputs)
            exit()
            sims.append(outputs['logits'])
    return sims


def traverse_json(json_obj, sentences):
    if isinstance(json_obj, dict):
        for key in json_obj:
            traverse_json(json_obj[key], sentences)
    elif isinstance(json_obj, list):
        for item in json_obj:
            traverse_json(item, sentences)
    elif isinstance(json_obj, str):
        sentences.append(json_obj)


def parse_args():
    parser = argparse.ArgumentParser(description='CLI for training a model.')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        help="The dataset to use. Supported options: 'cifar10', 'cifar100'.")
    parser.add_argument('--data-root', type=str, default='./data/cifar10',
                        help='Root directory of the dataset.')
    parser.add_argument('--json-file', type=str, default='cifar10_prompt.json',
                        help='JSON file containing the parse trees.')
    parser.add_argument('--classes-file', type=str, default='cifar10/cifar10_classes.txt',
                        help='Text file containing the class names.')
    parser.add_argument('--output-file', type=str,
                        help='Output JSON file containing the updated parse trees.')
    return parser.parse_args()


def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    




    model, preprocess = load_model(device)

    testloader = load_data(args.data_root)

    classes = load_classes(args.classes_file)

    json_obj = load_json(args.json_file)

    trees, all_sentences = create_trees(
        json_obj, classes, model, device)

    all_sims, targets = compute_sim(
        device, model, testloader, classes, trees, all_sentences, preprocess)


    print("="*20)
    print("Parse trees updated")
    print("="*20)


if __name__ == '__main__':
    args = parse_args()
    main(args)
