#include <cmath>
#include <stdio.h>
#include <random>
#include <time.h>
#include <mex.h>
#include <vector>
#include "matrix.h"
using namespace std;
bool comp(pair<double, int> a, pair<double, int> b){
    return a.first > b.first;
}
void main_fun(double * const real_err, double * const ROC, const vector<vector<int>> & Adj,
        const vector<int> & label, int step, int n, int seed_num,
        int method_num, const vector<double> & lambda, double PRalpha,
        double heatPRbeta, int times, int label_num, int ROCpoint){
    vector<int> d(n);
    default_random_engine generator;
    vector<vector<pair<double, int>>> res(method_num, vector<pair<double, int>>(n));
    vector<vector<pair<double, int>>> sort_res(method_num, vector<pair<double, int>>(n));
    int i, j, k, h, h1, h2, pointer;
    for(i = 0; i < n; i++){
        d[i] = Adj[i].size();
    }
    for(i = 0; i < label_num; i++){
        vector<int> label_index;
        vector<int> label_inv_index;
        for(j = 0; j < n; j++){
            if(label[j] == i){
                label_index.push_back(j);
            }else{
                label_inv_index.push_back(j);
            }
        }
        uniform_int_distribution<int> distribution(0,label_index.size()-1);
        for(j = 0; j < times; j++){
            //mexPrintf("%d, %d\n", i, j);
            vector<double> x(n);
            vector<double> nx(n);
            vector<double> next_x(n);
            for(h = 0; h < n; h++){
                for(h1 = 0; h1 < method_num; h1++){
                    res[h1][h].first = 0;
                    res[h1][h].second = h;
                }
            }
            k = seed_num;
            while(k > 0){
                auto temp = distribution(generator);
                if(x[label_index[temp]]==0){
                    k--;
                    x[label_index[temp]] = 1;
                    nx[label_index[temp]] = 1.0/d[label_index[temp]];
                    for(h1 = 0; h1 < method_num; h1++){
                        res[h1][label_index[temp]].first = x[label_index[temp]];
                    }
                }
            }
            vector<double> lastlambda (method_num - 3, 1);
            double lastalpha = 1;
            double lastbeta = 1;
            for(k = 0; k < step; k++){
                double test = 0;
                for(h = 0; h < n; h++){
                    for(auto ele: Adj[h]){
                        next_x[ele] += nx[h]; 
                    }
                }
                for(h1 = 0; h1 < method_num - 3; h1++){
                    lastlambda[h1] = lastlambda[h1] * lambda[h1];
                }
                lastbeta *= heatPRbeta/(k+1);
                lastalpha *= PRalpha;
                for(h = 0; h < n; h++){
                    x[h] = next_x[h];
                    nx[h] = x[h]/d[h];
                    next_x[h] = 0;
                    for(h1 = 0; h1 < method_num - 3; h1++){
                        res[h1][h].first += x[h]/lastlambda[h1];
                    }
                    res[method_num - 3][h].first = x[h];
                    res[method_num - 2][h].first += x[h] * lastalpha;
                    res[method_num - 1][h].first += x[h] * lastbeta;
                }
                for(h1 = 0; h1 < method_num; h1++){
                    sort_res[h1].assign(res[h1].begin(), res[h1].end());
                    sort(sort_res[h1].begin(), sort_res[h1].end(), comp);
                    for(h2 = 0; h2 < label_index.size(); h2++){
                        if(label[sort_res[h1][h2].second]!=i){
                            *(real_err + k + step * h1  + step * method_num * i + step * method_num * label_num * j) += 1;
                        }
                    }
                }
                if(k == ROCpoint){
                    for(h1 = 0; h1 < method_num; h1++){
                        sort_res[h1].assign(res[h1].begin(), res[h1].end());
                        sort(sort_res[h1].begin(), sort_res[h1].end(), comp);
                        pointer = (100+1)*h1  + (100+1)*method_num * i + (100+1)*method_num * label_num * j;
                        *(ROC + pointer) = 0;
                        int count = 0;
                        for(h2 = 0; h2 < label_index.size(); h2++){
                            if(h2 >= label_index.size()*count*0.01){
                                *(ROC + count + 1 + pointer) = *(ROC + count + pointer);
                                count++;
                            }
                            if(label[sort_res[h1][h2].second]==i){
                                *(ROC + count + pointer) += 1;
                            }
                        }
                    }
                }
            }
        }      
    }
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
	if (nrhs != 9) {
    mexWarnMsgTxt("Check Parameters");
    return;
    }
    // read incidence_list
    int n = mxGetM(prhs[1]);
    int R = mxGetM(prhs[0]);
    double * adj_list = mxGetPr(prhs[0]);
    vector<vector<int>> Adj (n, vector<int>(0));
    int i;
    for(i = 0; i < R; i++){
        Adj[(int)*(adj_list + i)-1].push_back((int)*(adj_list + i + R)-1);
        //mexPrintf("%d, %d, %d\n", (int)*(adj_list + i), Adj[(int)*(adj_list + i)][Adj[(int)*(adj_list + i)].size()-1], i);
    }
    vector<int> label;
    int label_num = 0;
    double * label_list = mxGetPr(prhs[1]);
    for(i = 0; i < n; i++){
        label.push_back((int)*(label_list + i));
        label_num = max(label_num, (int)*(label_list + i));
    }
    label_num++;
    int step = (int)*(mxGetPr(prhs[2]));
    int seed_num = (int)*(mxGetPr(prhs[3]));
    int lambda_num = mxGetN(prhs[4]);
    double * lambda_pr = mxGetPr(prhs[4]);
    vector<double> lambda;
    for(i = 0; i < lambda_num; i++){
        lambda.push_back(*(lambda_pr + i));
    }
    
    double PRalpha = *(mxGetPr(prhs[5]));
    double heatPRbeta = *(mxGetPr(prhs[6]));
    int times = (int)*(mxGetPr(prhs[7]));
    int ROCpoint = (int)*(mxGetPr(prhs[8]));
    int method_num = lambda_num + 3;
    plhs[0] = mxCreateDoubleMatrix(step * method_num * label_num * times, 1, mxREAL);
    double * realerr = mxGetPr(plhs[0]);
    plhs[1] = mxCreateDoubleMatrix((100+1) * method_num * label_num * times, 1, mxREAL);
    double * ROC = mxGetPr(plhs[1]);
    main_fun(realerr, ROC, Adj, label, step, n, seed_num, method_num, lambda,
            PRalpha, heatPRbeta, times, label_num, ROCpoint);
}







