from analysis import report
import os
import torch
from PIL import Image
from scipy.optimize import linear_sum_assignment


def report_metrics( net, X_rec, X_true, y_true, path, neptune_rec, neptune_gt, img_size, inv_transform, neptune ):
    if neptune:
        from neptune.new.types import File

    B = X_true.shape[0]
    
    cost = torch.nn.functional.cosine_similarity(X_true.reshape(B,-1)[None,:,:], X_rec[:,None,:], dim=-1).abs()
    cost = -cost.detach().numpy().T
    _, opt_order = linear_sum_assignment(cost)
    
    X_rec = X_rec[ opt_order, : ].reshape( -1, *img_size )
    X_rec = inv_transform( X_rec )
    big = X_rec.max(axis=1)[0].max(axis=1)[0].max(axis=1)[0]
    X_rec[ big > 300 ] *= 255 / big[ big > 300, None, None, None ]
    im_rec = X_rec.to( torch.uint8 ).detach().cpu()


    X_true = inv_transform( X_true )
    im_true = X_true.to( torch.uint8 ).detach().cpu()
    
    vision_metrics = get_vision_metrcis( net, X_rec, X_true, y_true )
    print( "Metrics:", vision_metrics )
    if neptune:
        for k in vision_metrics:
            neptune[ f'result/{k}' ].log( vision_metrics[k] )

    if not os.path.exists(path):
        os.makedirs(path)

    for i in range( B ):
        img = im_true[i].permute(1,2,0)
        if img.shape[-1] == 1:
            img = img.repeat(1,1,3)
        img = Image.fromarray(img.numpy())
        img.save( os.path.join(path, f'{i}_gt.png') )

        img = im_rec[i].permute(1,2,0)
        if img.shape[-1] == 1:
            img = img.repeat(1,1,3)
        img = Image.fromarray(img.numpy())
        img.save( os.path.join(path, f'{i}_rec.png') )

        if neptune:
            neptune[ neptune_gt  ].append( File( os.path.join(path, f'{i}_gt.png'  ) ) )
            neptune[ neptune_rec ].append( File( os.path.join(path, f'{i}_rec.png' ) ) )
        
def get_vision_metrcis( net, X_rec, X_true, y_true ):
    B = X_true.shape[0]
    server_payload = [
        dict(
            parameters=[p for p in net.parameters()], buffers=[b for b in net.buffers()], metadata={"modality": "vision"}
        )
    ]

    compute_lpips = True
    if X_true.shape[2] < 32 or X_true.shape[3] < 32:
        compute_lpips = False

    true_data = { 'data': X_true/255, 'labels': y_true }
    rec_data =  { 'data': X_rec/255,  'labels': None }

    vision_metrics = report( rec_data, true_data, server_payload, net, order_batch=False, compute_rpsnr=False, compute_full_iip=False, compute_lpips=compute_lpips )
    
    del vision_metrics['rpsnr']
    del vision_metrics['max_rpsnr']
    del vision_metrics['order']
    del vision_metrics['IIP-none']
    del vision_metrics['label_acc']
    del vision_metrics['max_ssim']
    del vision_metrics['parameters']
    
    return vision_metrics

