# encoding: utf-8
# explain   : list of sft data:
# {id: str, image: path, fmri: path, vision_embeds: path conversations: List[dict{from: str, value: str}]}
# conversations include complex_reasoning, conversations, detail_description and sample_descriptions

import os
import sys
import json
import numpy as np
import random
import argparse

from tqdm import tqdm

random.seed(42)

parser = argparse.ArgumentParser()

parser.add_argument(
    "--dataset",
    type=str,
    default='nsd',
)

parser.add_argument(
    "--subject",
    type=str,
    default='subj01'
)

parser.add_argument(
    "--train",
    action='store_true'
)

args = parser.parse_args()


def main():
    root_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}'
    images_dir = f'{root_dir}/images'
    conversations_dir = f'{root_dir}/{args.dataset}_gpt_conversation'
    vision_embeds_dir = f'{root_dir}/vision_embeds'

    blip_captions = json.load(open(f'{root_dir}/{args.dataset}_captions.json', 'r'))
    coco_caption = json.load(open(f'{root_dir}/{args.dataset}_coco_captions.json', 'r'))

    fmri_dir = f'{root_dir}/fmris/{args.subject}'
    fmri2image = json.load(open(f'{fmri_dir}/{args.dataset}_fmri2image.json', 'r'))

    train_flag = 'tr' if args.train else 'te'
    if args.train:
        fmri2image = fmri2image['train']
    else:
        fmri2image = fmri2image['val']

    conversations = []
    for conv_type in ['conversation', 'complex_reasoning']:
        for fmri_id, image_id in tqdm(enumerate(fmri2image), desc=conv_type):
            fmir_path = f'{fmri_dir}/whole/{args.dataset}_betas_{train_flag}_{fmri_id:06}.npy'
            image_path = f'{images_dir}/{args.dataset}_image_{image_id:06}.png'

            conversation = []
            with open(f'{conversations_dir}/{args.dataset}_{conv_type}/{args.dataset}_{conv_type}_{image_id:06}.txt',
                      'r') as f:
                conv = f.readlines()

            conv_start = []
            for i, line in enumerate(conv):
                if line.lower().startswith('human:'):
                    conv_start.append(i)
            conv_start.append(len(conv))

            for i in range(len(conv_start) - 1):
                start = conv_start[i]
                end = conv_start[i + 1]
                human = '' if start != 0 else '<image>\n'
                gpt = ''
                state = None
                for j in range(start, end):
                    if len(conv[j]) < 2:
                        continue

                    if conv[j].lower().startswith('human:'):
                        human += conv[j][6:]
                        state = 'human'
                    elif conv[j].lower().startswith('gpt:'):
                        gpt += conv[j][4:]
                        state = 'gpt'
                    elif state is not None:
                        if state == 'human':
                            human += conv[j]
                        else:
                            gpt += conv[j]
                    else:
                        continue
                conversation.append({
                    'from': 'human',
                    'value': human
                })
                conversation.append({
                    'from': 'gpt',
                    'value': gpt
                })

            if len(conversation) == 0:
                # state = 0  # human
                for idx, line in enumerate(conv):
                    if len(line) >= 2:
                        break
                conversation.append({
                    'from': 'human',
                    'value': '<image>\n' + conv[idx]
                })

                conversation.append({
                    'from': 'gpt',
                    'value': ''.join([x for x in conv[idx + 1:] if len(x) > 2]),
                })
                # print(conversation, '\n')



            conversations.append({
                'id': f'{args.dataset}_{args.subject}_{fmri_id}_{image_id}_{conv_type}',
                'image': image_path,
                'fmri': fmir_path,
                'vision_embeds': f'{vision_embeds_dir}/vision_{image_id:06}.npy',
                'conversations': conversation,
                'conversation_type': conv_type
            })
            # print(conversations[-1])

    brief_questions = [
        'Describe the image concisely.',
        'Provide a brief description of the given image.',
        'Offer a succinct explanation of the picture presented.',
        'Summarize the visual content of the image.',
        'Provide a brief description of the image.',
        'Describe the image briefly.',
        'Summarize the image.',
        'Give a short and clear explanation of the subsequent image.',
        'Share a concise interpretation of the image provided.',
        'Present a compact description of the photo\'s key features.',
        'Relay a brief, clear account of the picture shown.',
        'Render a clear and concise summary of the photo.',
        'Write a terse but informative summary of the picture.',
        'Create a compact narrative representing the image presented.',
    ]

    detailed_questions = [
        'Describe the following image in detail.',
        'Provide a detailed description of the given image.',
        'Give an elaborate explanation of the image you see.',
        'Share a comprehensive rundown of the presented image.',
        'Offer a detailed description of the image.',
        'Describe the image in detail.',
        'Offer a thorough analysis of the image.',
        'Provide a detailed explanation of the subsequent image.',
        'Explain the various aspects of the image before you.',
        'Clarify the contents of the displayed image with great detail.',
        'Characterize the image using a well-detailed description.',
        'Break down the elements of the image in a detailed manner.',
        'Walk through the important details of the image.',
        'Portray the image with a rich, descriptive narrative.',
        'Narrate the contents of the image with precision.',
        'Analyze the image in a comprehensive and detailed manner.',
        'Illustrate the image through a descriptive explanation.',
        'Explain the image in detail.',
        'Examine the image closely and share its details.',
        'Write an exhaustive depiction of the given image.',
    ]

    for fmri_id, image_id in tqdm(enumerate(fmri2image), desc='detail_description'):
        fmir_path = f'{fmri_dir}/whole/{args.dataset}_betas_{train_flag}_{fmri_id:06}.npy'
        image_path = f'{images_dir}/{args.dataset}_image_{image_id:06}.png'

        conversation = []
        with open(f'{conversations_dir}/{args.dataset}_detail_description/'
                  f'{args.dataset}_detail_description_{image_id:06}.txt', 'r') as f:
            conv = f.read()
            conversation.append({
                'from': 'human',
                'value': '<image>\n' + random.choice(detailed_questions) + '\n'
            })
            conversation.append({
                'from': 'gpt',
                'value': conv
            })

        conversations.append({
            'id': f'{args.dataset}_{args.subject}_{fmri_id}_{image_id}_detail_description',
            'image': image_path,
            'fmri': fmir_path,
            'vision_embeds': f'{vision_embeds_dir}/vision_{image_id:06}.npy',
            'conversations': conversation,
            'conversation_type': 'detail_description'
        })

    for fmri_id, image_id in tqdm(enumerate(fmri2image), desc='sample_descriptions'):
        fmir_path = f'{fmri_dir}/whole/{args.dataset}_betas_{train_flag}_{fmri_id:06}.npy'
        image_path = f'{images_dir}/{args.dataset}_image_{image_id:06}.png'

        captions = []
        captions.extend(coco_caption[image_id]['coco_caption'])
        captions.extend(blip_captions[image_id]['captions'])

        for idx, caption in enumerate(captions):
            conversation = [{
                'from': 'human',
                'value': '<image>\n' + random.choice(brief_questions) + '\n'
            }, {
                'from': 'gpt',
                'value': caption
            }]

            conversations.append({
                'id': f'{args.dataset}_{args.subject}_{fmri_id}_{image_id}_briefly_descriptions_{idx}',
                'image': image_path,
                'fmri': fmir_path,
                'vision_embeds': f'{vision_embeds_dir}/vision_{image_id:06}.npy',
                'conversations': conversation,
                'conversation_type': 'briefly_descriptions'
            })

    output_dir = f'{root_dir}/sft_data/{args.subject}'
    os.makedirs(output_dir, exist_ok=True)
    output_json = {
        'info': {
            'dataset': args.dataset,
            'subject': args.subject,
            'train': args.train,
            'fmri_mean': f'{fmri_dir}/whole/{args.dataset}_whole_betas_mean.npy',
            'fmri_std': f'{fmri_dir}/whole/{args.dataset}_whole_betas_std.npy',
            'atlas': f'{fmri_dir}/atlas.json'
        },
        'conversations': conversations,
    }
    with open(f'{output_dir}/sft_{args.subject}_{train_flag}.json', 'w') as f:
        json.dump(output_json, f, indent=4)


if __name__ == '__main__':
    main()
