import torch
from diffusers import StableDiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
import MyDiffusers
from copy import deepcopy
import numpy as np
from PIL import Image
import inspect


class StableDiffusionManager:
    def __init__(self, device, tau) -> None:
        # Save parameters
        self.device = device
        self.tau = tau

        # Initialize the pipeline
        self.pipeline = StableDiffusionPipeline.from_pretrained(
            # "CompVis/stable-diffusion-v1-4",
            # "stabilityai/stable-diffusion-2-1",
            "runwayml/stable-diffusion-v1-5",
            use_safetensors=False,
            safety_checker=None,
        ).to(self.device)

        self.pipeline_call_params = inspect.signature(self.pipeline.__call__).parameters
        self.pipeline_call_params = set(list(self.pipeline_call_params))

        self.original_scheduler = deepcopy(self.pipeline.scheduler)

        self.schedulers = MyDiffusers.get_all_schedulers(
            self.original_scheduler.config,
            tau=tau,
            num_inference_steps=1000,
            device=device,
        )

        self.alphabar = self.schedulers['scheduler_full_generation'].alphas_cumprod

        # Initialize generator
        self.generator = torch.Generator(device=device).manual_seed(0)

    def image_to_latent(self, image: Image):
        image_proc: VaeImageProcessor = self.pipeline.image_processor
        latent = image_proc.preprocess(image).to(self.device)
        # print("Warning: try not to use me too often as the autoencoder is not perfect.")
        # image = torch.from_numpy(np.array(image) / 255).permute(2, 0, 1).float()
        # image = image.unsqueeze(0) * 2 - 1
        # latent = image.to(self.device)
        my_gen = torch.Generator(device=self.device).manual_seed(184750)
        latent = self.pipeline.vae.encode(latent).latent_dist.sample(generator=my_gen)
        latent = latent * self.pipeline.vae.config.scaling_factor
        return latent

    def latent_to_image(self, z: torch.Tensor) -> Image:
        z = z.to(self.device)
        z = z / self.pipeline.vae.config.scaling_factor
        image = self.pipeline.vae.decode(z).sample
        image = (image + 1).div(2).clamp(0, 1)
        image = image.detach().cpu().squeeze(0).permute(1, 2, 0).numpy()
        image = (image * 255).round().astype("uint8")
        return Image.fromarray(image)

    def _run_monitored_pipeline(self, z, prompt, **kwargs):
        assert len(z.shape) == 4
        assert set(list(kwargs.keys())).issubset(self.pipeline_call_params)
        monitor = MyDiffusers.Monitor()
        output = self.pipeline(
            prompt=[prompt] * len(z),
            latents=z,
            **kwargs,  # i.e. num_inference_steps, eta, etc.
            callback_on_step_end=monitor,
            callback_on_step_end_tensor_inputs=self.pipeline._callback_tensor_inputs,
        )
        data = monitor.coalesce()
        return output, data

    def full_generation(self, z, prompt, **kwargs):
        self.print_eta(kwargs)
        self.pipeline.scheduler = self.schedulers["scheduler_full_generation"]
        return self._run_monitored_pipeline(z=z, prompt=prompt, **kwargs)

    def partial_generation(self, z, prompt, **kwargs):
        self.print_eta(kwargs)
        self.pipeline.scheduler = self.schedulers["scheduler_partial_generation"]
        return self._run_monitored_pipeline(z=z, prompt=prompt, **kwargs)
    
    def partial_generation_remaining(self, z, prompt, **kwargs):
        self.print_eta(kwargs)
        self.pipeline.scheduler = self.schedulers["scheduler_partial_generation_remaining"]
        return self._run_monitored_pipeline(z=z, prompt=prompt, **kwargs)

    def print_guidance_scale_inversion(self, kwargs):
        guidance_scale = kwargs.get("guidance_scale", 7.5)
        if guidance_scale > 1.0:
            print(
                f"Warning: An high guidance scale ({guidance_scale}) may affect the inversion quality."
            )
            # cfr https://arxiv.org/pdf/2211.09794.pdf page 4 "We observe that such a guidance scale amplifies the accumulated error"

    def print_eta(self, kwargs):
        eta = kwargs.get("eta", 0.0)
        if type(eta) == torch.Tensor:
            # We are calling the function using a spatial eta
            eta = eta.mean().item()
        if eta >= 0.9999:
            print(
                f"Warning: A high eta ({eta}) may affect the quality of the generated image."
            )

    def full_inversion(self, z, prompt, **kwargs):
        self.print_guidance_scale_inversion(kwargs)
        self.print_eta(kwargs)
        self.pipeline.scheduler = self.schedulers["scheduler_full_inversion"]
        return self._run_monitored_pipeline(z=z, prompt=prompt, **kwargs)

    def partial_inversion(self, z, prompt, **kwargs):
        self.print_guidance_scale_inversion(kwargs)
        self.print_eta(kwargs)
        self.pipeline.scheduler = self.schedulers["scheduler_partial_inversion"]
        return self._run_monitored_pipeline(z=z, prompt=prompt, **kwargs)
