import json
import os
from io import BytesIO

import requests
from clip_retrieval.clip_client import ClipClient, Modality
# import replicate
from icrawler.builtin import BingImageCrawler, GoogleImageCrawler
from PIL import Image

IMAGE_BASE_URL = "https://github.com/rom1504/clip-retrieval/raw/main/tests/test_clip_inference/test_images/"

client = ClipClient(
    url="https://knn.laion.ai/knn-service",
    indice_name="laion5B-L-14",
    aesthetic_score=9,
    aesthetic_weight=0.5,
    modality=Modality.IMAGE,
    num_images=20,
)


def traverse_dict(dictionary):
    """
    A helper function to traverse a multi-level dictionary and extract all the string values.

    Args:
        dictionary: A dictionary to traverse.

    Returns:
        A list of string values from the input dictionary.
    """
    sentences = []
    if isinstance(dictionary, dict):
        for key, value in dictionary.items():
            if key == "Concepts" or key == "Substance":
                for sentence in value:
                    sentences.append(sentence)
            else:
                sentences += traverse_dict(value)
    elif isinstance(dictionary, list):
        # for item in dictionary:
        sentences += dictionary
    return sentences


def download_images_clip(prompt_file, save_dir):
    """
    Download images from CLIP retrieval service and save them to disk.

    Args:
        sentences: A list of sentences to query the CLIP retrieval service.
        client: A CLIPClient object to query the service.
        category: A string specifying the category of the images.

    Returns:
        None
    """
    with open(prompt_file) as file:
        data = json.load(file)

    sentences = dict()
    # Generate sentences for each attribute
    for key, value in data.items():
        sentences[key] = traverse_dict(value)

    # Download and save images for each sentence
    for category, sentence_list in sentences.items():
        # Create a folder for the sentence if it does not exist
        category_dir = os.path.join(save_dir, category)
        os.makedirs(category_dir, exist_ok=True)
        for sentence in sentence_list:
            results = client.query(text=sentence)
            sentence_dir = os.path.join(category_dir, sentence)
            os.makedirs(sentence_dir, exist_ok=True)
            for i, result in enumerate(results):
                response = requests.get(result['url'])
                try:
                    img = Image.open(BytesIO(response.content)).convert('RGB')
                    # Save the image to the sentence folder
                    img.save(os.path.join(sentence_dir, f"image{i}.jpg"))
                except:
                    print("Save Image Fail.")


def download_images_sd(prompt_file, save_dir, api_token):
    import replicate
    """
    Download images from CLIP retrieval service and save them to disk.

    Args:
        prompt_file: A string specifying the path to a JSON file containing the prompts to use.
        save_dir: A string specifying the path to the directory where the images should be saved.
        api_token: A string specifying the Replicate API token to use.

    Returns:
        None
    """
    # Initialize the Replicate client
    client = replicate.Client(api_token=api_token)

    # Load the prompts from the JSON file
    with open(prompt_file) as file:
        data = json.load(file)

    sentences = dict()
    # Generate sentences for each attribute
    for key, value in data.items():
        sentences[key] = traverse_dict(value)

    # Query the model for each prompt and save the resulting images
    for category, sentence_list in sentences.items():
        # Create a folder for the category if it does not exist
        category_dir = os.path.join(save_dir, category)
        os.makedirs(category_dir, exist_ok=True)
        for sentence in sentence_list:
            # Query the model for image URLs
            urls = client.run(
                "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
                input={"prompt": sentence,
                       "num_outputs": 4,
                       "image_dimensions": "512x512"}
            )

            # Download and save the images
            sentence_dir = os.path.join(category_dir, sentence)
            if os.path.isdir(sentence_dir):
                # Skip folders that already exist and contain images
                if len(os.listdir(sentence_dir)) > 0:
                    print(f"Skipping existing folder {sentence_dir}")
                    continue
            os.makedirs(sentence_dir, exist_ok=True)
            for i, url in enumerate(urls):
                response = requests.get(url)
                try:
                    img = Image.open(BytesIO(response.content)).convert('RGB')
                    # Save the image to the sentence folder
                    img.save(os.path.join(sentence_dir, f"image{i}.jpg"))
                except:
                    print("Save Image Fail.")


def download_images_google(prompt_file, save_dir):
    """
    Download images using icrawler and save them to disk.

    Args:
        prompt_file: A string specifying the name of the prompt file to read.
        save_dir: A string specifying the directory to save the downloaded images.

    Returns:
        None
    """
    with open(prompt_file) as file:
        data = json.load(file)

    sentences = dict()
    # Generate sentences for each attribute
    for key, value in data.items():
        sentences[key] = traverse_dict(value)

    # Download and save images for each sentence
    for category, sentence_list in sentences.items():
        # Create a folder for the sentence if it does not exist
        category_dir = os.path.join(save_dir, category)
        os.makedirs(category_dir, exist_ok=True)
        for sentence in sentence_list:
            sentence_dir = os.path.join(category_dir, sentence)
            if os.path.isdir(sentence_dir):
                # Skip folders that already exist and contain images
                if len(os.listdir(sentence_dir)) > 50:
                    print(f"Skipping existing folder {sentence_dir}")
                    continue
            os.makedirs(sentence_dir, exist_ok=True)
            crawler = GoogleImageCrawler(feeder_threads=2,
                                         parser_threads=2,
                                         downloader_threads=8,
                                         storage={"root_dir": sentence_dir})
            # crawler.crawl(keyword=sentence, max_num=20)
            crawler.crawl(keyword=sentence, offset=0, max_num=100,
                          min_size=(100, 100), max_size=None, file_idx_offset=0)


def download_images_bing(prompt_file, save_dir, number_image=10):
    """
    Download images using icrawler and save them to disk.

    Args:
        prompt_file: A string specifying the name of the prompt file to read.
        save_dir: A string specifying the directory to save the downloaded images.

    Returns:
        None
    """
    with open(prompt_file) as file:
        data = json.load(file)

    sentences = dict()
    # Generate sentences for each attribute
    for key, value in data.items():
        sentences[key] = traverse_dict(value)

    # Download and save images for each sentence
    for category, sentence_list in sentences.items():
        # Create a folder for the sentence if it does not exist
        category_dir = os.path.join(save_dir, category)
        os.makedirs(category_dir, exist_ok=True)
        for sentence in sentence_list:
            try:
                sentence_dir = os.path.join(category_dir, sentence)
            except:
                print(category_dir, sentence)
            if os.path.isdir(sentence_dir):
                # Skip folders that already exist and contain images
                if len(os.listdir(sentence_dir)) > number_image-3:
                    print(f"Skipping existing folder {sentence_dir}")
                    continue
            os.makedirs(sentence_dir, exist_ok=True)
            crawler = BingImageCrawler(feeder_threads=4,
                                       parser_threads=4,
                                       downloader_threads=4,
                                       storage={"root_dir": sentence_dir})
            crawler.crawl(keyword=sentence, filters=None,
                          offset=0, max_num=number_image)


if __name__ == "__main__":
    directory = '../imagenet_images'
    # Example usage:
    download_images_bing('imagenet/imagenet_prompt.json',
                         directory, number_image=8)
