from glob import glob
import os
import ot
import geomloss
from scipy.spatial.distance import cdist
import torch
import numpy as np
from timeit import default_timer as timer
import csv
import heapq
import tqdm


with open('emb_c16_sort.txt', 'r') as f:
    split_data = f.read()
    emb_files_c16 = split_data.split('\n')
print(len(emb_files_c16))

with open('emb_c17_sort.txt', 'r') as f:
    split_data = f.read()
    emb_files_c17 = split_data.split('\n')
print(len(emb_files_c17))

print(torch.cuda.is_available())
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

c16_c17_loss_matrix = []
for i in range(len(emb_files_c16)):
    print(i)
    line = []
    tic = timer()
    a = torch.load(os.path.join('c16_feature_path', emb_files_c16[i]))
    att_a = torch.load(os.path.join('c16_attention_path', emb_files_c16[i]))
    for j in range(len(emb_files_c17)):
        torch.cuda.empty_cache()
        b = torch.load(os.path.join('c17_feature_path', emb_files_c17[j]))
        att_b = torch.load(os.path.join('c17_attention_path', emb_files_c17[j]))
        
        M = ot.dist(a, b, metric='euclidean')
        alpha = torch.tensor(att_a).to(device)
        beta = torch.tensor(att_b).to(device)
        M = torch.tensor(M).to(device)
        entreg = 0.5
        # pw = ot.bregman.greenkhorn(alpha, beta, M, numItermax=1000, reg=entreg)
        pw = ot.bregman.sinkhorn_log(alpha, beta, M, numItermax=1000, reg=entreg)
        loss = torch.mul(M, pw)
        loss = loss.sum()
        line.append(loss)
        print(loss)
        del alpha
        del b
        del beta
        del M

    c16_c17_loss_matrix.append(line)
    toc = timer()
    del a
    print(toc-tic,'s')
        
torch.save(c16_c17_loss_matrix, os.path.join('optimal_transport_distance/c16_c17_20att', 'c16_c17_matrix.pt'))  
        
    