import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as scio
import os 
import time
import csv
from Datagene import *
from Human_Model import *
from Policy_module import *
import sys

class HumanNonBayesian():
    ### A model based human decision maker
    ### non-rational and non-bayesian
    ### 
    def __init__(self, alpha=1,beta=10,gamma=1,M=2,N=2,K=2):
        self.M, self.N, self.K = M, N, K
        self.alpha, self.beta, self.gamma = alpha, beta, gamma
        
    def forward(self, x):  
        M, N, K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1)
        lam = x[:,:,N].reshape(-1,M,1,1)
        pi = x[:,:,N+1:].reshape(-1,M,1,K)
        return self.forward_info( uR, lam, pi)
        
    def forward_info(self, uR, lam, pi):
        mu = (pi*lam)
        mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)   
        mu_re = torch.maximum(mu, torch.ones_like(mu)*0.001)
        mu_base = torch.ones_like(mu_re) * 0.5 ### similar to random
        mu_h = (1-self.gamma) * mu_re + self.gamma * mu_base 
        mu_h2 = mu_h /  torch.maximum( mu_h.sum(axis=1,keepdim=True), torch.ones_like(mu_h)*0.001)
        UR = mu_h2 * uR
        UR = UR.sum(axis=1,keepdim=True)
        UR_sm = UR
        UR_sm = F.softmax(self.beta * UR_sm, dim=2)
        UR_random = torch.ones_like(UR)
        UR_random = F.softmax (UR_random, dim=2)
        UR_h = self.alpha * UR_sm + (1-self.alpha) * UR_random 
        if torch.sum(torch.isnan(UR_h)) > 0:
            UR_h = F.softmax (torch.ones_like(UR), dim=2)
        return UR_h
    
class PolicyNN_abc(pl.LightningModule):
    ### the NN for policy model when loss is defined by a human model
    
    def __init__(self, M=2, N=2, K=2, batch_size=2**6, batch_num=100000, num_workers=8, tile=None, beta = 10, lr=1e-6, beta_end = None, human_model = None,traindata_path=None, 
                 valdata_path=None, fc_size = 512):
        super(PolicyNN_abc, self).__init__()
        self.M, self.N, self.K = M, N, K
        self.fc1 = nn.Linear(M * (N+N+1), fc_size) # u^S + u^R + λ
        self.fc2 = nn.Linear(fc_size,fc_size)
        self.fc4 = nn.Linear(fc_size,fc_size)
        self.fc3 = nn.Linear(fc_size, M*K)
        self.sm = nn.Softmax(dim=2)
        
        self.batch_size = batch_size
        self.batch_num = batch_num
        self.num_workers = num_workers    
        self.tile = tile
        self.beta = beta
        if beta_end != None:
            self.beta_end = beta_end
        else:
            self.beta_end = beta
        self.beta_evolve = torch.logspace(start=float(np.log10(self.beta)), end=float(np.log10(self.beta_end)), steps=1000) # use a fixed steps to move
        self.evolve_step = 0
        self.lr = lr
        self.M = M
        self.N = N
        self.K = K
        self.human_model = human_model
        self.traindata_path = traindata_path
        self.valdata_path = valdata_path
        
    def forward(self, x):
        out = self.fc1(x.float())
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc4(out)
        out = F.relu(out)
        out = self.fc3(out)
        out = out.view(-1,self.M,self.K)
        out = self.sm(out)
        out = out.view(-1,self.M*self.K)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        if batch_idx % 100 == 0: self.evolve_step = min(self.evolve_step + 1 , 999)
        loss = self.loss_fn_human(x,y_hat,human_model = self.human_model, reduce=True)
        
        torch.autograd.set_detect_anomaly(True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        rmsel = nn.MSELoss()
        rmse = torch.sqrt(rmsel(y, y_hat))
        loss_arg = 0
        meanbeta = torch.mean(self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size])
        self.log("performance", {"iter": batch_idx, "loss": loss, "rmse": rmse, "meanbeta":meanbeta, 'argloss':loss_arg, 'beta':self.beta_evolve[self.evolve_step]})
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn_human(x,y_hat,human_model = self.human_model,  reduce=True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        loss_arg = 0
        self.log("performance", {"iter": batch_idx, "val-loss": loss, "rmse": rmse, "entropy":F.cross_entropy(y_hat, y), 'argloss':loss_arg})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        if os.path.exists(self.traindata_path):
            data = torch.load(self.traindata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check train data path:', self.traindata_path)
            sys.exit(0)
        self.train_beta = torch.ones(self.batch_num * self.batch_size) * self.beta_end  ### beta of training 
        self.train_beta = self.train_beta.reshape(-1,1,1,1).to(device='cuda:0')
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return dl
    
    def val_dataloader(self):
        if os.path.exists(self.valdata_path):
            data = torch.load(self.valdata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check val data path:', self.valdata_path)
            sys.exit(0)
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=x.size()[0], shuffle=False, num_workers=1)
        return dl
    
    def loss_fn_human(self, x,pred,human_model = None,reduce=False,batch_idx=-1):
        M,N,K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1)
        uS = x[:,:,N:2*N].reshape(-1,M,N,1)
        lam = x[:,:,2*N:].reshape(-1,M,1,1)
        pi = pred.reshape(-1,M,1,K)
         
        if not human_model: ### when human model is not defined, a softmax model will be used as default
            mu = pi*lam
            mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)       
            UR = mu * uR
            UR = UR.sum(axis=1,keepdim=True)
            UR_sm = UR
            UR_sm = F.softmax(self.beta * UR_sm, dim=2)
#             print('shape of tensor:', mu.size(), UR.size(), UR_sm.size())
            total = -(lam * pi * uS * UR_sm).sum(axis=(1,2,3))
        else: ### use a human model
            human_x = torch.cat((uR.reshape(-1,M,N),lam.reshape(-1,M,1), pred.reshape(-1,M,K)),dim = 2).reshape(-1, M*(N+1+K))
            UR_h = human_model.forward(x= human_x).reshape(-1, 1,N,K)
            total = -(lam * pi * uS * UR_h).sum(axis=(1,2,3))
        if reduce:
            total = total.mean()   
        return total
    
    
bs = 1000
batch_num = 20
batch_size = 1024
a = float(sys.argv[1])
b = float(sys.argv[2])
c = float(sys.argv[3])

tstart = time.time()
M,N,K = 2,2,2
trainp = 'ABC/val_dataN='+ str(M) + '.pt'
valp = 'ABC/val_dataN='+ str(M) + '.pt'

modelh1 = HumanNonBayesian(M=M,N=N,K=K,alpha=a, beta=b, gamma=c)

for lr in [0.001,0.002,0.01,0.02,0.1]:
    for fc_size in [256,512,1024]:

        policynn = PolicyNN_abc(M=2, N=2, K=2, batch_size=2**10, batch_num=20, num_workers=10, tile=None, beta = 10, lr=1e-6, beta_end = None, 
                                       human_model = modelh1,traindata_path=trainp, valdata_path=valp, fc_size = fc_size)

        csv_logger = CSVLogger("./logs", name="policyNN_abc", version=f'policy_4fc_n2a{a}b{b}c{c}lr{lr}fc{fc_size}')
        trainer = pl.Trainer(accelerator='gpu' , deterministic=False, max_epochs=100,check_val_every_n_epoch=10,
                             logger=csv_logger,enable_progress_bar = False)
        print('training begin!')
        history = trainer.fit(policynn)
        print('training finished! time cost: ', time.time()-tstart, ' for parameter:(M,lr,fc_size)=', M,lr,fc_size)
        torch.save(policynn.state_dict(), f'ABC/PolicyNN_abc_mpolicy_a{a}b{b}c{c}lr{lr}fc{fc_size}.pt')
    