import os
import argparse
import torch
import torchvision
from einops import rearrange
from diffusers import DDIMScheduler, AutoencoderKL, DDIMInverseScheduler
from transformers import CLIPTextModel, CLIPTokenizer

from models.pipeline import COVEPipeline
from models.util import save_videos_grid, read_video, sample_trajectories, extract_dift_and_cal_sim
from models.unet import UNet3DConditionModel
from PIL import Image
from torchvision.transforms import PILToTensor

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, default="A Tiger, high quality", help="Textual prompt for video editing")
    parser.add_argument("--input_class", type=str, default="cat", help="input video class")
    parser.add_argument("--neg_prompt", type=str, default="a cat with big eyes, deformed", help="Negative prompt for guidance")
    parser.add_argument("--guidance_scale", default=20.0, type=float, help="Guidance scale")
    parser.add_argument("--video_path", type=str, default="/cluster/home3/wjs/flatten/data/puff.mp4", help="Path to a source video")
    parser.add_argument("--sd_path", type=str, default="/cluster/home3/wjs/flatten/checkpoints/stable-diffusion-2-1-base", help="Path of Stable Diffusion")
    parser.add_argument("--output_path", type=str, default="/cluster/home3/wjs/flatten/outputs", help="Directory of output")
    parser.add_argument("--dift_save_path", type=str, default="/cluster/home3/wjs/flatten/outputs", help="Directory of output")
    parser.add_argument("--video_length", type=int, default=20, help="Length of output video")
    parser.add_argument("--old_qk", type=int, default=0, help="Whether to use old queries and keys for flow-guided attention")
    parser.add_argument("--height", type=int, default=512, help="Height of synthesized video, and should be a multiple of 32")
    parser.add_argument("--width", type=int, default=512, help="Width of synthesized video, and should be a multiple of 32")
    parser.add_argument("--sample_steps", type=int, default=50, help="Steps for feature injection")
    parser.add_argument("--inject_step", type=int, default=40, help="Steps for feature injection")
    parser.add_argument("--seed", type=int, default=66, help="Random seed of generator")
    parser.add_argument("--frame_rate", type=int, default=None, help="The frame rate of loading input video. Default rate is computed according to video length.")
    parser.add_argument("--fps", type=int, default=15, help="FPS of the output video")
    parser.add_argument("--h_state", nargs='+', type=int, default=64, help="FPS of the output video")
    parser.add_argument("--attention_t", type=int, default=400, help="FPS of the output video")
    parser.add_argument("--occ_thre", type=float, default=0.9, help="FPS of the output video")
    parser.add_argument("--merge_ratio", type=float, default=1, help="FPS of the output video") #留下来的比例
    
    parser.add_argument("--dift_t", type=int, default=261 , help="t of dift")
    parser.add_argument("--load_saved_position", type=bool, default=False , help="t of dift")
    parser.add_argument("--dift_up_ft_index", type=int, default=2 , help="t of dift")
    parser.add_argument("--dift_ensemble_size", type=int, default=8 , help="t of dift")
    
    parser.add_argument("--k_single", type=int, default=3 , help="t of dift")
    parser.add_argument("--thre", nargs='+', type=float, default=1 , help="t of dift")
    
    
    args = parser.parse_args()
    return args

    

if __name__ == "__main__":

    args = get_args()
    # args.output_path = args.output_path + str(args.h_state) +'_'+ str(args.k) +'_'+ str(args.thre)
    os.makedirs(args.output_path, exist_ok=True)
    device = "cuda"
    # Height and width should be 512
    args.height = (args.height // 32) * 32
    args.width = (args.width // 32) * 32

    generator = torch.Generator(device=device)
    generator.manual_seed(args.seed)

    # read the source video
    video = read_video(video_path=args.video_path, video_length=args.video_length,
                       width=args.width, height=args.height, frame_rate=args.frame_rate) #15,3,512,512
    
    
    original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1) #1,3,15,512,512
    save_videos_grid(original_pixels, os.path.join(args.output_path, "source_video.mp4"), rescale=True)

    ###################################################
    # k = args.k
    # thre = args.thre
    # inds_64 = extract_dift_and_cal_sim(args, os.path.join(args.output_path, "source_video.mp4"), dift_up_ft_index=2, feature_level=64, k_single=k[0], thre=thre[0]) #1/8
    # inds_32 = extract_dift_and_cal_sim(args, os.path.join(args.output_path, "source_video.mp4"), dift_up_ft_index=1, feature_level=32, k_single=k[1], thre=thre[1]) #1/16
    # inds_16 = extract_dift_and_cal_sim(args, os.path.join(args.output_path, "source_video.mp4"), dift_up_ft_index=0, feature_level=16, k_single=k[2], thre=thre[2]) #1/32
    # inds = {"64": inds_64, "32": inds_32, "16": inds_16}
    inds = extract_dift_and_cal_sim(args, os.path.join(args.output_path, "source_video.mp4"), dift_up_ft_index=2, feature_level=64, occlusion_thre=args.occ_thre, k_single=args.k_single) 
    
    # inds = extract_dift_and_cal_sim(args, os.path.join(args.output_path, "source_video.mp4"), dift_up_ft_index=1, feature_level=32) #1/16
    # inds = extract_dift_and_cal_sim(args, os.path.join(args.output_path, "source_video.mp4"), dift_up_ft_index=1, feature_level=32) #1/16
    ####################################################
    
    t2i_transform = torchvision.transforms.ToPILImage()
    real_frames = []
    for i, frame in enumerate(video):
        real_frames.append(t2i_transform(((frame+1)/2*255).to(torch.uint8)))
        
    unet_additional_kwargs = {
            'use_inflated_groupnorm': True, 
            'use_motion_module': True, 
            'motion_module_resolutions': [1, 2, 4, 8], 
            'motion_module_mid_block': False, 
            'motion_module_type': 'Vanilla', 
            'motion_module_kwargs': 
            {
                'num_attention_heads': 8, 
                'num_transformer_block': 1, 
                'attention_block_types': ["Temporal_Self", "Temporal_Self"], 
                'temporal_position_encoding': True, 
                'temporal_position_encoding_max_len': 32, 
                'temporal_attention_dim_div': 1, 
                'zero_initialize': True
            }
        }
    tokenizer = CLIPTokenizer.from_pretrained(args.sd_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(args.sd_path, subfolder="text_encoder").to(dtype=torch.float16)
    vae = AutoencoderKL.from_pretrained(args.sd_path, subfolder="vae").to(dtype=torch.float16)
    unet = UNet3DConditionModel.from_pretrained_2d(args.sd_path, subfolder="unet", unet_additional_kwargs=unet_additional_kwargs).to(dtype=torch.float16)
    scheduler = DDIMScheduler.from_pretrained(args.sd_path, subfolder="scheduler")
    inverse = DDIMInverseScheduler.from_pretrained(args.sd_path, subfolder="scheduler")

    pipe = FlattenPipeline(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
            scheduler=scheduler, inverse_scheduler=inverse)
    
    # unet_state_dict = {}
    # motion_module_path= "/cluster/home3/wjs/AnimateDiff/models/Motion_Module/mm_sd_v14.ckpt"
    # if motion_module_path != "":
    #     print(f"load motion module from {motion_module_path}")
    #     motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
    #     motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
    #     unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
    #     unet_state_dict.pop("animatediff_config", "")
    
    # missing, unexpected = pipe.unet.load_state_dict(unet_state_dict, strict=False)
    # assert len(unexpected) == 0
    # del unet_state_dict

    pipe.enable_vae_slicing()
    pipe.enable_xformers_memory_efficient_attention()
    pipe.to(device)

    # compute optical flows and sample trajectories
    trajectories = None# sample_trajectories(os.path.join(args.output_path, "source_video.mp4"), device)
    torch.cuda.empty_cache()

    # for k in trajectories.keys():
    #     trajectories[k] = trajectories[k].to(device)
    sample = pipe(args.prompt, video_length=args.video_length, frames=real_frames, inds=inds,
                num_inference_steps=args.sample_steps, generator=generator, guidance_scale=args.guidance_scale,
                negative_prompt=args.neg_prompt, width=args.width, height=args.height,
                trajs=trajectories, output_dir="tmp/", inject_step=args.inject_step, old_qk=args.old_qk, h_state=args.h_state, attention_t=args.attention_t).videos
    temp_video_name = args.prompt+"_"+args.neg_prompt+"_"+str(args.guidance_scale)
    save_videos_grid(sample, f"{args.output_path}/{temp_video_name}.mp4", fps=args.fps)
