import numpy as np
import matplotlib.pyplot as plt
import os
import time
from PIL import Image
import cv2
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm

def cal_psnr(noisy_image, original_image):
    original_image = cv2.cvtColor(original_image, cv2.COLOR_RGB2GRAY)
    noisy_image = cv2.cvtColor(noisy_image, cv2.COLOR_RGB2GRAY)

    original_image = original_image.astype(np.float32)
    noisy_image = noisy_image.astype(np.float32)

    max_pixel_value = np.max(original_image)
    mse = np.mean((original_image - noisy_image) ** 2)
    psnr = 10 * np.log10((max_pixel_value ** 2) / mse)
    return psnr
def cal_ssim(noisy_image, original_image):
    original_image = cv2.cvtColor(original_image, cv2.COLOR_RGB2GRAY)
    noisy_image = cv2.cvtColor(noisy_image, cv2.COLOR_RGB2GRAY)

    ssim_value = ssim(original_image, noisy_image)
    return ssim_value

def get_img(path):
    img = Image.open(path)
    img = img.resize((128,128))
    img = np.array(img)
    return img

def cal_metric(result_path, GT_path):
    psnr_all = []
    ssim_all = []
    for i in tqdm(sorted(os.listdir(result_path))):
        
        result = get_img(result_path + i)
        GT = get_img(GT_path + i)
        
        psnr_value = cal_psnr(result,GT)
        ssim_value = cal_ssim(result,GT)

        psnr_all.append(psnr_value)
        ssim_all.append(ssim_value)

    mean_psnr = np.mean(psnr_all)
    mean_ssim = np.mean(ssim_all)

    return mean_psnr,mean_ssim


if __name__ == '__main__':
    result_path = './CycleGAN_pro/result_glare/'
    GT_path = '../glare-dataset/Test/REAL_GLARE/GT/1/'
    glare_path = '../glare-dataset/Test/REAL_GLARE/glare/1/'
    save_path = './CycleGAN_pro/psnr_result/'
    
    mean_psnr,mean_ssim = cal_metric(result_path, GT_path)

    print("mean psnr is: ",mean_psnr)
    print("mean ssim is: ",mean_ssim)

