import pandas as pd
import numpy as np
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn import model_selection, metrics, preprocessing
from sklearn.preprocessing import MinMaxScaler
import copy
import collections 
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import scipy.stats as stats

class MovieDataset:
    def __init__(self, users, movies, ratings):
        self.users = users
        self.movies = movies
        self.ratings = ratings
    # len(movie_dataset)
    def __len__(self):
        return len(self.users)
    # movie_dataset[1]
    def __getitem__(self, item):

        users = self.users[item]
        movies = self.movies[item]
        ratings = self.ratings[item]

        return {
            "users": torch.tensor(users, dtype=torch.long),
            "movies": torch.tensor(movies, dtype=torch.long),
            "ratings": torch.tensor(ratings, dtype=torch.long),
        }

class RecSysModel(nn.Module):
    def __init__(self, n_users, n_movies):
        super().__init__()
        # trainable lookup matrix for shallow embedding vectors

        self.user_embed = nn.Embedding(n_users, 16)
        self.movie_embed = nn.Embedding(n_movies, 16)
        self.user_tower = nn.Linear(16,16)
        self.movie_tower = nn.Linear(16,16)
#         self.user_embed_normalize = nn.functional.normalize
#         self.movie_embed_normalize = nn.functional.normalize
        self.dot = torch.bmm
#         self.get_logit = nn.Sigmoid()
#         self.out = nn.Linear(1, 1)

        # user, movie embedding concat
        # self.out = nn.Linear(64, 1)

    def forward(self, users, movies, ratings=None):
        user_embeds = self.user_embed(users)
        movie_embeds = self.movie_embed(movies)
        user_tower_out = self.user_tower(user_embeds)
        movie_tower_out = self.movie_tower(movie_embeds)
        # output = torch.([user_embeds, movie_embeds], dim=1)

        # output = self.out(output)
        output = self.dot(user_tower_out.view(-1, 1, 16), movie_tower_out.view(-1, 16, 1)).view(-1,1)
#         output = self.get_logit(inner_prod)
        # output = self.out(inner_prod)
        return output

def get_embed_array(model, user_id_ls, movie_id_ls):
    with torch.no_grad():
        all_user_embeds = model.user_embed(torch.tensor(user_id_ls, dtype=torch.long))
        all_movie_embeds = model.movie_embed(torch.tensor(movie_id_ls, dtype=torch.long))
        all_user_tower_out = model.user_tower(all_user_embeds)
        all_movie_tower_out = model.movie_tower(all_movie_embeds)
        all_user_embeds_normalized = nn.functional.normalize(all_user_tower_out,dim=1)
        all_movie_embeds_normalized = nn.functional.normalize(all_movie_tower_out,dim=1)
        user_embeds_array = all_user_embeds_normalized.cpu().detach().numpy()
        movie_embeds_array = all_movie_embeds_normalized.cpu().detach().numpy()
    return user_embeds_array, movie_embeds_array

def binary_rating_conversion(s):
    if s > 3:
        return 1
    else:
        return 0

def float_rating_conversion(s):
    if s >= 3:
        return (s-3)/2
    else:
        res = (s-3)/2
        return max(-1.0,res)

def normalize_vector(v):
    """ Normalize vector v to a unit-length vector.
    """
    return v / np.linalg.norm(v)

def normalize_each_row(V):
    """ Normalize each row of a matrix.
    """
    return (V.T / np.linalg.norm(V, axis=1) ).T

def softmax_prob(u, V, beta):
    """ u: user feature, a (d,) vector
      V: creator features, (n, d) matrix
      beta: parameter
    """
    unnormalized_prob = np.exp( beta * np.dot(V, u) )
    prob = unnormalized_prob / np.sum(unnormalized_prob)
    return prob

def get_binary_label(user_embed, movie_embed):
    inner_prod = user_embed @ movie_embed.T
    if inner_prod > 0:
        return 1
    else:
        return 0
    
def get_float_label(user_embed, movie_embed):
    inner_prod = user_embed @ movie_embed
    inner_prod = inner_prod/(np.linalg.norm(user_embed)*np.linalg.norm(movie_embed))
    return inner_prod

def show_result(record, T, show_every, title_string):
    U_record, V_record = record["U_record"], record["V_record"]
    assert T == U_record.shape[0] - 1

    user_deltas = [0 for t in range(T)]
    creator_deltas = [0 for t in range(T)]

    T_list = list(range(T))
    for t in range(T):
        U, V = U_record[t, :, :], V_record[t, :, :]

        ### Plot:
        if t % show_every == 0:
            fig, ax = plt.subplots()
            ax.scatter(V[0, :], V[1, :], marker='X')
            ax.scatter(U[0, :], U[1, :], s=2)
            # ax.axis('equal')
            ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))
            ax.set_title(title_string + f";  t={t}")

        # record the change of feature vectors
        user_deltas[t] = np.linalg.norm(U_record[t+1, :, :] - U_record[t, :, :])
        creator_deltas[t] = np.linalg.norm(V_record[t+1, :, :] - V_record[t, :, :])

    fig, ax = plt.subplots()
    ax.scatter(V[0, :], V[1, :], marker='X')
    ax.scatter(U[0, :], U[1, :], s=2)
    # ax.axis('equal')
    ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))
    ax.set_title(title_string + f";  T={T}")
    plt.show()

    plt.figure()
    plt.plot(T_list, user_deltas)
    plt.title('Change of user preference')

    plt.figure()
    plt.plot(T_list, creator_deltas)
    plt.title('Change of creator feature')
    
def argmax(iterable):
    return max(enumerate(iterable), key=lambda x: x[1])[0]

def diversity_score_arr(V, prev_item_ls, curr_item_ls):
    n_prev = len(prev_item_ls)
    prev_item_embed = V[prev_item_ls]
    curr_item_embed = V[curr_item_ls]
    inner_prod = curr_item_embed @ prev_item_embed.T
#     print(inner_prod)
    div_score_arr = np.ones(len(curr_item_ls))-np.sum(inner_prod,axis=1)/n_prev
    return div_score_arr

def diversity_score(V, prev_item_ls, curr_item_id):
    n_prev = len(prev_item_ls)
    prev_item_embed = V[prev_item_ls]
    curr_item_embed = V[curr_item_id]
    inner_prod = curr_item_embed @ prev_item_embed.T
#     print(inner_prod)
    div_score = 1-np.sum(inner_prod)/n_prev
    return div_score

def objective_w_div(V, prev_item_ls, curr_item_ls, user_embed, rho):
    ovr_score_ls = []
    for curr_item_id in curr_item_ls:
        div_score = diversity_score(V, prev_item_ls, curr_item_id)
        rel_score = V[curr_item_id] @ user_embed.T
        ovr_score_ls.append(div_score*rho + rel_score)
    best_item = curr_item_ls[argmax(ovr_score_ls)]
    return ovr_score_ls, best_item

def get_dynamics(U_init, V_init, eta_u, eta_c, beta, T,
                 user_update_rule="inner_product", creator_update_rule="inner_product",
                 fixed_dimension=0):
    """ Input:
        - U_init: m * d matrix, each row is a user feature vector
        - V_init: n * d matrix, each row is a creator feature vector
        Returns a dict consisting of:
        - "U_record": a  T * m * d array
        - "V_record": a  T * n * d array
    """
    n, d = V_init.shape
    m = U_init.shape[0]
    assert U_init.shape[1] == d
    U_record = np.zeros((T, m, d))
    V_record = np.zeros((T, n, d))
    U_record[0] = np.copy(U_init)
    V_record[0] = np.copy(V_init)

    if fixed_dimension > 0:
        non_fixed_norm = np.zeros((m))
        for j in range(m):
            non_fixed_norm[j] = np.sqrt( 1 - np.linalg.norm(U_init[j, :fixed_dimension])**2 )

    for t in range(T-1):
        U = U_record[t]
        V = V_record[t]

        ### Sample: sample a creator for each user:
        user_to_creator = [None for j in range(m)]
        creator_to_users = [[] for i in range(n)]
        # p_matrix = np.zeros((n, m))
        for j in range(m):
            prob = softmax_prob(U[j], V, beta)
            # p_matrix[:, j] = prob
            i = np.random.choice(range(n), p=prob)
            user_to_creator[j] = i
            # add user j to creator i's list
            creator_to_users[i].append(j)

        ### User update:
        new_U = np.copy(U)
        eta_u_value = eta_u(t)
        for j in range(m):
            uj = U[j]
            vi = V[user_to_creator[j]]
            if fixed_dimension == 0:
                new_U[j] = normalize_vector( uj + eta_u_value * np.dot(uj, vi) * vi )
            else:
                new_U[j, fixed_dimension:] = normalize_vector( uj[fixed_dimension:] +  eta_u_value*np.dot(uj, vi)*vi[fixed_dimension:] ) * non_fixed_norm[j]
                # print("after: ", new_U[j], "  norm:", np.linalg.norm(new_U[j]))

        ### Creator update:
        new_V = np.copy(V)
        eta_c_value = eta_c(t)
        for i in range(n):
            if creator_to_users[i] != []:
                if creator_update_rule == "average":
                    avg = np.mean(U[creator_to_users[i]], axis=0)
                    new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
                elif creator_update_rule == "inner_product":
                    J = creator_to_users[i]
                    tmp = ( U[J].T * np.dot(U[J], V[i]) ).T
                    avg = np.mean(tmp, axis=0)
                    new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
                elif creator_update_rule == "fixed":
                    pass
                else:
                    raise BaseException(f"creator update rule {creator_update_rule} not supported")

        # record the new feature vectors
        U_record[t+1] = new_U
        V_record[t+1] = new_V

    return {"U_record":U_record, "V_record":V_record}

def get_dynamics_w_div_old(U_init, V_init, eta_u, eta_c, beta, T, div_len=5, rho = 0.5,
                       user_update_rule="inner_product", creator_update_rule="inner_product",
                       fixed_dimension=0):
    """ Input:
        - U_init: m * d matrix, each row is a user feature vector
        - V_init: n * d matrix, each row is a creator feature vector
        Returns a dict consisting of:
        - "U_record": a  T * m * d array
        - "V_record": a  T * n * d array
    """
    n, d = V_init.shape 
    m = U_init.shape[0]
    assert U_init.shape[1] == d
    U_record = np.zeros((T, m, d))
    V_record = np.zeros((T, n, d))
    U_record[0] = np.copy(U_init)
    V_record[0] = np.copy(V_init)
    
    hist_item_dict = {}
    for j in range(m):
        hist_item_dict[j] = []

    if fixed_dimension > 0:
        non_fixed_norm = np.zeros((m))
        for j in range(m):
            non_fixed_norm[j] = np.sqrt( 1 - np.linalg.norm(U_init[j, :fixed_dimension])**2 )

    for t in range(T-1):
        U = U_record[t]
        V = V_record[t]

        ### Sample: sample a creator for each user:
        user_to_creator = [None for j in range(m)]
        creator_to_users = [[] for i in range(n)]
        # p_matrix = np.zeros((n, m))
        for j in range(m):
            prob = softmax_prob(U[j], V, beta)
#             print(T, j, len(hist_item_dict[j]))
            if len(hist_item_dict[j]) < div_len:
                i = np.random.choice(range(n), p=prob)
                hist_item_dict[j].append(i)
            else:
                user_embed = copy.deepcopy(U[j])
                candidate_ls = np.random.choice(range(n), size=5, replace=False, p=prob)
                ovr_score_ls, i = objective_w_div(V, hist_item_dict[j], candidate_ls, user_embed, rho)
                new_prev_ls = copy.deepcopy(hist_item_dict[j][1:])
                new_prev_ls.append(i)
                hist_item_dict[j] = new_prev_ls
            user_to_creator[j] = i
            # add user j to creator i's list
            creator_to_users[i].append(j)

        ### User update:
        new_U = np.copy(U)
        eta_u_value = eta_u(t)
        for j in range(m):
            uj = U[j]
            vi = V[user_to_creator[j]]
            if fixed_dimension == 0:
                new_U[j] = normalize_vector( uj + eta_u_value * np.dot(uj, vi) * vi )
            else:
                new_U[j, fixed_dimension:] = normalize_vector( uj[fixed_dimension:] +  eta_u_value*np.dot(uj, vi)*vi[fixed_dimension:] ) * non_fixed_norm[j]
                # print("after: ", new_U[j], "  norm:", np.linalg.norm(new_U[j]))

        ### Creator update:
        new_V = np.copy(V)
        eta_c_value = eta_c(t)
        for i in range(n):
            if creator_to_users[i] != []:
                if creator_update_rule == "average":
                    avg = np.mean(U[creator_to_users[i]], axis=0)
                    new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
                elif creator_update_rule == "inner_product":
                    J = creator_to_users[i]
                    tmp = ( U[J].T * np.dot(U[J], V[i]) ).T
                    avg = np.mean(tmp, axis=0)
                    new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
                elif creator_update_rule == "fixed":
                    pass
                else:
                    raise BaseException(f"creator update rule {creator_update_rule} not supported")

        # record the new feature vectors
        U_record[t+1] = new_U
        V_record[t+1] = new_V

    return {"U_record":U_record, "V_record":V_record}

def get_dynamics_w_div(U_init, V_init, eta_u, eta_c, beta, T, div_len=5, rho = 0.5,
                       user_update_rule="inner_product", creator_update_rule="inner_product",
                       fixed_dimension=0):
    """ Input:
        - U_init: m * d matrix, each row is a user feature vector
        - V_init: n * d matrix, each row is a creator feature vector
        Returns a dict consisting of:
        - "U_record": a  T * m * d array
        - "V_record": a  T * n * d array
    """
    n, d = V_init.shape 
    m = U_init.shape[0]
    assert U_init.shape[1] == d
    U_record = np.zeros((T, m, d))
    V_record = np.zeros((T, n, d))
    U_record[0] = np.copy(U_init)
    V_record[0] = np.copy(V_init)
    
    hist_item_dict = {}
    for j in range(m):
        hist_item_dict[j] = []

    if fixed_dimension > 0:
        non_fixed_norm = np.zeros((m))
        for j in range(m):
            non_fixed_norm[j] = np.sqrt( 1 - np.linalg.norm(U_init[j, :fixed_dimension])**2 )

    for t in range(T-1):
        U = U_record[t]
        V = V_record[t]

        ### Sample: sample a creator for each user:
        user_to_creator = [None for j in range(m)]
        creator_to_users = [[] for i in range(n)]
        # p_matrix = np.zeros((n, m))
        for j in range(m):
            prob = softmax_prob(U[j], V, beta)
#             print(T, j, len(hist_item_dict[j]))
            if len(hist_item_dict[j]) < div_len:
                i = np.random.choice(range(n), p=prob)
                hist_item_dict[j].append(i)
            else:
                user_embed = copy.deepcopy(U[j])
                div_score_arr = diversity_score_arr(V, hist_item_dict[j], list(range(n)))
                div_score_arr_norm = div_score_arr/div_score_arr.sum()
                prob_w_div_arr = np.array(prob)+rho*div_score_arr_norm
                prob_w_div_arr_norm = prob_w_div_arr/prob_w_div_arr.sum()
                prob_w_div_norm = prob_w_div_arr_norm.tolist()
                i = np.random.choice(range(n), p=prob_w_div_norm)
                new_prev_ls = copy.deepcopy(hist_item_dict[j][1:])
                new_prev_ls.append(i)
                hist_item_dict[j] = new_prev_ls
            user_to_creator[j] = i
            # add user j to creator i's list
            creator_to_users[i].append(j)

        ### User update:
        new_U = np.copy(U)
        eta_u_value = eta_u(t)
        for j in range(m):
            uj = U[j]
            vi = V[user_to_creator[j]]
            if fixed_dimension == 0:
                new_U[j] = normalize_vector( uj + eta_u_value * np.dot(uj, vi) * vi )
            else:
                new_U[j, fixed_dimension:] = normalize_vector( uj[fixed_dimension:] +  eta_u_value*np.dot(uj, vi)*vi[fixed_dimension:] ) * non_fixed_norm[j]
                # print("after: ", new_U[j], "  norm:", np.linalg.norm(new_U[j]))

        ### Creator update:
        new_V = np.copy(V)
        eta_c_value = eta_c(t)
        for i in range(n):
            if creator_to_users[i] != []:
                if creator_update_rule == "average":
                    avg = np.mean(U[creator_to_users[i]], axis=0)
                    new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
                elif creator_update_rule == "inner_product":
                    J = creator_to_users[i]
                    tmp = ( U[J].T * np.dot(U[J], V[i]) ).T
                    avg = np.mean(tmp, axis=0)
                    new_V[i] = normalize_vector(V[i] + eta_c_value * avg)
                elif creator_update_rule == "fixed":
                    pass
                else:
                    raise BaseException(f"creator update rule {creator_update_rule} not supported")

        # record the new feature vectors
        U_record[t+1] = new_U
        V_record[t+1] = new_V

    return {"U_record":U_record, "V_record":V_record}

def get_dynamics_real_data(U_init, V_init, eta_u, eta_c, beta, T, model_perf_rec, div_len=5, rho = 0.5,
                       user_update_rule="inner_product", creator_update_rule="inner_product",
                       fixed_dimension=0):
    """ Input:
        - U_init: m * d matrix, each row is a user feature vector
        - V_init: n * d matrix, each row is a creator feature vector
        Returns a dict consisting of:
        - "U_record": a  T * m * d array
        - "V_record": a  T * n * d array
    """
    n, d = V_init.shape 
    m = U_init.shape[0]
    assert U_init.shape[1] == d
    
    user_id_ls = list(range(m))
    movie_id_ls = list(range(n))

    user_embeds_array = copy.deepcopy(U_init)
    movie_embeds_array = copy.deepcopy(V_init)

    loss_func = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model_perf_rec.parameters())
    sch = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7)

    U_record = np.zeros((T, m, 16))
    V_record = np.zeros((T, n, 16))
    U_record[0] = copy.deepcopy(U_init)
    V_record[0] = copy.deepcopy(V_init)
    
    U_model = copy.deepcopy(U_init)
    V_model = copy.deepcopy(V_init)

    hist_item_dict = {}
    for j in range(m):
        hist_item_dict[j] = []

    for t in range(T-1):
    #     if (t+1)%50 == 0:
    #         print(t+1,'/',T)

        U = U_record[t]
        V = V_record[t]

        user_to_creator = [None for j in range(m)]
        creator_to_users = [[] for i in range(n)]
        p_matrix = np.zeros((n, m))
        for j in range(m):
            prob = softmax_prob(U_model[j], V_model, beta)
            if len(hist_item_dict[j]) < div_len:
                i = np.random.choice(range(n), p=prob)
                hist_item_dict[j].append(i)
            else:
                user_embed = copy.deepcopy(U_model[j])
                div_score_arr = diversity_score_arr(V_model, hist_item_dict[j], list(range(n)))
                div_score_arr_norm = div_score_arr/div_score_arr.sum()
                prob_w_div_arr = np.array(prob)+rho*div_score_arr_norm
                prob_w_div_arr_norm = prob_w_div_arr/prob_w_div_arr.sum()
                prob_w_div_norm = prob_w_div_arr_norm.tolist()
                i = np.random.choice(range(n), p=prob_w_div_norm)
                new_prev_ls = copy.deepcopy(hist_item_dict[j][1:])
                new_prev_ls.append(i)
    #             print(new_prev_ls)
                hist_item_dict[j] = new_prev_ls
            user_to_creator[j] = i
            creator_to_users[i].append(j)

        new_U = np.copy(U)
        for j in range(m):
            uj = U[j]
            vi = V[user_to_creator[j]]
            if fixed_dimension == 0:
                new_U[j] = normalize_vector( uj + eta_u(t) * np.dot(uj, vi) * vi )
            else:
                new_U[j, fixed_dimension:] = normalize_vector( uj[fixed_dimension:] +  eta_u(t)*np.dot(uj, vi)*vi[fixed_dimension:] ) * non_fixed_norm[j]

        new_V = np.copy(V)
        for i in range(n):
            if creator_to_users[i] != []:
                J = creator_to_users[i]
                tmp = ( U[J].T * np.dot(U[J], V[i]) ).T
                avg = np.mean(tmp, axis=0)
                new_V[i] = normalize_vector(V[i] + eta_c(t) * avg)

        rec_user_id_ls = copy.deepcopy(list(range(m)))
        rec_movie_id_ls = [user_to_creator[user_id] for user_id in rec_user_id_ls]
        rec_rating_ls = []
        for user_id in rec_user_id_ls:
            user_embed = user_embeds_array[user_id,:]
            movie_id = user_to_creator[user_id]
            movie_embed = movie_embeds_array[movie_id,:]
            rating = get_binary_label(user_embed, movie_embed)
            rec_rating_ls.append(rating)

        rec_data = {'userId': rec_user_id_ls, 'movieId': rec_movie_id_ls, 'rating':rec_rating_ls}
        rec_df = pd.DataFrame(data=rec_data)

        perf_rec_dataset = MovieDataset(
            users=rec_df.userId.values,
            movies=rec_df.movieId.values,
            ratings=rec_df.rating.values
        )

        perf_rec_loader = DataLoader(dataset=perf_rec_dataset,
                                  batch_size=4,
                                  shuffle=True,
                                  num_workers=2,drop_last=True)

        epochs = 3

        model_perf_rec.train()
        for epoch_i in range(epochs):
            for i, perf_rec_data in enumerate(perf_rec_loader):
                output = model_perf_rec(perf_rec_data["users"],
                                        perf_rec_data["movies"]
                                       )

                # .view(4, -1) is to reshape the rating to match the shape of model output which is 4x1
                rating = perf_rec_data["ratings"].view(4, -1).to(torch.float32)

                loss = loss_func(output, rating)
                # total_loss = total_loss + loss.sum().item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        user_embeds_array_new, movie_embeds_array_new = get_embed_array(model_perf_rec, user_id_ls, movie_id_ls)
        U_model = copy.deepcopy(user_embeds_array_new)
        V_model = copy.deepcopy(movie_embeds_array_new)

        U_record[t+1, :, :] = new_U
        V_record[t+1, :, :] = new_V
        
    return {"U_record":U_record, "V_record":V_record}

def show_dynamics_high_dim(record, show_every=1, show_array=None, title_string="", projection="first_two"):
    """ record is a dictionary of U_record, V_record, where
      U_record: T * m * d
      V_record: T * n * d
    """
    U_record, V_record = record["U_record"], record["V_record"]
    T, m, d = U_record.shape;  n = V_record.shape[1]
    assert T == V_record.shape[0];  assert d == V_record.shape[2]

    user_deltas = np.zeros(T-1)
    creator_deltas = np.zeros(T-1)

    T_list = list(range(T))
    for t in range(T):
        U, V = U_record[t], V_record[t]

        to_show=False
        if t == 0 or t == T-1:
            to_show = True
        elif show_array != None:
            to_show = t in show_array
        elif (t+1) % show_every == 0:
            to_show = True

        ### Plot:
        if to_show:
            if d == 2 or projection == "first_two":
                fig, ax = plt.subplots()
                ax.scatter(V[:, 0], V[:, 1], marker='X')
                ax.scatter(U[:, 0], U[:, 1], s=2)
                # ax.axis('equal')
                ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))
                if t < T-1:
                    ax.set_title(title_string + f";  t={t+1}")
                else:
                    ax.set_title(title_string + f";  T={T}")

            elif projection == "pca":
                X = np.zeros(shape=(m + n, d))
                X[0:m, :] = U;  X[m:, :] = V
                pca = PCA(n_components=2)
                X_pca = pca.fit_transform(X)

                plt.figure()
                plt.scatter(x=X_pca[m:, 0],  y=X_pca[m:,  1], marker='X')
                plt.scatter(x=X_pca[0:m, 0], y=X_pca[0:m, 1], s=4)
                if t < T-1:
                    plt.title("PCA visualization of feature vectors\n" + title_string + f";  t={t+1}")
                else:
                    plt.title("PCA visualization of feature vectors\n" + title_string + f";  T={T}")

            elif projection == "tsne":
                X = np.zeros(shape=(m + n, d))
                X[0:m, :] = U;  X[m:, :] = V
                tsne = TSNE(n_components=2, random_state=42)
                X_tsne = tsne.fit_transform(X)

                plt.figure()
                plt.scatter(x=X_tsne[m:, 0],  y=X_tsne[m:,  1], marker='X')
                plt.scatter(x=X_tsne[0:m, 0], y=X_tsne[0:m, 1], s=4)
                if t < T-1:
                    plt.title("t-SNE visualization of feature vectors\n" + title_string + f";  t={t+1}")
                else:
                    plt.title("t-SNE visualization of feature vectors\n" + title_string + f";  T={T}")

            else:
                raise BaseException(f"projection {projection} not supported")

        # record the change of feature vectors
        if t < T-1:
            user_deltas[t] = np.linalg.norm(U_record[t+1] - U_record[t])
            creator_deltas[t] = np.linalg.norm(V_record[t+1] - V_record[t])

    plt.figure()
    plt.plot(T_list[:-1], user_deltas)
    plt.title('Change of user features')

    plt.figure()
    plt.plot(T_list[:-1], creator_deltas)
    plt.title('Change of creator features')
    plt.show()
    
def avg_pairwise_distance(V):
    """ V: (n, d) matrix """
    return np.mean( sp.spatial.distance.pdist(V) )

def user_relevance(U, V, beta):
    """ U: (m, d),  V: (n, d),   beta: the parameter in softmax """
    m = U.shape[0]
    pairwise_dot = U @ V.T  # m * n
    exp_score = np.exp( beta * pairwise_dot )  # m * n
    softmax_prob = (exp_score.T / np.sum(exp_score, axis=1) ).T   # m * n, each row sums to 1
    return np.sum( softmax_prob * pairwise_dot ) / m

def user_weighted_variance(U, V, beta):
    """ U: (m, d),  V: (n, d)
      return the weighted variance of the creators recommended to a user
    """
    m = U.shape[0]
    pairwise_dot = U @ V.T  # m * n
    exp_score = np.exp( beta * pairwise_dot )  # m * n
    softmax_prob = (exp_score.T / np.sum(exp_score, axis=1) ).T   # m * n, each row sums to 1
    V_bar = softmax_prob @ V  # (m, n) * (n, d) = (m, d)
    V_V_bar_distance_squared= sp.spatial.distance.cdist(V_bar, V, 'sqeuclidean')  # m * n
    return np.sum( softmax_prob * V_V_bar_distance_squared) / m

def tendency_to_polarization(V):
    """ V: (n, d) """
    return np.sum( np.abs( V @ V.T ) ) / V.shape[0]**2

def get_all_measures(U_record, V_record, beta):
    """ compute all the measures, given:
      - U_record: T * n * d
      - V_record: T * m * d
      - beta: parameter in softmax
      Return: a dictionary consisting of:
      {"CD": creator diversity: (T,) array
       "UR": user relevance:    (T,) array
       "UD": user diversity:    (T,) array
       "TP": tendency to polarization: (T,) array
      }
    """
    T = U_record.shape[0]; assert V_record.shape[0] == T;
    assert U_record.shape[2] == V_record.shape[2]
    CD = np.zeros(T);  UR = np.zeros(T)
    UD = np.zeros(T);  TP = np.zeros(T)
    for t in range(T):
        U, V = U_record[t], V_record[t]
        CD[t] = avg_pairwise_distance(V)
        UR[t] = user_relevance(U, V, beta)
        UD[t] = user_weighted_variance(U, V, beta)
        TP[t] = tendency_to_polarization(V)
    return {"CD":CD, "UR":UR, "UD":UD, "TP":TP}

def plot_mean_and_conf(data, labels, conf=0.95, smooth_window=10,
                       ylabel='Y_label', xlabel='X_label', title='Title', figsize=(10, 8), fontsize=16, alpha=0.4,
                       save_to=""
                       ):
    # data:          np araray with shape (n_experiments, n_trials, n_timesteps)
    # labels:        list of strings, where labels[i] = "THE_NAME_FOR_experiment_i"
    # conf:          confidence interval, e.g., .95 yields 95% confidence intervals
    # smooth_window: number of timesteps overwhich to average when plotting means and confidence intervals (higher value yeilds more smooth lines).

    plt.figure(figsize=figsize)
    for i, experiment in enumerate(data):

        means = np.mean(experiment,   axis=0)
        sems  = stats.sem(experiment, axis=0)
        ci = sems * stats.t.ppf((1 + conf) / 2, experiment.shape[0] - 1)

        timesteps = np.arange(1, experiment.shape[1] + 1)


        means = mn_vec(means, smooth_window)
        ci    = mn_vec(ci, smooth_window)
        plt.plot(timesteps, means, label=labels[i], linewidth=3)
        plt.fill_between(timesteps, means+ci, means-ci, alpha=alpha)

    plt.xlabel(xlabel, fontsize=fontsize)
    plt.ylabel(ylabel, fontsize=fontsize)
    if title is not None:
        plt.title(title, fontsize=1.1*fontsize)
    plt.legend(fontsize=fontsize)
    plt.grid(True)
    if save_to != "":
        plt.savefig(save_to)

    plt.show()


def mn_vec(vec, smooth_window):
    n = len(vec)
    ret_val = [np.mean(vec[max(0, i-smooth_window//2): i+smooth_window//2]) for i in range(n)]
    return np.array(ret_val)