import numpy as np
from skimage.transform import iradon, radon, resize
import torch
from torch.utils.data import Dataset
import SimpleITK as sitk
from glob import glob
from typing import Any
import os


# compute psnr
def PSNR(x, y):
    mse = np.mean((x - y) ** 2)
    return 20 * np.log10(1.0 / np.sqrt(mse))


class DebugDataset(Dataset):
    def __init__(self, noise_var, sparse_factor):
        super().__init__()

        files = glob(
            "/XXXX-2/XXXX-1/scratch/Luna16/images/**/*.mhd", recursive=True
        )[:20]

        self.xs = []
        self.ys = []
        for file in files:
            image = sitk.ReadImage(file, sitk.sitkInt32)
            image = sitk.GetArrayFromImage(image)
            slice = image[len(image) // 2]
            self.ys.append(
                to_measurement(slice, noise_std=noise_var, sparse_factor=sparse_factor)
            )

            slice = np.maximum(slice, -1000)
            slice = (slice + 1000) / (4000)
            slice = resize(
                slice,
                (128, 128),
                mode="reflect",
                anti_aliasing=True,
                preserve_range=True,
            )
            self.xs.append(slice)

    def __len__(self):
        return len(self.target_files)

    def __getitem__(self, idx: int):
        target = self.tfm(np.load(self.target_files[idx]))
        measurement = self.tfm(np.load(self.measurement_files[idx]))

        # Normalize to [-1, 1]
        target = (target - target.min()) / (target.max() - target.min())
        measurement = (measurement - measurement.min()) / (
            measurement.max() - measurement.min()
        )
        target = target * 2 - 1
        measurement = measurement * 2 - 1
        return target.float(), measurement.float()


def create_circular_mask(h, w, center=None, radius=None):
    if center is None:  # use the middle of the image
        center = (int(w / 2), int(h / 2))
    if radius is None:  # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w - center[0], h - center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)

    mask = dist_from_center <= radius
    return mask


# Takes in an sitk image and returns the corresponding measurement in [-1, 1]
def to_measurement(slice, noise_std=0.4, sparse_factor=4, ood=False):
    slice = np.maximum(slice, -1000)
    slice = (slice + 1000) / (4000)
    # resize
    slice = resize(
        slice,
        (256, 256),
        mode="reflect",
        anti_aliasing=True,
        preserve_range=True,
    )

    # add artifact
    if ood:
        mask = create_circular_mask(256, 256, radius=8, center=(100, 200))
        slice[mask] = 1
        mask = create_circular_mask(256, 256, radius=8, center=(140, 200))
        slice[mask] = 1

    # compute sparse sinogram
    theta = np.linspace(0.0, 180.0, max(slice.shape), endpoint=False)
    sinogram = radon(slice, theta=theta, circle=False, preserve_range=True)

    # sparse
    sparse_sinogram = sinogram[:, ::sparse_factor]
    sparse_theta = theta[::sparse_factor]
    noisy_sinogram = (
        sparse_sinogram + np.random.randn(*sparse_sinogram.shape) * noise_std
    )

    fbp = iradon(
        noisy_sinogram,
        theta=sparse_theta,
        filter_name="shepp-logan",
        circle=False,
        preserve_range=True,
    )
    out = resize(
        fbp, (128, 128), mode="reflect", anti_aliasing=True, preserve_range=True
    )

    # normalize to [-1, 1]
    out = (out - out.min()) / (out.max() - out.min())
    out = out * 2 - 1
    out = torch.from_numpy(out).float()
    return out
