import numpy as np
import torch
from sklearn.model_selection import cross_validate, cross_val_predict
from sklearn import datasets, linear_model
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis ### todo review this!!!
from snapml import BoostingMachineClassifier as SnapBoostingMachineClassifier
from sklearn.experimental import enable_iterative_imputer 
from sklearn.impute import KNNImputer, IterativeImputer
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neural_network import MLPRegressor
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from sklearn.svm import SVC
from sklearn.kernel_approximation import RBFSampler, Nystroem
import math

from .supervised_imputer import SupervisedImputer
from ..divergences.discriminator_divergence_estimator import div_multiple_clf, div_tv_clf



class AdversarialIterativeBatchWrapper():
    def __init__(self, config, base_classifier, reference, query, num_corrupted_feats, batch_size=5000, max_dims=None):
        
        self.query_original = query.copy()
        
        if reference.shape[0] > 15000:
            batch_size = 3000
        
        if max_dims is not None and query[:, num_corrupted_feats:].shape[1] > max_dims:
            std = np.std(query[:, num_corrupted_feats:], axis=0)
            indexes = np.concatenate([np.arange(num_corrupted_feats), np.argsort(std)[-max_dims:]+num_corrupted_feats])
            reference = reference[:, indexes]
            query = query[:, indexes]
            print(f'{len(indexes)} features in reference and query.')
        
        self.config = config
        self.reference = reference
        self.query = query
        self.num_corrupted_feats = num_corrupted_feats
        self.base_classifier = base_classifier
        
        self.num_samples = reference.shape[0]
        self.batch_size = batch_size
        self.max_dims = max_dims    # TODO: remove this line and parameter
        self.num_batches = math.floor(self.num_samples/self.batch_size)        
        
    def fit_transform(self):
        if self.reference.shape[0] <= self.batch_size:
            print('NO BATCH PROCESSING!')
            imputer = AdversarialIterativePerSampleImputerV2(None, self.base_classifier, self.reference.copy(), 
                                                             self.query.copy(), self.num_corrupted_feats)
            query_imputed = imputer.fit_transform()
            
            if self.query.shape[1] < self.query_original.shape[1]:
                query = self.query_original
                query[:, :self.num_corrupted_feats] = query_imputed[:, :self.num_corrupted_feats]
                
                return query
            else:
                return query_imputed
        else:
            perm_indices = np.random.permutation(self.reference.shape[0])
            for k in range(self.num_batches):
                print('PROCESSING BATCH ', k, self.num_batches)
                s, e = self.batch_size*k, self.batch_size*(k+1)
                if k == (self.num_batches - 1):
                    e = self.reference.shape[0]
                imputer = AdversarialIterativePerSampleImputerV2(None, self.base_classifier, self.reference[perm_indices,:][s:e,:].copy(), self.query[perm_indices,:][s:e,:].copy(), self.num_corrupted_feats)
                query_imputed_small = imputer.fit_transform()
                idx_update = perm_indices[s:e]
                self.query[idx_update,:] = query_imputed_small
            
            if self.query.shape[1] < self.query_original.shape[1]:
                query = self.query_original
                query[:, :self.num_corrupted_feats] = self.query[:, :self.num_corrupted_feats]
                return query
            else:
                return self.query
            

class AdversarialIterativePerSampleImputerV2():
    
    def __init__(self, config, base_classifier, reference, query, num_corrupted_feats):
        self.config = config
        self.reference = reference
        self.query = query
        self.query_original = query.copy()
        self.num_corrupted_feats = num_corrupted_feats
        
        self.query_epoch_log = []
        
        self.discriminators_list = []
        self.discriminator_base_classifier = base_classifier
        self.eval_classifier = base_classifier
        
        
        # Training data to include in discriminator fitting
        self.include_samplewise_shuffle_aug = True
        self.include_featurewise_shuffle_aug = True
        self.include_imputed_samples_aug = False
        
        # Supervised Imputers
        self.imputer_list = []
        self.imputed_query_list = []
        
    def fit(self, x):
        return self
    
    
    def fit_transform(self):
        '''
        Main function
        '''
        self.warm_up()
        self.impute_dataset()
        return self.query
    
    
    def warm_up(self):
        '''
        Initial function that:
        1. fits initial external imputers including lin reg and k-nn
        2. performs imputation with the imputers fitted in step (1)
        3. gets proposals for starting points e.g. random imputation and imputation from imputers in step (1,2)
        4. Evaluates each initial proposal (step 3) and selects the one with lowest divergence as starting point
        '''
        ## Initial steps
        print('fitting and imputing data')
        self.fit_internal_imputers()
        self.impute_with_sklearn_imputers()
        print(len(self.imputed_query_list))
        print('getting initial proposals')
        _, _ = self.get_initial_proposals()
        self.evaluate_initial_proposals()
        return self
    
    
    def _add_proposal_inside_query(self, proposal):
        '''
        Helper function for 'get_initial_proposals()'
        '''
        _proposal = self.query.copy()
        _proposal[:,0:self.num_corrupted_feats] = proposal[:,0:self.num_corrupted_feats]
        return _proposal
        

    def get_initial_proposals(self):
        '''
        (warmup)
        '''
        proposals = []
        proposals.append(self._add_proposal_inside_query(self.reference.copy()))
            
        for imp in self.imputed_query_list:
            proposals.append(self._add_proposal_inside_query(imp))
            
        proposals_array = np.concatenate(proposals, axis=0)
        proposals_array = proposals_array[:,0:self.num_corrupted_feats]
        
        self.initial_proposals = proposals
        
        return proposals, proposals_array
    
    
    
    def get_iteration_proposals(self):
        proposals = []

        proposals.append(self._add_proposal_inside_query(self.reference.copy()))

        ref_perm = self.reference.copy()
        for jj in range(self.num_corrupted_feats):
            randperm = np.random.permutation(self.reference.shape[0])
            ref_perm[:,jj] = ref_perm[randperm,jj]
        proposals.append(self._add_proposal_inside_query(ref_perm))
        proposals.append(self._add_proposal_inside_query(self.imputed_query_list[0])) # We only add from linear regression imputation

            
        proposals_array = np.concatenate(proposals, axis=0)
        proposals_array = proposals_array[:,0:self.num_corrupted_feats]
        
        return proposals, proposals_array
    
    
    
    def evaluate_initial_proposals(self):
        '''
        (warmup)
        '''
        best_tv = 1.0
        for j, proposal in enumerate(self.initial_proposals):
            print(f'Evaluating proposal {j}')
            tv = self.evaluate_proposal(proposal)
            print(tv, best_tv)

            if tv < best_tv:
                best_tv = tv
                self.query = proposal.copy()

            if best_tv < 0.1:
                break
    
    
    def evaluate_proposal(self, proposal):
        tv, cv_proba = div_tv_clf(self.reference, proposal, clf = self.eval_classifier)
        return tv
        
    
    
    def fit_initial_discriminators(self):
        for i in [1,0]:
            x = np.concatenate([self.reference, self.query[i::2,:]], axis=0)
            y = np.concatenate([np.zeros(self.reference.shape[0]), np.ones(self.query[i::2,:].shape[0])], axis=0)
            
            if self.include_samplewise_shuffle_aug:
                ref_perm = self.reference.copy()
                randperm = np.random.permutation(self.reference.shape[0])
                ref_perm[:,0:self.num_corrupted_feats] = ref_perm[randperm,0:self.num_corrupted_feats]
                x = np.concatenate([x, ref_perm], axis=0)
                y = np.concatenate([y, np.ones(ref_perm.shape[0])], axis=0)
            
            if self.include_featurewise_shuffle_aug:
                ref_perm = self.reference.copy()
                for jj in range(self.num_corrupted_feats):
                    randperm = np.random.permutation(self.reference.shape[0])
                    ref_perm[:,jj] = ref_perm[randperm,jj]
                x = np.concatenate([x, ref_perm], axis=0)
                y = np.concatenate([y, np.ones(ref_perm.shape[0])], axis=0)
            
            if self.include_imputed_samples_aug: #TODO: this could be removed
                for que_imputed in self.imputed_query_list:
                    x = np.concatenate([x, que_imputed[i::2,:]], axis=0)
                    y = np.concatenate([y, np.ones(que_imputed[i::2,:].shape[0])], axis=0)
                    
            for prev_que in self.query_epoch_log: # TODO: only add features corrupted! self.query[i::2,:]
                x = np.concatenate([x, prev_que], axis=0)
                y = np.concatenate([y, np.ones(prev_que.shape[0])], axis=0)

            clf = self.discriminator_base_classifier
            clf.fit(x,y)
            self.discriminators_list.append(clf)
        


    def evaluate_divergence(self, evaluate_background_div = False):  # This function is not used
        if evaluate_background_div:
            acc = cross_val_score(self.eval_classifier, np.concatenate([self.reference[:,self.num_corrupted_feats:], self.query[:,self.num_corrupted_feats:]], axis=0), np.concatenate([np.zeros(self.reference[:,self.num_corrupted_feats:].shape[0]), np.ones(self.query[:,self.num_corrupted_feats:].shape[0])], axis=0), cv=4)
        else:
            acc = cross_val_score(self.eval_classifier, np.concatenate([self.reference, self.query], axis=0), np.concatenate([np.zeros(self.reference.shape[0]), np.ones(self.query.shape[0])], axis=0), cv=3)
        return np.mean(acc)
    
    
    def detect_bad_samples(self):
        cv_pred = cross_val_predict(self.eval_classifier, np.concatenate([self.reference, self.query], axis=0), np.concatenate([np.zeros(self.reference.shape[0]), np.ones(self.query.shape[0])], axis=0), cv=4, method='predict_proba')
        cv_pred = cv_pred[self.reference.shape[0]:, 1] # Keep query results only
        return cv_pred, np.where(cv_pred > 0.5)[0]
        
        
        
    def impute_sample(self, index):
        x = self.query[index,:]
        _, initial_proposals = self.get_iteration_proposals()
        classifier = self.discriminators_list[index % 2]
        x_new = self.optimize_simple(x, initial_proposals, classifier)
        return x_new
    
    
    
    
    def impute_dataset(self):
        '''
        Main loop performing imputation at each sample
        '''
        
        self.do_internal_eval = False


        if self.do_internal_eval:
            print('doing eval')
            tv_all, tv_mean, tv_max = div_multiple_clf(self.reference, self.query)
            print(tv_all, tv_mean, tv_max)
            

        self.num_epochs = 1
        for epoch in range(self.num_epochs):
            print('epoch is ', epoch)

            self.fit_initial_discriminators()

            print('detecting bad samples')
            preds, bad_idx = self.detect_bad_samples()
            argsort = np.argsort(preds)[::-1]
            bad_idx = []
            for i in argsort:
                if preds[i] > 0.5:
                    bad_idx.append(i)
            bad_idx = np.array(bad_idx)
            num_elements_per_batch = max(0,int(bad_idx.shape[0] - self.query.shape[0]*0.5))
            print('num bad samples is ', len(list(bad_idx)), num_elements_per_batch)

            if bad_idx.shape[0]/self.query.shape[0] < 0.53:
                print('finishing process!')
                break


            for j, idx in enumerate(list(bad_idx)): # TODO: could this be made paralelized ?
                if j % 50 == 0:
                    print(j)
                    
                if j % 400 == 399 and self.do_internal_eval:    
                    print('doing eval', j)
                    tv_all, tv_mean, tv_max = div_multiple_clf(self.reference, self.query, clf_list=[CatBoostClassifier(verbose=False)])
                    print(tv_all, tv_mean, tv_max)
                    
                self.query[idx,:] = self.impute_sample(idx)

                if j == num_elements_per_batch:
                    print('breaking!', j)
                    break

            print('re-fitting internal discriminators')
            self.query_epoch_log.append(self.query.copy())


        return self
    

    def optimize_simple(self, x, initial_proposals, classifier):
        x = np.repeat(x[np.newaxis,:], initial_proposals.shape[0], axis=0)
        x[:,0:self.num_corrupted_feats] = initial_proposals[:,0:self.num_corrupted_feats]
        loss = classifier.predict_proba(x)[:,1]
        x_new = x[np.argmin(loss),:]
        return x_new
    
    
    
    def fit_internal_imputers(self):
        '''
        (Warmup) Training of internal imputation done during warmp up using linear regression and k-nn
        '''
        
        # Linear reg. imputer
        imputer = SupervisedImputer(LinearRegression(), self.num_corrupted_feats)
        imputer.fit(self.reference)
        self.imputer_list.append(imputer)
        
        # K-NN imputers
        for k in [1,]: #2 ,5, 10
            for weights in ['distance', 'uniform']:
                imputer = KNNImputer(n_neighbors=k, weights=weights)
                imputer.fit(self.reference)
                self.imputer_list.append(imputer)
        return self.imputer_list
    
    
    def impute_with_sklearn_imputers(self):
        '''
        (Warmup) Inference of internal imputation done during warmp up using linear regression and k-nn
        '''
        que_missing = self.query.copy()
        que_missing[:,0:self.num_corrupted_feats] = np.nan
        
        for imputer in self.imputer_list:
            que_imputed = imputer.transform(que_missing.copy())
            self.imputed_query_list.append(que_imputed)
        return self.imputed_query_list
