#!/usr/bin/env python

from __future__ import print_function

import sys
import os
import struct
import numpy as np

import tree_pb2 as tree_proto
import store_kv_pb2 as store_proto
import tree_builder


class TreeLearner(object):
    def __init__(self, filename):
        self.filename = filename
        self.tree_meta = tree_proto.TreeMeta()
        self.id_code = dict()
        self.item_codes = set()
        self.nodes = dict()


    def load_tree(self, tree_file):
        nodes = []
        id_code_part_list = []
        with open(tree_file, 'rb') as f:
            num = struct.unpack('i', f.read(4))[0]
            while True:
                kv_item = store_proto.KVItem()
                kv_item.ParseFromString(f.read(num))
                key = kv_item.key
                value = kv_item.value
                if key == '.tree_meta':
                    self.tree_meta.ParseFromString(value)
                elif 'Part_' in key:
                    part = tree_proto.IdCodePart()
                    part.ParseFromString(value)
                    id_code_part_list.append(part)
                else:
                    node = tree_proto.Node()
                    node.ParseFromString(value)
                    nodes.append([struct.unpack('L', key[::-1])[0], node])
                content = f.read(4)
                if not content:
                    break
                num = struct.unpack('i', content)[0]

        for part in id_code_part_list:
            if part.part_id not in self.tree_meta.id_code_part:
                continue
            for id_code in part.id_code_list:
                self.id_code[id_code.id] = id_code.code 
                self.item_codes.add(id_code.code)
        id_code_part_list = None

        for code, node in nodes:
            if code in self.item_codes:
                assert self.id_code[node.id] == code
            self.nodes[code] = node
        nodes = None


    def get_ancestor(self, code, level):
        code_max = 2 ** (level + 1) - 1
        while code >= code_max:
            code = int((code - 1) / 2) 
        return code


    def get_nodes_given_level(self, level):
        code_min = 2 ** level - 1
        code_max = 2 ** (level + 1) - 1
        res = []
        for code in self.nodes.keys():
            if code >= code_min and code < code_max:
                res.append(code)
        return res

    def get_children_given_ancestor_and_level(self, ancestor, level):
        code_min = 2 ** level - 1
        code_max = 2 ** (level + 1) - 1
        parent = [ancestor]
        res = []
        While True:
            children = []
            for p in parent:
                children.extend([2 * p + 1, 2 * p + 2])
            if code_min <= children[0] < code_max:
                break
            parent = children

        output = []
        for i in res:
            if i in self.nodes:
                output.append(i)
        return output

    def get_parent_path(self, child, ancestor):
        res = []
        while child > ancestor:
            res.append(child)
            child = int((child - 1) / 2)
        return res

    def assign_leaf_nodes(self, parent_relation, tree_file):
        # update prob of all nodes
        stat = dict()
        pstat = dict()
        for id, code in parent_relation.items():
            init_code = self.id_code[id]
            prob = self.nodes[init_code].prob
            stat[code] = [id, prob]
            ancs = self._ancessors(code)
            for anc in ancs:
                if anc not in pstat:
                    pstat[anc] = 0.0
                pstat[anc] += prob
    
    
        # write all nodes 
        meta = tree_proto.TreeMeta()
        meta.max_level = self.tree_meta.max_level
        id_code_part = []
        with open(tree_file, 'wb') as f:
            for code in self.nodes.keys():
                # leaf nodes
                if code in stat:
                    node = tree_proto.Node()
                    node.id = stat[code][0]
                    node.is_leaf = True
                    node.probality = stat[code][1]
                    node.leaf_cate_id = 0
                    kv_item = store_proto.KVItem()
                    kv_item.key = self._make_key(code)
                    kv_item.value = node.SerializeToString()
                    self._write_kv(f, kv_item.SerializeToString())
                    if not id_code_part or len(id_code_part[-1].id_code_list) == 512:
                        part = tree_proto.IdCodePart()
                        part.part_id = 'Part_' + self._make_key(len(id_code_part) + 1)
                        id_code_part.append(part)
                    part = id_code_part[-1]
                    id_code = part.id_code_list.add()
                    id_code.id = id
                    id_code.code = code
                # others
                else:
                    node = tree_proto.Node()
                    node.id = self.nodes[code].id
                    node.is_leaf = False
                    node.leaf_cate_id = 0
                    node.probality = pstat[code]
                    kv_item = store_proto.KVItem()
                    kv_item.key = self._make_key(code)
                    kv_item.value = node.SerializeToString()
                    self._write_kv(f, kv_item.SerializeToString())

            # write id_code_part
            for part in id_code_part:
                meta.id_code_part.append(part.part_id)
                kv_item = store_proto.KVItem()
                kv_item.key = part.part_id
                kv_item.value = part.SerializeToString()
                self._write_kv(f, kv_item.SerializeToString())

            # write tree_meta
            kv_item = store_proto.KVItem()
            kv_item.key = '.tree_meta'
            kv_item.value = meta.SerializeToString()
            self._write_kv(f, kv_item.SerializeToString())
    
    def _ancessors(self, code):
        ancs = []
        while code > 0:
            code = int((code - 1) / 2)
            ancs.append(code)
        return ancs    

    
    def _make_key(self, code):
        return struct.pack('L', code)[::-1]

    def _write_kv(self, fwr, message):
        fwr.write(struct.pack('i', len(message)))
        fwr.write(message)
