import CRPS.CRPS as pscore
import numpy as np
import torch
from glob import glob


from skimage.transform import iradon, radon, resize
import SimpleITK as sitk


def get_crps(gt, samples):
    gt = gt.flatten()
    samples = samples.flatten(start_dim=1)
    crps_sum = 0
    fcrps_sum = 0
    acrps_sum = 0
    for i in range(len(gt)):
        crps, fcrps, acrps = pscore(samples[:, i].detach().cpu(), gt[i]).compute()
        crps_sum += crps
        fcrps_sum += fcrps
        acrps_sum += acrps
    return {
        "crps": crps_sum / len(gt),
        "fcrps": fcrps_sum / len(gt),
        "acrps": acrps_sum / len(gt),
    }


def get_ood_measurement(R, idx=-100, noise_std=0.4, sparse_factor=4):
    files = glob("/XXXX-2/XXXX-1/scratch/Luna16/images/**/*.mhd", recursive=True)
    image = sitk.ReadImage(files[idx], sitk.sitkInt32)
    image = sitk.GetArrayFromImage(image)
    slice = image[len(image) // 2]
    slice = np.maximum(slice, -1000)
    slice = (slice + 1000) / (4000)
    # resize
    slice = resize(
        slice,
        (256, 256),
        mode="reflect",
        anti_aliasing=True,
        preserve_range=True,
    )

    # add artifact
    mask = create_circular_mask(256, 256, radius=8, center=(100, 200))
    slice[mask] = 1
    mask = create_circular_mask(256, 256, radius=8, center=(140, 200))
    slice[mask] = 1

    # compute sparse sinogram
    theta = np.linspace(0.0, 180.0, max(slice.shape), endpoint=False)
    sinogram = radon(slice, theta=theta, circle=False, preserve_range=True)

    # sparse
    sparse_sinogram = sinogram[:, ::sparse_factor]
    sparse_theta = theta[::sparse_factor]
    noisy_sinogram = (
        sparse_sinogram + np.random.randn(*sparse_sinogram.shape) * noise_std
    )

    fbp = iradon(
        noisy_sinogram,
        theta=sparse_theta,
        filter_name="shepp-logan",
        circle=False,
        preserve_range=True,
    )
    out = resize(
        fbp, (128, 128), mode="reflect", anti_aliasing=True, preserve_range=True
    )

    # normalize to [-1, 1]
    out = (out - out.min()) / (out.max() - out.min())
    out = out * 2 - 1
    out = torch.from_numpy(out).float()
    return out


def create_circular_mask(h, w, center=None, radius=None):
    if center is None:  # use the middle of the image
        center = (int(w / 2), int(h / 2))
    if radius is None:  # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w - center[0], h - center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)

    mask = dist_from_center <= radius
    return mask


def PSNR(x, y):
    mse = np.mean((x - y) ** 2)
    return 20 * np.log10(1.0 / np.sqrt(mse))
