from glob import glob
import os
import ot
import geomloss
from scipy.spatial.distance import cdist
import torch
import numpy as np
from timeit import default_timer as timer
import csv
import heapq
import tqdm


with open('emb_c16_sort.txt', 'r') as f:
    split_data = f.read()
    emb_files = split_data.split('\n')
print(len(emb_files))

print(torch.cuda.is_available())
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

c16_c16_loss_matrix = []
for i in range(len(emb_files)):
    print(i)
    line = []
    tic = timer()
    a = torch.load(os.path.join('feature_path', emb_files[i]))
    att_a = torch.load(os.path.join('attention_path', emb_files[i])).squeeze()
    
    base_a = os.path.basename(emb_files[i])
    for j in range(i+1, len(emb_files)):
        torch.cuda.empty_cache()
        b = torch.load(os.path.join('feature_path', emb_files[j]))
        att_b = torch.load(os.path.join('attention_path', emb_files[j])).squeeze()
        M = ot.dist(a, b, metric='euclidean')

        entreg = 0.5
        alpha = torch.tensor(att_a).to(device)
        beta = torch.tensor(att_b).to(device)
        M = torch.tensor(M).to(device)
        # pw = ot.bregman.greenkhorn(alpha, beta, M, numItermax=1000, reg=entreg)
        pw = ot.bregman.sinkhorn_log(alpha, beta, M, numItermax=1000, reg=entreg)
        loss = torch.mul(M, pw)
        loss = loss.sum()
        line.append(loss)
        print(loss)
        del alpha
        del b
        del beta
        del M
    
    c16_c16_loss_matrix.append(line)
    toc = timer()
    del a
    del att_a
    print(toc-tic,'s')
        
torch.save(c16_c16_loss_matrix, os.path.join('c16_c16_20att', 'c16_c16_matrix.pt'))  
        
    