
import math
import numpy as np
import time

def generate_multivariate_gaussian(d,sigma):
    """
    生成一个服从d维高斯分布的d维随机向量
    :param d: 随机向量的维度
    :return: d维随机向量
    """
    # 均值向量（通常为零向量）
    mu = np.zeros(d)
    
    # 协方差矩阵（通常为单位矩阵）
    cov_matrix =  sigma**2 * np.eye(d)
    
    # 生成服从多元高斯分布的随机向量
    random_vector = np.random.multivariate_normal(mu, cov_matrix)
    
    return random_vector

def get_r(env,s):
    pos_rew = np.sqrt(np.square(s[:,4] - s[:,8]) + np.square(s[:,5] - s[:,9])) 

    neg_rew = np.sqrt(np.square(s[:,6] - s[:,8]) + np.square(s[:,7] - s[:,9])) 

    # agent_dist = [
    #         np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos)))
    #         for a in world.agents
    #         if not (a.adversary == agent.adversary)
    #     ]

    # pos_rew = min(agent_dist)
    # neg_rew = np.sqrt(
    #     np.sum(np.square(agent.goal_a.state.p_pos - agent.state.p_pos))
    # )
    return pos_rew - neg_rew

def get_omegas_bs(env, sigma, m):
    d = env.state_dim
    omegas = []
    bs = []
    for i in range(m):
        omega = generate_multivariate_gaussian(d,sigma)
        omegas.append(omega)
        b = np.random.uniform(0, 2*np.pi)
        bs.append(b)
    return omegas, bs

def get_rf(env,  m, s, a, b, omegas, bs,r, device):
    d = env.state_dim
    phi = []
    
    next_states = env.batch_next_state(s, a, b, device).detach().cpu().numpy()
    transposed_bs = bs.T
    
    
    nei = np.dot(next_states,omegas.T)+transposed_bs
    phi = np.cos(nei)
    phi = np.concatenate((phi, r), axis=-1)
    return phi
    
    
    # for i in range(next_states.shape[0]):
    #     next_state = next_states[i]
    #     phi_i = []
    #     next_state = next_state.cpu().detach().numpy()
    #     for j in range(m):
    #         phi_j = math.cos(np.dot(next_state,omegas[j])+bs[j])
    #         phi_i.append(phi_j)
    #     phi.append(phi_i)
        
    # return phi

def get_rd():
    rd = np.random.uniform(-1, 1,size = (2))
    return rd * 10


def get_rd_state(world):
    entity_pos = []
    entity_color = []
    for entity in world.landmarks:  # world.entities:
        entity_pos.append(entity.state.p_pos)
        entity_color.append(entity.color)
    # communication of all other agents
    agent_vel = []
    agent_pos = []
    agent_goal = []
    agent_color = []
    for agent in world.agents:
        rd_vel = get_rd()
        agent_vel.append(rd_vel)
        rd_pos = get_rd()
        agent_pos.append(rd_pos)
        agent_goal.append(agent.goal_a.state.p_pos)
        agent_color.append(agent.color)
    return np.concatenate(
                agent_vel
                + agent_pos
                + agent_goal
                + agent_color
                + entity_pos
                + entity_color
            )
    
def get_tag_rd_state(world):
    entity_pos = []
    entity_color = []
    for entity in world.landmarks:  # world.entities:
        entity_pos.append(entity.state.p_pos)
        entity_color.append(entity.color)
    # communication of all other agents
    agent_vel = []
    agent_pos = []
    agent_color = []
    for agent in world.agents:
        rd_vel = get_rd()
        agent_vel.append(rd_vel)
        rd_pos = get_rd()
        agent_pos.append(rd_pos)
        agent_color.append(agent.color)
    return np.concatenate(
                agent_vel
                + agent_pos
                + agent_color
                + entity_pos
                + entity_color
            )


def get_k(x1, x2):
    dis = x1 - x2
    norm = 0
    for i in range(len(dis)):
        norm += dis[i]**2
    k = math.exp(-norm/2)
    return k


def get_gram(states):
    k = len(states)
    gram = np.zeros((k, k))
    for i in range(k):
        for j in range(k):
            if i<= j:
                gram[i][j] = get_k(states[i],states[j])
    for i in range(k):
        for j in range(k):
            if i > j:
                gram[i][j] = gram[j][i]
    return gram
    
def get_nystrom(env , m, s, a, b, m_states, gram,r, device):
    d = env.agent_state_dim
    next_states = env.batch_next_state(s, a, b, device).detach().cpu().numpy()
    ks = np.zeros((s.shape[0], m))
    for i in range(m):
        dis = next_states - m_states[i]
        norm = np.zeros(dis.shape[0])
        for j in range(dis.shape[1]):
            norm += dis[:,j]**2
        ks[:,i] = norm
        
    phi = np.dot(ks,gram)
    phi = np.concatenate((phi, r), axis=-1)
    return phi