import os
import warnings
from glob import glob
from pathlib import Path

import numpy as np
import SimpleITK as sitk
from skimage.transform import iradon, radon, resize
from tqdm import tqdm

warnings.filterwarnings("error")


# def snr(x, y):
#     return 20 * np.log10(np.linalg.norm(x) / np.linalg.norm(x - y))


if __name__ == "__main__":
    save_dir = "/XXXX-2/XXXX-1/scratch/uq_diffusion/luna/data"
    Path.mkdir(Path(save_dir), exist_ok=True)
    noise_std = 0.4

    files = glob("/XXXX-2/XXXX-1/scratch/Luna16/images/**/*.mhd", recursive=True)
    files = files[:12000]
    i = 0
    for file in tqdm(files, total=len(files)):
        image = sitk.ReadImage(file, sitk.sitkInt32)
        image = sitk.GetArrayFromImage(image)
        for slice in image:
            # normalize
            slice = np.maximum(slice, -1000)
            slice = (slice + 1000) / (4000)
            # resize
            slice = resize(
                slice,
                (256, 256),
                mode="reflect",
                anti_aliasing=True,
                preserve_range=True,
            )

            try:
                # radon
                theta = np.linspace(0.0, 180.0, max(slice.shape), endpoint=False)
                sinogram = radon(slice, theta=theta, circle=False, preserve_range=True)

                # sparse
                sparse_sinogram = sinogram[:, ::4]
                sparse_theta = theta[::4]
                noisy_sinogram = (
                    sparse_sinogram
                    + np.random.randn(*sinogram[:, ::4].shape) * noise_std
                )

                fbp = iradon(
                    noisy_sinogram,
                    theta=sparse_theta,
                    filter_name="ramp",
                    circle=False,
                    preserve_range=True,
                )

                np.save(os.path.join(save_dir, f"target_{i}.npy"), slice)
                np.save(os.path.join(save_dir, f"measurement_{i}.npy"), fbp)
                i += 1

            except UserWarning:
                print(f"Skipping {file}")
            except RuntimeWarning:
                print(f"Skipping {file}")
