from diffusers import DDIMScheduler
from zeroscope_fifo.pipelines import TextToVideoSDPipeline
from zeroscope_fifo.models import UNet3DConditionModel
import torch
import os
import math
from tqdm import trange, tqdm
from PIL import Image
from functools import partial

from fifo_utils.latents_utils import prepare_latents, shift_latents
from fifo_utils.export_utils import export_to_gif

def get_pipeline(device, cache_dir=None, use_device_map=False):
    # if use_device_map is True, the models are distributed to multiple(4) gpus(often false, but True if you don't have enough memory)
    if use_device_map:
        device_map ={
            "conv_in":0,
            "time_proj":0,
            "time_embedding":0,
            "transformer_in":0,
            "down_blocks.0":1,
            "down_blocks.1":0,
            "down_blocks.2":0,
            "down_blocks.3":0,
            "up_blocks.0":2,
            "up_blocks.1":2,
            "up_blocks.2":2,
            "up_blocks.3":3,
            "mid_block":0,
            "conv_out":0,
            "conv_act":0,
            "conv_norm_out":0,
            "decoder":1,
            "post_quant_conv":3,
            "quant_conv":3,
            "text_model":3,
            "encoder":3
        }
        pipe = TextToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_576w",
                                                    torch_dtype=torch.float16,
                                                    device_map=device_map,
                                                    cache_dir=cache_dir
                                                    )
    else:
        pipe = TextToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_576w",
                                                    torch_dtype=torch.float16,
                                                    cache_dir=cache_dir,
                                                    ).to(device)
    pipe.unet = UNet3DConditionModel.from_pretrained("cerspense/zeroscope_v2_576w",
                                                    torch_dtype=torch.float16,
                                                    cache_dir=cache_dir,
                                                    subfolder="unet"
                                                    ).to(device)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    return pipe

# generate first N frames to prepare the latents
# if the latents are already generated, this step is skipped by setting args.skip_base = True
def run_base(args, pipe, directories, generator, prompt):
    print(f"generate first {args.video_length} frames ...")
    video_frames = pipe(prompt=prompt,
                        num_inference_steps=args.num_inference_steps,
                        height=args.height,
                        width=args.width,
                        eta=args.eta,
                        num_frames=args.video_length,
                        generator=generator,
                        save_latents=True,
                        latents_dir = directories["latents_dir"]
                        ).frames
    
    output_path = directories["base_dir"] + "/origin.gif"
    print(f"Saving as {output_path}")
    export_to_gif(video_frames, output_path)

    del video_frames

# generate longer video through fifo method
def run_fifo(args, pipe, directories, generator, prompt):
    print(f"generate longer video through fifo method ...")
    # prepare scheduler
    pipe.scheduler.set_timesteps(args.num_inference_steps, device=pipe._execution_device)

    # prepare first latents from the saved latents
    latents = prepare_latents(args, directories["latents_dir"], scheduler=pipe.scheduler)
    
    # set fifo results directory
    fifo_dir = directories["base_dir"] + "/fifo"
    # os.makedirs(fifo_dir, exist_ok=True)
    fifo_video_frames = []

    timesteps = pipe.scheduler.timesteps
    timesteps = torch.flip(timesteps, [0])
    if args.lookahead_denoising:
        timesteps = torch.cat([torch.full((args.video_length//2,), timesteps[0]).to(timesteps.device), timesteps])
    # main loop
    for i in trange(args.new_video_length + args.num_inference_steps - args.video_length):
        num_frames_per_gpu = args.video_length

        # one step inference
        for rank in reversed(range(2 * args.num_partitions if args.lookahead_denoising else args.num_partitions)):
            num_inference_steps_per_gpu = args.num_inference_steps // args.num_partitions

            start_idx = rank*(num_inference_steps_per_gpu // 2) if args.lookahead_denoising else rank*num_inference_steps_per_gpu
            midpoint_idx = start_idx + num_inference_steps_per_gpu // 2
            end_idx = start_idx + num_inference_steps_per_gpu

            t = timesteps[start_idx:end_idx]
            input_latents = latents[:,:,start_idx:end_idx].clone()
            
            output_latents, frame = pipe.fifo_onestep(
                prompt=prompt,
                height=args.height,
                width=args.width,
                num_frames=num_frames_per_gpu,
                timesteps=t,
                rank=rank,
                lookahead_denoising=args.lookahead_denoising,
                generator=generator,
                latents=input_latents,
                eta=args.eta,
            )
            
            if args.lookahead_denoising :
                latents[:,:,midpoint_idx:end_idx] = output_latents[:,:,-(num_inference_steps_per_gpu // 2):]
            else:
                latents[:,:,start_idx:end_idx] = output_latents
            del output_latents

        latents = shift_latents(latents, i, args, pipe.scheduler)

        # save frame
        if args.save_frames:
            output_path = fifo_dir + f"/{i}.png"
            Image.fromarray(frame[0]).save(output_path)

        fifo_video_frames.append(frame[0])
    
    # export to gif
    output_path = directories["base_dir"] + "/fifo.gif"
    export_to_gif(fifo_video_frames[-args.new_video_length:], output_path)