import torch
import numpy as np
import sys
def bestPossible(eval_list,k,data):
    totLoss = 0
    for A in eval_list:
        if data=='tech':
            AM=A['M'].cuda()
        else:
            AM=A.cuda()
        U, S, V = AM.svd()
        ans = U[:, :k].mm(torch.diag(S[:k]).cuda()).mm(V.t()[:k])
        # totLoss += torch.norm(ans - AM) ** 2
        totLoss += torch.norm(ans - AM)

    return totLoss


def evaluate(sparse, eval_list,sketch_vector, sketch_value,m,k,n,d):  # evaluate the test/train performance
    totLoss = 0
    count = 0


    for A in eval_list:
        if sparse:
            AM=A['M'].cuda()
            SA = torch.Tensor(m, A['d']).fill_(0).cuda()
            for i in range(A['n']):  # A has this many rows, not mapped yet
                actR = A['Map'][i]  # Actual row in the matrix
                mapR = sketch_vector[actR]  # row is mapped to this row in the sketch
                SA[mapR] += AM[i] * sketch_value[actR]  # remember: times the weight
        else:
            AM=A.cuda()
            SA = torch.Tensor(m, d).fill_(0).cuda()
            for i in range(n):  # A has this many rows, not mapped yet
                mapR = sketch_vector[i]  # row is mapped to this row in the sketch
                SA[mapR] += AM[i] * sketch_value[i]  # remember: times the weight

        # print(SA.max().item(), SA.min().item(), SA.mean().item())
        # print(A.max().item(), A.min().item(), A.mean().item())

        U2, Sigma2, V2 = SA.svd()
        AU = AM.mm(V2)
        U3, Sigma3, V3 = AU.svd()
        ans = U3[:, :k].mm(torch.diag(Sigma3[:k]).cuda()).mm(V3.t()[:k]).mm(V2.t())
        # totLoss += (torch.norm(ans - AM) ** 2).item()
        totLoss += (torch.norm(ans - AM)).item()
        count += 1
        if (count % 10 == 0):
            print(count, end=",")
            sys.stdout.flush()
    return totLoss

def evaluate_dense(sparse, eval_list,sketch, m,k):  # evaluate the test/train performance
    totLoss = 0
    count = 0


    for A in eval_list:
        if sparse:
            AM=A['M'].cuda()
            SA = torch.Tensor(m, A['d']).fill_(0).cuda()
            for i in range(A['n']):  # A has this many rows, not mapped yet
                actR = A['Map'][i]  # Actual row in the matrix
                SA+=torch.ger(sketch[:,actR], AM[i])
        else:
            AM=A.cuda()
            SA=torch.mm(sketch, AM)

        # print(SA.max().item(), SA.min().item(), SA.mean().item())
        # print(A.max().item(), A.min().item(), A.mean().item())

        U2, Sigma2, V2 = SA.svd()
        AU = AM.mm(V2)
        U3, Sigma3, V3 = AU.svd()
        ans = U3[:, :k].mm(torch.diag(Sigma3[:k]).cuda()).mm(V3.t()[:k]).mm(V2.t())
        # totLoss += (torch.norm(ans - AM) ** 2).item()
        totLoss += (torch.norm(ans - AM)).item()
        count += 1
        if (count % 10 == 0):
            print(count, end=",")
            sys.stdout.flush()
    return totLoss


def evaluate_both(eval_list,sketch_vector, sketch_value,m,k,n,d):  # evaluate the test/train performance
    totLoss = 0
    count = 0


    for A in eval_list:
        if sparse:
            AM=A['M'].cuda()
            SA = torch.Tensor(m, A['d']).fill_(0).cuda()
            for i in range(A['n']):  # A has this many rows, not mapped yet
                actR = A['Map'][i]  # Actual row in the matrix
                mapR = sketch_vector[actR]  # row is mapped to this row in the sketch
                SA[mapR] += AM[i] * sketch_value[actR]  # remember: times the weight
        else:
            AM=A.cuda()
            SA = torch.Tensor(m, d).fill_(0).cuda()
            for i in range(n):  # A has this many rows, not mapped yet
                mapR = sketch_vector[i]  # row is mapped to this row in the sketch
                SA[mapR] += AM[i] * sketch_value[i]  # remember: times the weight

        U2, Sigma2, V2 = SA.svd()
        AU = AM.mm(V2)
        U3, Sigma3, V3 = AU.svd()
        ans = U3[:, :k].mm(torch.diag(Sigma3[:k]).cuda()).mm(V3.t()[:k]).mm(V2.t())
        totLoss += (torch.norm(ans - AM)).item()
        count += 1
        if (count % 10 == 0):
            print(count, end=",")
            sys.stdout.flush()
    return totLoss
def evaluate_extra(sparse, eval_list,sketch_vector, sketch_value,sketch_vector2, sketch_value2,m,mextra,k,n,d):
    totLoss = 0
    count = 0
    for A in eval_list:
        if sparse:
            AM=A['M'].cuda()
            SA = torch.Tensor(m+mextra, A['d']).fill_(0).cuda()
            for i in range(A['n']):  # A has this many rows, not mapped yet
                actR = A['Map'][i]  # Actual row in the matrix
                mapR = sketch_vector[actR]  # row is mapped to this row in the sketch
                SA[mapR] += AM[i] * sketch_value[actR]  # remember: times the weight

                mapR=sketch_vector2[actR]+m
                SA[mapR]+= AM[i] * sketch_value2[actR]
        else:
            AM=A.cuda()
            SA = torch.Tensor(m+mextra, d).fill_(0).cuda()
            for i in range(n):  # A has this many rows, not mapped yet
                mapR = sketch_vector[i]  # row is mapped to this row in the sketch
                SA[mapR] += AM[i] * sketch_value[i]  # remember: times the weight

                mapR=sketch_vector2[i]+m
                SA[mapR] += AM[i] * sketch_value2[i]  # remember: times the weight

        U2, Sigma2, V2 = SA.svd()
        AU = AM.mm(V2)
        U3, Sigma3, V3 = AU.svd()
        ans = U3[:, :k].mm(torch.diag(Sigma3[:k]).cuda()).mm(V3.t()[:k]).mm(V2.t())
        totLoss += (torch.norm(ans - AM)).item()
        count += 1
        if (count % 10 == 0):
            print(count, end=",")
            sys.stdout.flush()
    return totLoss

def getAvgDim(A_list):
    nL=[]
    dL=[]
    for A in A_list:
        nL.append(A['n'])
        dL.append(A['d'])
    print('Avg height',np.average(nL),'Avg width',np.average(dL))

def getbest(A_train, A_test,k,data,best_file):
    best_train = bestPossible(A_train, k, data).tolist()
    best_test = bestPossible(A_test,k,data).tolist()
    print('best', best_train / len(A_train), best_test / len(A_test))
    torch.save([best_train/len(A_train),best_test/len(A_test)],best_file)
    return best_train, best_test