import xgboost
import numpy as np
import shap
from sklearn.metrics import accuracy_score

from tqdm import tqdm

import spacy
import stanza
from copy import copy

nlp = spacy.load('en_core_web_sm')
stanza_nlp = stanza.Pipeline(lang='en', processors='tokenize,sentiment', use_gpu=False)

from collections import defaultdict
import string
import json

entity_to_seg_marking = {
    'Event_Cause': ['EC_4_Action', 'EC_4_Omission',
                   'EC_4_Normal_Event', 'EC_4_Abnormal_Event',
                   'EC_4_Same_Time_Cause', 'EC_4_Early_Cause', 'EC_4_Late_Cause'],
    'Norm': ['Norm_3_Prescriptive_Norm', 'Norm_3_Statistics_Norm'],
    'Causal_Struct': ['CS_1_Disjunctive', 'CS_1_Conjunctive'],
    'Agent_Knowledge': ['AK_2_Agent_Aware', 'AK_2_Agent_Unaware'],
    'Outcome': ['Outcome_5_negative', 'Outcome_5_positive']
}

actual_hierarchy = {
    'action_omission': ['EC_4_Action', 'EC_4_Omission'],
    'event_normality': ['EC_4_Normal_Event', 'EC_4_Abnormal_Event'],
    'time': ['EC_4_Same_Time_Cause', 'EC_4_Early_Cause', 'EC_4_Late_Cause'],
    'norm_type': ['Norm_3_Prescriptive_Norm', 'Norm_3_Statistics_Norm'],
    'agent_awareness': ['AK_2_Agent_Aware', 'AK_2_Agent_Unaware'],
    'causal_structure': ['CS_1_Disjunctive', 'CS_1_Conjunctive'],
    'outcome': ['Outcome_5_negative', 'Outcome_5_positive']
}

def get_necessary_data():
    legend = json.load(
        open("/piech/u/anie/moca_github/data/causal_annotations_raw/Moca_public/annotations-legend.json"))
    reverse_legend = {v: k for k, v in legend.items()}

class CausalShapleyAnalysisEngine(object):
    def __init__(self, causal_engine, preds, story_id_to_segments, original_stories, sentiment_for_each_story=None):
        # build a pandas dataset
        # preds: a list of integers [0, 1, 0, ...]
        # length should be the same as

        data = {}
        for i, v in enumerate(causal_engine.label_vocab):
            data[v] = causal_engine.feat[:, i]

        annotated_stories = []
        for story_id in story_id_to_segments.keys():
            annotated_stories.append(preds[story_id])
        data['Pred'] = annotated_stories

        # add length and sentiment score
        lens, sentiments = [], []
        for story_id in tqdm(story_id_to_segments.keys()):
            doc = nlp(original_stories[story_id]['input'])
            lens.append(len(doc))
            if sentiment_for_each_story is None:
                doc = stanza_nlp(original_stories[story_id]['input'])
                sentiments.append(np.mean([s.sentiment for s in doc.sentences]))
        if sentiment_for_each_story is not None:
            sentiments = sentiment_for_each_story

        # data['Pred'] = annotated_stories

        self.data = data
        self.merge_dummy_variables_with_levels()

        # self.data_with_dummy['Pred'] = annotated_stories
        self.data_with_dummy['Lengths'] = np.array(lens) / 100  # we scale this
        self.data_with_dummy['Sentiment'] = sentiments

        # need to add y (preds)
        # self.df = pd.DataFrame(data=data)
        self.df = pd.DataFrame(data=self.data_with_dummy)

        self.Xd = xgboost.DMatrix(self.df, label=np.array(annotated_stories))
        self.labels = np.array(annotated_stories)

    def merge_dummy_variables_with_levels(self):
        data_with_dummy = {}
        for high_level in actual_hierarchy:
            if high_level == 'outcome':
                continue

            low_level_keys = copy(actual_hierarchy[high_level])
            # merge them
            level_map = dict(zip([''] + low_level_keys, range(len(low_level_keys) + 1)))

            data_with_dummy[high_level] = []

            for i in range(len(self.data[low_level_keys[0]])):
                # determine which level it's on
                non_zero_level = ''
                for l in low_level_keys:
                    if self.data[l][i] == 1 and non_zero_level == '':
                        non_zero_level = l
                    elif self.data[l][i] == 0:
                        pass
                    else:
                        raise Exception("shouldn't be here...")
                data_with_dummy[high_level].append(level_map[non_zero_level])

        self.data_with_dummy = data_with_dummy

    def build_model(self):
        # building linear relationship

        # we define interactive terms like "causal structure x other"

        self.model = xgboost.train({
            "eta": 0.01,
            "objective": "binary:logistic"
        }, self.Xd, 100)

    def test_model(self, verbose=True):
        preds = self.model.predict(self.Xd)
        preds = (preds >= 0.5).astype(int)
        accu = accuracy_score(self.labels, preds)
        if verbose:
            print(f"Accuracy is {accu}")
        return accu