import os
from PIL import Image
import torch
import json
import time
from tqdm import tqdm
import numpy as np

from src.models import *
from src.utils import *
from src.vqa_score import *

def get_free_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Used >tmp')
    memory_available = [-int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    os.system("rm tmp")
    GPU_id = int(np.argmax(memory_available))
    print("using GPU{}".format(GPU_id))
    return GPU_id

def batch_answer(args, questions, image_paths, model, processor, q_to_instrut, response_to_answer):

    prompts = [q_to_instrut(question) for question in questions]

    images = [Image.open(image_path).convert("RGB") for image_path in image_paths]

    if 'Qwen' in args.model:
        queries = [processor.from_list_format([{'image': image_path}, {'text': prompt}]) for image_path, prompt in zip(image_paths, prompts)]
        inputs = processor(queries, return_tensors='pt').to(args.device)
    else:
        inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True, truncation=True).to(args.device, dtype=torch.float16)
    
    if 'fuyu' in args.model:
        generate_ids = model.generate(**inputs,  max_new_tokens=30, pad_token_id=model.config.eos_token_id)
    else:
        generate_ids = model.generate(**inputs,  max_new_tokens=30)
    
    answers = [response_to_answer(i.strip()) for i in processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)]

    return answers

def main(args):

    model, processor, q_to_instrut, response_to_answer = load_model(args)

    with open(args.json_path, "r") as f:
        data = json.load(f)
    
    batch_data = [data[i:i+args.batch_size] for i in range(0, len(data), args.batch_size)]

    for i in tqdm(batch_data, ncols=100, desc=f"Running {args.model}"):

        questions = [d["question"] for d in i]

        h_perception_questions = [d["h_perception_question"] for d in i]

        l_perception_question_a = [d["l_perception_question_tuple"][0] for d in i]

        l_perception_question_q = [d["l_perception_question_tuple"][1] for d in i]

        l_perception_question_o = [d["l_perception_question_tuple"][2] for d in i]


        image_paths = [f"{args.image_dir}/full/{d['id']}.png" for d in i]

        low_image_paths = [f"{args.image_dir}/low_level/{d['id']}.png" for d in i]


        answers = batch_answer(args, questions, image_paths, model, processor, q_to_instrut, response_to_answer)

        h_perception_answers = batch_answer(args, h_perception_questions, image_paths, model, processor, q_to_instrut, response_to_answer)

        low_level_answers_q = batch_answer(args, l_perception_question_a, low_image_paths, model, processor, q_to_instrut, response_to_answer)

        low_level_answers_a = batch_answer(args, l_perception_question_q, low_image_paths, model, processor, q_to_instrut, response_to_answer)

        low_level_answers_o = batch_answer(args, l_perception_question_o, low_image_paths, model, processor, q_to_instrut, response_to_answer)

        for j, d in enumerate(i):
            d["prediction"] = answers[j]
            d["h_perception_prediction"] = h_perception_answers[j]
            d["l_perception_prediction_tuple"] = [low_level_answers_q[j], low_level_answers_a[j], low_level_answers_o[j]]
        
    with open(f'{args.data_dir}/prediction/{args.model}.json', "w") as f:
        json.dump(data, f, indent=4)
        
    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        choices=[
            "instructblip-vicuna-13b",
            "llava-1.5-13b-hf",
            "blip2-flan-t5-xxl",
            "fuyu-8b",
            'Qwen-VL-Chat',
            'Qwen-VL',
        ],
    )

    parser.add_argument(
        "--data_dir",
        type=str,
        default="./data",
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
    )

    args = parser.parse_args()

    args.json_path = f"{args.data_dir}/data.json"
    args.image_dir = f"{args.data_dir}/images"

    args.device = f"cuda:{get_free_gpu()}" if torch.cuda.is_available() else "cpu"

    main(args)