import json
import tkinter as tk
from PIL import Image, ImageTk
import torchvision
from tkinter import ttk
import torch
from tkinter import messagebox
from tkinter import filedialog

class EditorApp:
    def __init__(self, data_filename, template_filename, classes_filename, dataset='cifar10'):
        self.data_filename = data_filename
        self.template_filename = template_filename
        self.classes_filename = classes_filename
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((400, 400)),
            torchvision.transforms.ToTensor(),
        ])

        self.data = self.load_data(dataset)
        self.template_data = self.load_template_data()
        self.current_index = 0

        self.root = tk.Tk()
        self.root.title("Dictionary Editor")

        self.image_label = tk.Label(self.root)
        self.image_label.pack(side=tk.LEFT, padx=20, pady=20)

        self.key_label = tk.Label(self.root, text="")
        self.key_label.pack(pady=20)

        self.progress_label = tk.Label(self.root, text="")
        self.progress_label.pack(pady=20)

        self.value_text = tk.Text(self.root, height=50, width=80,font=("Consolas", 10))
        self.value_text.pack(side=tk.LEFT, padx=20, pady=10)

        self.value_scrollbar = tk.Scrollbar(self.root, command=self.value_text.yview)
        self.value_scrollbar.pack(side=tk.LEFT, fill="y")

        self.value_text.config(yscrollcommand=self.value_scrollbar.set)

        self.template_label = tk.Label(self.root, text="Template Data:")
        self.template_label.pack(pady=10)

        self.template_text = tk.Text(self.root, height=50, width=80,font=("Consolas", 10), state="disabled")
        self.template_text.pack(side=tk.LEFT, padx=20, pady=10)

        self.template_scrollbar = tk.Scrollbar(self.root, command=self.template_text.yview)
        self.template_scrollbar.pack(side=tk.LEFT, fill="y")

        self.template_text.config(yscrollcommand=self.template_scrollbar.set)

        self.save_image_button = tk.Button(self.root, text="Save Image", command=self.save_current_image_prompt)
        self.save_image_button.pack(side=tk.RIGHT, padx=20, pady=20)


        self.prev_button = tk.Button(self.root, text="Previous", command=self.show_previous_dict)
        self.prev_button.pack(side=tk.LEFT, padx=20, pady=20)

        self.next_button = tk.Button(self.root, text="Next", command=self.show_next_dict)
        self.next_button.pack(side=tk.LEFT, padx=20, pady=20)

        self.save_button = tk.Button(self.root, text="Save", command=self.save_changes)
        self.save_button.pack(side=tk.RIGHT, padx=20, pady=20)

        self.format_button = tk.Button(self.root, text="Format JSON", command=self.format_json)
        self.format_button.pack(side=tk.RIGHT, padx=20, pady=20)

        self.root.bind("<Control-s>", self.save_changes_shortcut)

        self.root.title("Dictionary Editor")
        self.value_text.configure(wrap="none")

        self.show_dict()

    def set_size(self):
        width = self.root.winfo_screenwidth()
        height = self.root.winfo_screenheight()
        self.root.geometry(f"{int(width * 0.8)}x{int(height * 0.8)}")


    def save_current_image(self, save_directory):
        image_path = f"{save_directory}/image_{self.current_index}.jpg"
        image, _ = self.dataset[self.current_index]
        image = Image.fromarray((image.numpy() * 255).astype('uint8').transpose(1, 2, 0))
        image.save(image_path)
        self.show_notification(f"Image saved: {image_path}", "green")

    def save_current_image_prompt(self):
        save_directory = filedialog.askdirectory(title="Select Save Directory")
        if save_directory:
            self.save_current_image(save_directory)

    def load_data(self, dataset):
        with open(self.data_filename, 'r') as f:
            data = json.load(f)
        if dataset == 'cifar10':
            self.dataset = torchvision.datasets.CIFAR10(root="../data/data", train=False, download=True, transform=self.transform)
        elif dataset == 'cifar100':
            self.dataset = torchvision.datasets.CIFAR100(root="../data/data", train=False, download=True, transform=self.transform)
        elif dataset == 'imagenet':
            self.dataset = torchvision.datasets.ImageFolder(root="../data/imagenet_val/val", transform=self.transform)
        else:
            NotImplementedError()
        return data

    def load_template_data(self):
        with open(self.template_filename, 'r') as f:
            template_data = json.load(f)

        with open(self.classes_filename, 'r') as f:
            self.classes = f.read().splitlines()
        return template_data

    def load_image(self, index):
        image, _ = self.dataset[index]
        image = Image.fromarray((image.numpy() * 255).astype('uint8').transpose(1, 2, 0))
        photo = ImageTk.PhotoImage(image)
        self.image_label.config(image=photo)
        self.image_label.image = photo

    def save_data(self):
        with open(self.data_filename, 'w') as f:
            json.dump(self.data, f, indent=4)
        self.show_notification("Changes saved", "green")

    def format_json(self):
        try:
            data_dict = json.loads(self.value_text.get("1.0", tk.END))
        except ValueError:
            messagebox.showerror("Error", "Invalid JSON")
        else:
            formatted_json = json.dumps(data_dict, indent=4)
            self.value_text.delete("1.0", tk.END)
            self.value_text.insert("1.0", formatted_json)

    def show_dict(self):
        current_dict = self.data[self.current_index]
        # cat = current_dict["target"]
        cat = current_dict["pred"]
        template_dict = self.template_data[self.classes[cat]]
        self.progress_label.config(text=f"Progress: {self.current_index + 1}/{len(self.data)}")
        self.key_label.config(text=f"Item {self.current_index + 1}")
        self.value_text.delete('1.0', tk.END)
        self.value_text.insert('1.0', json.dumps(current_dict, indent=4))
        self.template_text.config(state="normal")
        self.template_text.delete('1.0', tk.END)
        self.template_text.insert('1.0', json.dumps(template_dict, indent=4))
        self.template_text.config(state="disabled")
        self.load_image(self.current_index)

    def show_next_dict(self):
        if self.current_index < len(self.data) - 1:
            self.current_index += 1
            self.show_dict()

    def show_previous_dict(self):
        if self.current_index > 0:
            self.current_index -= 1
            self.show_dict()

    def save_changes(self):
        current_dict = self.data[self.current_index]
        new_value = self.value_text.get('1.0', tk.END)
        try:
            current_dict.update(json.loads(new_value))
        except ValueError:
            messagebox.showerror("Error", "Invalid JSON")
        else:
            self.save_data()

    def save_changes_shortcut(self, event):
        self.save_changes()

    def show_notification(self, message, color):
        self.progress_label.config(text=message, fg=color)

    def run(self):
        self.set_size()
        self.root.mainloop()

if __name__ == "__main__":
    app = EditorApp("imagenet_test_tree_vit_base_patch16_224.json", 
                    template_filename="imagenet/imagenet_prompt.json", 
                    classes_filename="imagenet/imagenet-classes.txt",
                    dataset="imagenet")
    app.run()
