### contains all data generator 

import torch
import numpy as np
from torch.distributions import Dirichlet
import pulp
import time
import torch.nn.functional as F
from Human_Model import HumanDecisionMakerModel 

def getXY(bs=1000, seed=None, tile=None):
    M,N,K = 2,2,2
    if seed != None: torch.manual_seed(seed)
    x = torch.zeros(bs, M*(N+N+1))
    x[:,0] = torch.rand(bs)/2 + 0.5
    x[:,1] = torch.rand(bs)/2
    x[:,2] = 0.0
    x[:,3] = 1.0
    x[:,4] = torch.rand(bs)*0.8 + 0.1
    x[:,5] = torch.rand(bs)/2
    x[:,6] = torch.rand(bs)/2 + 0.5
    x[:,7] = 0.0
    x[:,8] = 1.0
    x[:,9] = 1-x[:,4]

    y = torch.zeros(bs, M*K)
    p = torch.zeros(bs, M, K)
    p[:,0,1] = torch.minimum(x[:,9] / (1-x[:,9]) * (x[:,6]-x[:,5]) / (x[:,0]-x[:,1]),torch.tensor(1))
    p[:,0,0] = 1 - p[:,0,1]
    p[:,1,0] = 0
    p[:,1,1] = 1
    y = p.reshape(bs, M*K)  

    if tile:
        x,y = x[0:1],y[0:1]
        x = torch.tile(x,(tile,1))
        y = torch.tile(y,(tile,1))
    return x,y

def getXY_nstate_2action(M = 2,N=2,K=2 ,bs=1000, seed=None, tile=None, solved = False, sender_u_fix = True):
    if seed != None: torch.manual_seed(seed)
    x = torch.zeros(bs, M*(N+N+1))
    y = torch.zeros(bs, M*K)
    uS = torch.zeros(bs, M, N)
    uR = torch.zeros(bs, M, N)
    lam = Dirichlet(torch.ones(M)).sample([bs]).reshape(bs,M,1) 

    if sender_u_fix:
        uS[:,:,1] = 1.0
        uR = torch.rand(bs, M, N)
        r1, _ = torch.sort(uR[:,0,:], dim=1, descending = False)
        uR[:,0,:] = r1
        r2, _ = torch.sort(uR[:,1,:], dim=1, descending = True)
        uR[:,1,:] = r2
        judger =  (lam.reshape(bs,M)  * (uR[:,:,1]- uR[:,:,0] ) ).sum(axis=(1)) > 0
        uR[judger,:,:] = uR[judger][:,:,[1,0]]
    else:
        uS = torch.rand(bs, M, N)* 0.99 + 0.01
        uS[:,:,:],_ = torch.sort(uS[:,:,:], dim=2, descending = False)   
        uR = torch.rand(bs, M, N)
        r1, _ = torch.sort(uR[:,0,:], dim=1, descending = False)
        uR[:,0,:] = r1
        r2, _ = torch.sort(uR[:,1,:], dim=1, descending = True)
        uR[:,1,:] = r2
        
    if solved :
        tstart = time.time()
        policy = torch.zeros(bs, M, K)        
        for case in range(bs):
            policy[case,:,:] = lp_solver(M,N,K, uS[case,:,:], uR[case,:,:], lam[case,:])
            if np.mod(case,100) == 0:
                print('solving instance case=',case, ' in ',time.time()-tstart, ' seconds')
        y = policy.reshape(bs, M*K)  
        print('finish all instance case=',bs, ' in ',time.time()-tstart, ' seconds')
    x = torch.cat((uR,uS,lam),dim = 2).reshape(bs, M*(N+N+1))
    return x,y


def getXY_human_model(M = 2,N=2,K=2 ,bs=1000, seed=None, human_model=None, invpolicy=False):
    if seed != None: torch.manual_seed(seed)
    x = torch.zeros(bs, M*(N+N+1))
    y = torch.zeros(bs, M*K)   ### policy
    uS = torch.zeros(bs, M, N)
    uR = torch.zeros(bs, M, N)
    lam = Dirichlet(torch.ones(M)).sample([bs]).reshape(bs,M,1) 
    uS[:,:,1] = 1.0
    uR = torch.rand(bs, M, N)
    r1, _ = torch.sort(uR[:,0,:], dim=1, descending = False)
    uR[:,0,:] = r1
    r2, _ = torch.sort(uR[:,1,:], dim=1, descending = True)
    uR[:,1,:] = r2
    judger =  (lam.reshape(bs,M)  * (uR[:,:,1]- uR[:,:,0] ) ).sum(axis=(1)) > 0
    uR[judger,:,:] = uR[judger][:,:,[1,0]]
    
    if not human_model:
        y = y
    else:
        y = torch.rand(bs, M, K)
        pi = y/y.sum(axis=2,keepdim=True)
        
        if invpolicy:
            uS[int(bs/2):] = uS[:int(bs/2)]
            uR[int(bs/2):] = uR[:int(bs/2)]
            lam[int(bs/2):] = lam[:int(bs/2)]
            pi[int(bs/2):,:,[1,0]] = pi[:int(bs/2),:,[0,1]] 
            
        if True:
            pass

        human_res = human_model.forward_info(uR=uR.reshape(-1,M,N,1), lam=lam.reshape(-1,M,1,1), pi=pi.reshape(-1,M,1,K))
    
    prob_x = torch.cat((uR,uS,lam),dim = 2).reshape(bs, M*(N+N+1))
    prob_y = pi.reshape(-1,M*K)
    human_x = torch.cat((uR,lam, pi),dim = 2).reshape(bs, M*(N+1+K))
    human_y = human_res.reshape(-1, N*K)
    return prob_x, prob_y, human_x, human_y

    
def  getXY_human_model_exp(M = 2,N=2,K=2 ,bs=1000, seed=None, human_model=None, invpolicy=False, ur_grain = None, lambda_grain = None, policy_grain=None, inv=False):

    if seed != None: torch.manual_seed(seed)
    x = torch.zeros(bs, M*(N+N+1))
    y = torch.zeros(bs, M*K)   ### policy
    uS = torch.zeros(bs, M, N)
    if not inv: # persuade to not purchase
        uS[:,:,1] = 1.0
    else: # persuade to purchase
        uS[:,:,0] = 1.0
    
    lam = Dirichlet(torch.ones(M)).sample([bs]).reshape(bs,M,1) 
    if lambda_grain:
        a1 = torch.randint(low=1, high = lambda_grain, size = (bs,), dtype = torch.float)
        lam[:,0,0] = a1 / float(lambda_grain)
        lam[:,1,0] = 1.0 - lam[:,0,0]
    
    uR = torch.zeros(bs, M, N)
    uR = torch.rand(bs, M, N)
    r1, _ = torch.sort(uR[:,0,:], dim=1, descending = True)
    uR[:,0,:] = r1
    r2, _ = torch.sort(uR[:,1,:], dim=1, descending = False)
    uR[:,1,:] = r2 
    if ur_grain:
        uR = torch.randint(low=0, high = lambda_grain, size = (bs,M,N), dtype = torch.float)
        r1, _ = torch.sort(uR[:,0,:], dim=1, descending = True)
        r2, _ = torch.sort(uR[:,1,:], dim=1, descending = False)
        uR[:,0,0] = torch.randint(low=0, high = lambda_grain, size = (bs,), dtype = torch.float)
        uR[:,1,1] = torch.randint(low=0, high = lambda_grain, size = (bs,), dtype = torch.float)
        uR[:,0,1] = 0
        uR[:,1,0] = 0
        uR = uR / float(lambda_grain)
    if not human_model:
        y = y
    else:
        ### use random policy to query human response
        y = torch.rand(bs, M, K)
        pi = y/y.sum(axis=2,keepdim=True)
        if policy_grain:
            p1 = torch.randint(low=1, high = policy_grain, size = (bs,M,), dtype = torch.float)
            pi[:,:,0] = p1 / float(policy_grain)
            pi[:,:,1] = 1.0 - pi[:,:,0]
        human_x = torch.cat((uR,lam, pi),dim = 2).reshape(bs, M*(N+1+K))
        human_res = human_model.forward(human_x)

    prob_x = torch.cat((uR,uS,lam),dim = 2).reshape(bs, M*(N+N+1))
    prob_y = pi.reshape(-1,M*K)
    human_x = torch.cat((uR,lam, pi),dim = 2).reshape(bs, M*(N+1+K))
    human_y = human_res.reshape(-1, N*K)
    return prob_x, prob_y, human_x, human_y

def decode_tensor(data, label='x', M=2, N=2, K=2):
    if label == 'x':
        x = data.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N)
        uS = x[:,:,N:2*N].reshape(-1,M,N)
        lam = x[:,:,2*N:].reshape(-1,M,1)
        return uR,uS,lam
    elif label == 'y':
        pi = data.reshape(-1,M,K)
        return pi

def lp_solver(M=2, N=2, K=2, uS=None, uR=None, lam=None):
    ### solves information design problem with lp
    if N!=K: K = N # K=N in IC setting
    prob = pulp.LpProblem(sense=pulp.LpMaximize)
    x = pulp.LpVariable.dicts(name='x', indices=(range(M), range(K)), lowBound=0, cat=pulp.LpContinuous)
    prob += pulp.lpSum(lam[m] * x[m][i] *uS[m,i] for m in range(M) for i in range(N))
    for m in range(M):
        prob += pulp.lpSum(x[m][i] for i in range(N)) == 1.0 
    for i in range(N):
        for j in range(N):
            if i != j:
                prob += pulp.lpSum( lam[m] * x[m][i] * (uR[m,i] - uR[m,j]) for m in range(M)) >= 0.0
    prob.solve(pulp.PULP_CBC_CMD(msg=False))
    y = torch.zeros(M,K)
    if pulp.LpStatus[prob.status] == 'Optimal':
        for m in range (M):
            for k in range(K):
                y[m,k] = x[m][k].varValue
        return y
    else:
        print('LP solve status: ', pulp.LpStatus[prob.status])
        return y
    
def getXY_full(bs=1000, seed=None):
    if seed != None: torch.manual_seed(seed)
    x = torch.zeros(bs, M*(N+N+1))
    x[:,0] = torch.rand(bs)
    x[:,1] = torch.rand(bs)
    x[:,2] = torch.rand(bs)
    x[:,3] = torch.rand(bs)
    x[:,4] = torch.rand(bs)*0.8 + 0.1
    x[:,5] = torch.rand(bs)
    x[:,6] = torch.rand(bs)
    x[:,7] = torch.rand(bs)
    x[:,8] = torch.rand(bs)
    x[:,9] = 1-x[:,4]

    y = torch.zeros(bs, M*K)
    y[:,2] = 1
    y[:,0] = 0
    y[:,1] = 0
    y[:,3] = 1
    return x,y

def quick_loss_fn(x,pred,M =2,N=2,K=2, beta=10, reduce=False, verbose=False):
    x = x.reshape(-1,M,N+N+1)
    uR = x[:,:,:N].reshape(-1,M,N,1)
    uS = x[:,:,N:2*N].reshape(-1,M,N,1)
    lam = x[:,:,2*N:].reshape(-1,M,1,1)
    pi = pred.reshape(-1,M,1,K)

    mu = pi*lam 
    mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)       
    UR = mu * uR
    UR = UR.sum(axis=1,keepdim=True)

    if verbose:
        print(UR.shape)
        print(UR)

    # Argmax
    UR_am = UR 
    UR_am[:,:,1,:] += 0.00001
    idx = torch.argmax(UR_am, axis=2)
    UR_am = torch.zeros_like(UR_am).scatter_(dim=2, index = idx.unsqueeze(1), src = torch.ones_like(UR_am))

    if verbose:
        print('----------argmax')
        print(UR_am.shape)
        print(UR_am)

    # Softmax
    UR_sm = UR
    UR_sm = F.softmax(beta * UR_sm, dim=2)

    if verbose:
        print('----------softmax')
        print(UR_sm.shape)
        print(UR_sm)

    total = -(lam * pi * uS * UR_sm).sum(axis=(1,2,3)) ## Change UR_am to UR_sm if you want to use softmax instead of argmax
    
    if verbose:
        print('----------utility')
        print((lam * pi * uS * UR_sm))
    
    if reduce:
        total = total.mean()
    return total

def quick_loss_fn_argmax(x,pred, M=2,N=2,K=2, beta=10, reduce=False, verbose=False):

    x = x.reshape(-1,M,N+N+1)
    uR = x[:,:,:N].reshape(-1,M,N,1)
    uS = x[:,:,N:2*N].reshape(-1,M,N,1)
    lam = x[:,:,2*N:].reshape(-1,M,1,1)
    pi = pred.reshape(-1,M,1,K)

    mu = pi*lam 
    mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)       
    UR = mu * uR
    UR = UR.sum(axis=1,keepdim=True)
    
    # sender utility 
    US = mu * uS
    US = US.sum(axis=1,keepdim=True)


    if verbose:
        print(mu)
        print(UR.shape)
        print(UR)

    # Argmax
    UR_am = UR 
    UR_am += 0.001 * US
    idx = torch.argmax(UR_am, axis=2)
    UR_am = torch.zeros_like(UR_am).scatter_(dim=2, index = idx.unsqueeze(1), src = torch.ones_like(UR_am))

    if verbose:
        print('----------argamx')
        print(UR_am.shape)
        print(UR_am)

    total = -(lam * pi * uS * UR_am).sum(axis=(1,2,3)) ## Change UR_am to UR_sm if you want to use softmax instead of argmax
    
    if verbose:
        print('----------utility')
        print()
        print((lam * pi * uS * UR_am))
    
    if reduce:
        total = total.mean()
    return total


if __name__ == '__main__':
    print('Running some checking codes to debug generator')
    bs = 2
    seed = 42
    M,N,K = 2,2,2
    x,y = getXY(bs=bs, seed=seed)
    print('setting-----------------')
    print(x,y)
    x = x.reshape(-1,M,N+N+1)
    uR = x[:,:,:N].reshape(-1,M,N)
    uS = x[:,:,N:2*N].reshape(-1,M,N)
    lam = x[:,:,2*N:].reshape(-1,M,1)
    pi = y.reshape(-1,K,M).mT.reshape(-1,M,K)
    
    #### generate a list of test case
    bs, M,N,K = 10,5,2,2
    seed = 42
    x,y = getXY_nstate_2action(M=M, bs =bs,seed=seed, solved=True)
    case = 6
    print(decode_tensor(x[6,:], 'x', M,N,K))
    print(decode_tensor(y[6,:], 'y', M,N,K))
    data = decode_tensor(x[6,:], 'x', M,N,K)
    pi = lp_solver(M=M, N=2, K=2, uS=data[1].reshape(M,N), uR=data[0].reshape(M,N), lam=data[2].reshape(M))
    print(pi)



    
    
    
    