import torch
from transformers import CLIPProcessor, CLIPModel
from transformers import CLIPImageProcessor
from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPVisionModel, CLIPTextModel
from PIL import Image

from torchvision import transforms

def test_image_processor():
    pro = CLIPProcessor.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/processor')
    vpro = CLIPImageProcessor.from_pretrained('/data/workspace/hf_pre/stable-diffusion-v1-5/feature_extractor')

    img = Image.open('/data/workspace/tmp/000_unclip.png')

    inputs = pro(text=["a photo of a cat", "a photo of a dog"], images=img, return_tensors="pt", padding=True)
    vinputs = vpro(images=img, return_tensors="pt")

    print((inputs['pixel_values'] - vinputs['pixel_values']).abs().sum())
    print(inputs.keys())

def test_clip_vision_proj():
    with torch.no_grad():
        pro = CLIPProcessor.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/processor')
        model = CLIPModel.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/model')
        vpm = CLIPVisionModelWithProjection.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/model')
        tpm = CLIPTextModelWithProjection.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/model')

        img = Image.open('/data/workspace/tmp/000_unclip.png')
        inputs = pro(text=["a photo of a cat", "a photo of a dog"], images=img, return_tensors="pt", padding=True)

        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        print('model logits: ', logits_per_image)
        print(model.logit_scale)

        mvout = model.vision_model(pixel_values=inputs['pixel_values'])
        mvout = mvout[1]
        print(mvout.shape)
        mvout = model.visual_projection(mvout)

        vout = vpm(pixel_values=inputs['pixel_values'])
        print((mvout - vout.image_embeds).abs().mean())
        print(mvout[0][0], vout.image_embeds[0][0])

        mtout = model.text_model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        mtout = mtout[1]
        mtout = model.text_projection(mtout)

        tout = tpm(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        print(mtout.shape)
        print((mtout - tout.text_embeds).abs().mean())
        print(mtout[0][0], tout.text_embeds[0][0])
        logit_scale = model.logit_scale

        image_embeds = vout.image_embeds
        text_embeds = tout.text_embeds
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

        logit_scale = logit_scale.exp()
        logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
        logits_per_image = logits_per_text.t()
        print('proj logits: ', logits_per_image)

def test_config():
    from transformers import CLIPConfig
    cfg = CLIPConfig.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/model')
    config_dict, kwargs = cfg.get_config_dict('/data/workspace/hf_pre/clip-vit-large-patch14/model')
    print(config_dict, kwargs)
    print(cfg)

def check_vision_model():
    vpm = CLIPVisionModelWithProjection.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/model')
    vm = CLIPVisionModel.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/model')

    vpro = CLIPImageProcessor.from_pretrained('/data/workspace/hf_pre/stable-diffusion-v1-5/feature_extractor')
    img = Image.open('/data/workspace/tmp/000_unclip.png')
    inputs = vpro(images=img, return_tensors="pt")

    with torch.no_grad():
        vpout = vpm(pixel_values=inputs['pixel_values']).image_embeds
        vout = vm(pixel_values=inputs['pixel_values'])[1]
        vout = vpm.visual_projection(vout)
        print((vpout - vout).abs().mean())

def test_batch():
    # vm = CLIPVisionModel.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K_models/vision')
    vpro = CLIPImageProcessor.from_pretrained('/data/workspace/hf_pre/stable-diffusion-v1-5/feature_extractor')
    img = Image.open('/data/workspace/tmp/000_unclip.png')
    img1 = img.copy()
    inputs = vpro(images=[img, img1], return_tensors="pt")
    print(inputs['pixel_values'].shape)

def save_vision_model():
    vm = CLIPVisionModel.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K')
    vm.save_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K_models/vision')

    tm = CLIPTextModel.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K')
    tm.save_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K_models/text')

def save_vision_model_with_proj():
    vm = CLIPVisionModelWithProjection.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-H-14-laion2B-s32B-b79K', torch_dtype=torch.float16)
    vm.save_pretrained('/data/workspace/hf_pre/unclip_hubery/clip3/vision_model')

def check_save_vision_model():
    vom = CLIPVisionModel.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K_models/vision')
    vm = CLIPVisionModel.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K')

    tom = CLIPTextModel.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K_models/text')
    tm = CLIPTextModel.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-L-14-laion2B-s32B-b82K')

    pro = CLIPProcessor.from_pretrained('/data/workspace/hf_pre/clip-vit-large-patch14/processor')

    img = Image.open('/data/workspace/tmp/000_unclip.png')
    inputs = pro(text=["a photo of a cat", "a photo of a dog"], images=img, return_tensors="pt", padding=True)

    with torch.no_grad():
        voout = vom(pixel_values=inputs['pixel_values'])[1]
        vout0 = vm(pixel_values=inputs['pixel_values'])
        toout = tom(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])[1]
        tout0 = tm(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        vout = vout0[1]
        tout = tout0[1]
        print((voout - vout).abs().mean())
        print(vout0[0].shape, vout0[1].shape)
        print((toout - tout).abs().mean())
        print(tout0[0].shape, tout0[1].shape)

def check_save_vision_model_proj():

    img = Image.open('/data/workspace/tmp/000_unclip.png')
    pro = CLIPImageProcessor.from_pretrained('/data/workspace/hf_pre/unclip_hubery/clip3/feature_extractor')

    inputs = pro(images=img, return_tensors="pt")

    vom = CLIPModel.from_pretrained('/data/workspace/hf_pre/CLIP-ViT-H-14-laion2B-s32B-b79K', torch_dtype=torch.float32)
    vm = CLIPVisionModelWithProjection.from_pretrained('/data/workspace/hf_pre/unclip_hubery/clip3/vision_model', torch_dtype=torch.float32)

    voout = vom.get_image_features(pixel_values=inputs['pixel_values'])
    vout = vm(pixel_values=inputs['pixel_values'])[0]
    print(voout.shape, vout.shape)
    print((voout - vout).abs().mean(), voout[0], vout[0])

def check_image_preprocessor():
    import cv2
    import numpy as np
    pro = CLIPImageProcessor.from_pretrained('/Users/hubery/projects/9_share/hf_pre/stable-diffusion-v1-5/feature_extractor')

    mean = pro.image_mean
    std = pro.image_std

    class PIL_torgb():
        def __call__(self, img):
            return img.convert('RGB')
    class PIL_Resize():
        def __init__(self, size, interpolation=3):
            self.size = size
            self.interpolation = interpolation
        def __call__(self, img):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)

    class PIL_CenterCrop():
        def __init__(self, size):
            self.size = int(size)
        def __call__(self, img):
            orig_width, orig_height = img.size
            top = (orig_height - self.size) // 2
            bottom = top + self.size
            left = (orig_width - self.size) // 2
            right = left + self.size

            return img.crop((left, top, right, bottom))

    def get_pro_transform(pro):
        img_size = pro.size['shortest_edge']
        mean = pro.image_mean
        std = pro.image_std
        resample = pro.resample

        trans = transforms.Compose(
            [
                PIL_torgb(),
                PIL_Resize(img_size, resample),
                PIL_CenterCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )

        return trans

    trans = get_pro_transform(pro)

    img = Image.open('/Users/hubery/projects/4_data/raw_human/d/111.jpeg')

    import os
    from app.guofeng.guofeng_webuiapi import getAllImageFiles
    root = '/Users/hubery/projects/4_data/raw_human/d/'
    files = getAllImageFiles(root)
    for f in files:
        img = Image.open(os.path.join(root, f))
        imgs = [img, img]

        a = pro(imgs, return_tensors='pt').pixel_values
        b = [trans(i)[None, ...] for i in imgs]
        b = torch.cat(b, 0)

        # a = pro.resize(np.array(imgs[0]), pro.size, 3).astype(np.float32)
        # a = pro.center_crop(a, pro.crop_size)
        # b = np.array(trans(imgs[0])).astype(np.float32)

        print(f)
        print(type(a), type(b))
        print(a.shape, b.shape)
        print(a.dtype, b.dtype, a.max(), b.max())
        print(abs(a - b).sum())
        print(abs(a - b).max())
        print(abs(a - b).mean() / 255)

        if abs(a - b).sum() > 0:
            print('#'*20)
            print('error ' + f)
            print('#'*20)
            break

        # cv2.imshow('a', a / 255)
        # cv2.imshow('b', b / 255)
        # cv2.waitKey(0)
    

if __name__ == '__main__':
    # test_image_processor()
    # test_clip_vision_proj()
    # test_config()
    # check_vision_model()
    # save_vision_model()
    # check_save_vision_model()
    # test_batch()
    # save_vision_model_with_proj()
    # check_save_vision_model_proj()
    check_image_preprocessor()