import base64
import requests
import time
from openai import OpenAI
import anthropic
import io
from PIL import Image
import re
import google.generativeai as genai

# OpenAI API Key

openai_api_key = {OPENAI_API_KEY}

claude_client = anthropic.Anthropic(
    api_key={CLAUDE_API_KEY},
)
genai.configure(api_key={GEMINI_API_KEY})

gemini_model = genai.GenerativeModel('gemini-pro-vision')

# Function to encode the image
def encode_image(image_path, resize_ratio=0.4):
    with open(image_path, "rb") as image_file:
        # Open the image using PIL
        image = Image.open(image_file)
        
        # Calculate the new dimensions based on the resize ratio
        new_width = int(image.width * resize_ratio)
        new_height = int(image.height * resize_ratio)
        
        # Resize the image
        resized_image = image.resize((new_width, new_height))
        
        # Convert the resized image to bytes
        buffer = io.BytesIO()
        resized_image.save(buffer, format=image.format)
        image_bytes = buffer.getvalue()
        
        # Encode the resized image bytes to base64
        return base64.b64encode(image_bytes).decode('utf-8')

def gpt4v(input):
	# input = {
	# 	"image_path": "path to the image",
	# 	"question": "question to ask"
	# }

	textual_contents = [i for i in input['question'].split("<image>") if i != '']

	try:
		assert len(re.findall("<image>", input['question'])) == len(input['image_path'])
	except AssertionError:
		raise ValueError("The number of <image> elements does not match the number of image paths provided.")
	
	image_path = input['image_path']
	base64_image = [encode_image(i) for i in image_path]
	start = time.time()
	content = []

	if input['question'].startswith("<image>"):
		content.append({"type": "image_url", "image_url": {"url":f"data:image/jpeg;base64,{base64_image.pop(0)}", "detail": "low"}})

	while len(textual_contents) > 0 or len(base64_image) > 0:
		content.append({"type": "text", "text": textual_contents.pop(0)})
		if len(base64_image) > 0:
			content.append({"type": "image_url", "image_url": {"url":f"data:image/jpeg;base64,{base64_image.pop(0)}", "detail": "low"}})

	headers = {
	"Content-Type": "application/json",
	"Authorization": f"Bearer {openai_api_key}"
	}

	payload = {
	"model": "gpt-4-vision-preview",
	"messages": [
		{
			"role": "user",
			"content": content
		}
	],
	"max_tokens": 1024,
	'temperature': 0,
	}

	start = time.time()
	response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
	end = time.time()
	if end - start < 4:
		time.sleep(4 - (end - start))

	return response.json()

def gpt4o(input):
	# input = {
	# 	"image_path": "path to the image",
	# 	"question": "question to ask"
	# }

	textual_contents = [i for i in input['question'].split("<image>") if i != '']

	try:
		assert len(re.findall("<image>", input['question'])) == len(input['image_path'])
	except AssertionError:
		raise ValueError("The number of <image> elements does not match the number of image paths provided.")
	
	image_path = input['image_path']
	base64_image = [encode_image(i) for i in image_path]
	start = time.time()
	content = []

	if input['question'].startswith("<image>"):
		content.append({"type": "image_url", "image_url": {"url":f"data:image/jpeg;base64,{base64_image.pop(0)}", "detail": "low"}})

	while len(textual_contents) > 0 or len(base64_image) > 0:
		content.append({"type": "text", "text": textual_contents.pop(0)})
		if len(base64_image) > 0:
			content.append({"type": "image_url", "image_url": {"url":f"data:image/jpeg;base64,{base64_image.pop(0)}", "detail": "low"}})


	headers = {
	"Content-Type": "application/json",
	"Authorization": f"Bearer {openai_api_key}"
	}

	payload = {
	"model": "gpt-4o",
	"messages": [
		{
			"role": "user",
			"content": content
		}
	],
	"max_tokens": 1024,
	'temperature': 0,
	}

	start = time.time()
	response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
	end = time.time()
	if end - start < 4:
		time.sleep(4 - (end - start))

	return response.json()

def claude3(input):
	# input = {
	# 	"image_path": ["path to the image"],
	# 	"question": question to ask, with <image> elements
	# }

	textual_contents = [i for i in input['question'].split("<image>") if i != '']

	try:
		assert len(re.findall("<image>", input['question'])) == len(input['image_path'])
	except AssertionError:
		raise ValueError("The number of <image> elements does not match the number of image paths provided.")


	image_path = input['image_path']
	base64_image = [encode_image(i) for i in image_path]
	start = time.time()
	content = []

	if input['question'].startswith("<image>"):
		content.append({"type": "image", "source": {"type": "base64", "media_type": 'image/png', "data": base64_image.pop(0)}})

	while len(textual_contents) > 0 or len(base64_image) > 0:
		content.append({"type": "text", "text": textual_contents.pop(0)})
		if len(base64_image) > 0:
			content.append({"type": "image", "source": {"type": "base64", "media_type": 'image/png', "data": base64_image.pop(0)}})

	message = claude_client.messages.create(
		model="claude-3-sonnet-20240229",
		max_tokens=1024,
		temperature=0,
		messages=[
			{
				"role": "user",
				"content": content
			}
		],
	)
	end = time.time()
	if end - start < 4:
		time.sleep(4 - (end - start))

	return message

def opus(input):
	# input = {
	# 	"image_path": ["path to the image"],
	# 	"question": question to ask, with <image> elements
	# }

	textual_contents = [i for i in input['question'].split("<image>") if i != '']

	try:
		assert len(re.findall("<image>", input['question'])) == len(input['image_path'])
	except AssertionError:
		raise ValueError("The number of <image> elements does not match the number of image paths provided.")


	image_path = input['image_path']
	base64_image = [encode_image(i) for i in image_path]
	start = time.time()
	content = []

	if input['question'].startswith("<image>"):
		content.append({"type": "image", "source": {"type": "base64", "media_type": 'image/png', "data": base64_image.pop(0)}})

	while len(textual_contents) > 0 or len(base64_image) > 0:
		content.append({"type": "text", "text": textual_contents.pop(0)})
		if len(base64_image) > 0:
			content.append({"type": "image", "source": {"type": "base64", "media_type": 'image/png', "data": base64_image.pop(0)}})

	message = claude_client.messages.create(
		model="claude-3-opus-20240229",
		max_tokens=1024,
		temperature=0,
		messages=[
			{
				"role": "user",
				"content": content
			}
		],
	)
	end = time.time()
	if end - start < 4:
		time.sleep(4 - (end - start))

	return message

def gemini(input):
	# input = {
	# 	"image_path": ["path to the image"],
	# 	"question": question to ask, with <image> elements
	# }

	textual_contents = [i for i in input['question'].split("<image>") if i != '']

	try:
		assert len(re.findall("<image>", input['question'])) == len(input['image_path'])
	except AssertionError:
		raise ValueError("The number of <image> elements does not match the number of image paths provided.")

	image_path = input['image_path']
	images = [Image.open(i) for i in image_path]

	content = []

	if input['question'].startswith("<image>"):
		content.append(images.pop(0))

	while len(textual_contents) > 0 or len(images) > 0:
		content.append(textual_contents.pop(0))
		if len(images) > 0:
			content.append(images.pop(0))

	start = time.time()
	message = gemini_model.generate_content(content, generation_config=genai.types.GenerationConfig(max_output_tokens=1024,temperature=0))
	end = time.time()
	if end - start < 5:
		time.sleep(5 - (end - start))

	return message