import numpy as np
from tf_utils import *
from sklearn.cluster import KMeans, SpectralClustering

NUM_CLASSES = 1000



class WeightStaircaseDescender(object):

    # 10 stairs, 2000 steps/stair
    def __init__(self, sess, weights, init_step, num_staircase=100, len_staircase=100, start_step=0):
        self.sess = sess
        self.weights = weights
        self.init_step = init_step
        self.len_staircase = len_staircase
        self.num_staircase = num_staircase
        self.target_steps = [start_step + i * len_staircase for i in range(num_staircase)]
        self.max_step = start_step + num_staircase * len_staircase
        self.start_step = start_step
        self.step_values = None
        print('weight staircase descender target steps: ', self.target_steps)

    def descend(self, step):
        if step > self.max_step or step < self.start_step:
            return
        if self.step_values is None:
            former_staircases = (step - self.start_step) // self.len_staircase
            init_weights = [self.num_staircase / (self.num_staircase - former_staircases) * self.sess.run(w) for w in self.weights]
            ratio = 1. / self.num_staircase
            self.step_values = [ratio * iw for iw in init_weights]
        if step in self.target_steps:
            print('staircase descend starting:')
            for w, sv in zip(self.weights, self.step_values):
                value = self.sess.run(w) - sv
                assign = tf.assign(w, value)
                self.sess.run(assign)
                print('assign value for ', w.name)

class WeightDecayer(object):

    # 10 stairs, 2000 steps/stair
    def __init__(self, sess, layer_idx_to_eqcls, decay_factor=0.999, start_step=5000, end_step=15000):
        self.sess = sess
        self.start_step = start_step
        self.end_step=end_step
        assign_ops = []
        kernel_weights = tf_extract_kernel_tensors()
        bias_weights = tf_extract_bias_tensors()

        for layer_idx, eqcls in layer_idx_to_eqcls.items():
            k_w = kernel_weights[layer_idx]
            b_w = bias_weights[layer_idx]
            k_mask = np.ones(k_w.get_shape())
            b_mask = np.ones(b_w.get_shape())
            for eqcl in eqcls:
                if len(eqcl) == 1:
                    continue
                for e in eqcl:
                    b_mask[e] = decay_factor
                    k_mask[:, :, :, e] = decay_factor
            assign_ops.append(tf.assign(k_w, k_w * k_mask))
            assign_ops.append(tf.assign(b_w, b_w * b_mask))
            print('prepare decay ops for layer ', layer_idx)

        self.grouped_assign_op = tf.group(*assign_ops)
        print('weight decayer factor: ', decay_factor)

    def decay(self, step):
        if step >= self.end_step or step < self.start_step:
            return
        self.sess.run(self.grouped_assign_op)
        if step % 200 == 0:
            print('decayer still working!')





def near(a, b):
    return abs(a-b) <= 1e-6

def average_by_label(outs, labels):
    num_examples = outs.shape[0]
    filters = outs.shape[1]
    assert num_examples == labels.shape[0]
    result = np.zeros((NUM_CLASSES, filters))
    for i in range(NUM_CLASSES):
        idxes = labels == i
        selected = outs[idxes, :]
        result[i,:] = np.mean(selected, axis=0)
    return result

def add_to_eqcl(outs, eqc, new_idx):
    new_sum = len(eqc) * outs[:,eqc[0]] + outs[:,new_idx]
    eqc.append(new_idx)
    new_mean = new_sum / len(eqc)
    for i in eqc:
        outs[:,i] = new_mean

def search_for_largest_coe(avg_outs_by_class, idx_to_eqcl, max_eqcl_size=None):
    if max_eqcl_size is None:
        keys = list(idx_to_eqcl.keys())
    else:
        keys = []
        for k, v in idx_to_eqcl.items():
            if len(v) < max_eqcl_size:
                keys.append(k)

    idxes = np.array(keys, dtype=np.int32)
    sub_outs = avg_outs_by_class[:, idxes]
    coe = np.corrcoef(sub_outs, rowvar=False)
    len_key = len(keys)
    diag_idxes = np.arange(len_key, dtype=np.int32)
    coe[diag_idxes, diag_idxes] = -10000
    max_coord = np.argmax(coe)
    max_row = max_coord // len_key
    max_col = max_coord % len_key
    max_value = coe[max_row, max_col]
    return max_value, keys[max_row], keys[max_col]

def merge_two_eqcls(outs, idx_to_eqcl, host, guest):
    assert host != guest
    host_eqcl = idx_to_eqcl[host]
    guest_eqcl = idx_to_eqcl.pop(guest)
    new_sum = len(host_eqcl) * outs[:, host_eqcl[0]] + len(guest_eqcl) * outs[:, guest_eqcl[0]]
    host_eqcl.extend(guest_eqcl)
    new_mean = new_sum / len(host_eqcl)
    for i in host_eqcl:
        outs[:, i] = new_mean

def merge_abandoned_eqcls(outs, idx_to_eqcl, abandon_idxes):
    host_idx = abandon_idxes[0]
    host_eqcl = idx_to_eqcl[host_idx]
    host_eqcl.extend(abandon_idxes[1:])
    for i in abandon_idxes[1:]:
        idx_to_eqcl.pop(i)
        outs[:, host_idx] += outs[:, i]
    # outs[:, host_idx] /= len(abandon_idxes)
    outs[:, host_idx] = 1

def calculate_eqcls_from_raw(raw_outs, num_eqcls, abandon_thresh=0.000001):
    filters = raw_outs.shape[1]
    idx_to_eqcl = {i: [i] for i in range(filters)}

    # ###
    # origin_relu_outs = np.maximum(raw_outs, 0)
    # abandon_idxes = []
    # num_examples = origin_relu_outs.shape[0]
    # for i in range(0, origin_relu_outs.shape[1]):
    #     if np.sum(origin_relu_outs[:, i] > 0) < abandon_thresh * num_examples:
    #         abandon_idxes.append(i)
    # if len(abandon_idxes) > 1:
    #     print('abandon filters and merge them ', abandon_idxes)
    #     merge_abandoned_eqcls(raw_outs, idx_to_eqcl, abandon_idxes)
    #
    # ###
    # print('remained filters after abandoning ', len(idx_to_eqcl))

    while True:
        cur_num_eqcls = len(idx_to_eqcl)
        if cur_num_eqcls <= num_eqcls:
            break
        # relu_outs = np.maximum(raw_outs, 0)
        relu_outs = raw_outs
        value, row, col = search_for_largest_coe(relu_outs, idx_to_eqcl)
        print(value)
        merge_two_eqcls(raw_outs, idx_to_eqcl, row, col)
    result = list(idx_to_eqcl.values())
    return result

def calculate_eqcls_from_raw_bidirection(raw_outs, num_eqcls, phase_one_target=None):


    filters = raw_outs.shape[1]
    idx_to_eqcl = {i: [i] for i in range(filters)}

    if phase_one_target is None:
        phase_one_target = num_eqcls + (filters - num_eqcls)*1//4

    boolean_outs = raw_outs > 0
    xor_mat = np.zeros((filters, filters))
    for i in range(filters):
        for j in range(filters):
            xor = np.sum(np.logical_xor(boolean_outs[:, i], boolean_outs[:, j]), dtype=np.int32)
            xor_mat[i, j] = xor
    while True:
        cur_num_eqcls = len(idx_to_eqcl)
        if cur_num_eqcls <= phase_one_target:
            break
        keys = list(idx_to_eqcl.keys())
        a_keys = np.array(keys, dtype=np.int32)
        len_key = len(keys)
        selected_mat = xor_mat[:, a_keys]
        selected_mat = selected_mat[a_keys, :]
        max_xor_coord = np.argmax(selected_mat)
        row = max_xor_coord // len_key
        col = max_xor_coord % len_key
        print(selected_mat.shape, row, col)
        max_sum = selected_mat[row, col]
        print('max xor ', max_sum)
        # row_eqcl_idx = -1
        # col_eqcl_idx = -1
        # tmp = []
        # for idx, eqcl in idx_to_eqcl.items():
        #     tmp += eqcl
        #     if row in eqcl:
        #         row_eqcl_idx = idx
        #     if col in eqcl:
        #         col_eqcl_idx = idx
        # tmp.sort()
        # print(tmp)
        # print(len(tmp))
        row = keys[row]
        col = keys[col]
        merge_two_eqcls(raw_outs, idx_to_eqcl, row, col)
        # update
        for i in range(filters):
            new_sum = np.sum(np.logical_xor(boolean_outs[:, i], boolean_outs[:, row]), dtype=np.int32)
            xor_mat[i, row] = new_sum
            xor_mat[row, i] = new_sum
            new_sum = np.sum(np.logical_xor(boolean_outs[:, i], boolean_outs[:, col]), dtype=np.int32)
            xor_mat[i, col] = new_sum
            xor_mat[col, i] = new_sum



    while True:
        cur_num_eqcls = len(idx_to_eqcl)
        if cur_num_eqcls <= num_eqcls:
            break
        relu_outs = np.maximum(raw_outs, 0)
        value, row, col = search_for_largest_coe(relu_outs, idx_to_eqcl)
        print(value)
        merge_two_eqcls(raw_outs, idx_to_eqcl, row, col)
    result = list(idx_to_eqcl.values())
    return result

def merge_two_eqcls_with_weights(outs, idx_to_eqcl, host, guest, filter_to_weight):
    host_eqcl = idx_to_eqcl[host]
    guest_eqcl = idx_to_eqcl.pop(guest)
    host_weight_sum = 0.
    guest_weight_sum = 0.
    for e in host_eqcl:
        host_weight_sum += filter_to_weight[e]
    for e in guest_eqcl:
        guest_weight_sum += filter_to_weight[e]
    new_sum = outs[:, host_eqcl[0]] * host_weight_sum + outs[:, guest_eqcl[0]] * guest_weight_sum
    host_eqcl.extend(guest_eqcl)
    if host_weight_sum + guest_weight_sum == 0:
        new_mean = 0
    else:
        new_mean = new_sum / (host_weight_sum + guest_weight_sum)
    # new_mean = new_sum / len(host_eqcl)
    # new_weight = (np.sum(new_mean > 0) + 0.0) / outs.shape[0]
    for e in host_eqcl:
        outs[:, e] = new_mean


def calculate_eqcls_with_weights_and_abandon(raw_outs, num_eqcls, featuremap_weights, abandon_num):
    filters = raw_outs.shape[1]
    idx_to_eqcl = {i: [i] for i in range(filters)}

    sortd_weight_idxes = np.argsort(featuremap_weights)
    idxes_to_abandon = sortd_weight_idxes[:abandon_num]
    featuremap_weights[idxes_to_abandon] = 0
    print('abandoned idxes: ', idxes_to_abandon)

    for idx in idxes_to_abandon:
        idx_to_eqcl.pop(idx)

    while True:
        cur_num_eqcls = len(idx_to_eqcl)
        if cur_num_eqcls <= num_eqcls:
            break
        # relu_outs = np.maximum(raw_outs, 0)
        relu_outs = raw_outs
        value, row, col = search_for_largest_coe(relu_outs, idx_to_eqcl)
        print(value, row, col, featuremap_weights[row], featuremap_weights[col])
        merge_two_eqcls_with_weights(raw_outs, idx_to_eqcl, row, col, featuremap_weights)
    result = list(idx_to_eqcl.values())

    # equivalently, add abandoned idexes to the last eqcl and set weights to zero
    result[-1].extend(idxes_to_abandon.tolist())
    return result

def calculate_eqcls_with_weights(raw_outs, num_eqcls, featuremap_weights, max_eqcl_size=None):
    filters = raw_outs.shape[1]
    if featuremap_weights is None:
        featuremap_weights = np.ones(filters)
    idx_to_eqcl = {i: [i] for i in range(filters)}
    while True:
        cur_num_eqcls = len(idx_to_eqcl)
        if cur_num_eqcls <= num_eqcls:
            break
        value, row, col = search_for_largest_coe(raw_outs, idx_to_eqcl, max_eqcl_size)
        print(value, row, col, featuremap_weights[row], featuremap_weights[col])
        # print(raw_outs[:,row], raw_outs[:,col])
        merge_two_eqcls_with_weights(raw_outs, idx_to_eqcl, row, col, featuremap_weights)
    result = list(idx_to_eqcl.values())
    sorted_result = []
    for eqc in result:
        sorted_result.append(sorted(eqc))
    return sorted_result

def calculate_eqcls_randomly(raw_outs, num_eqcls, featuremap_weights, max_eqcl_size=None):
    filters = raw_outs.shape[1]
    idx_to_eqcl = {i: [i] for i in range(filters)}
    while True:
        cur_num_eqcls = len(idx_to_eqcl)
        if cur_num_eqcls <= num_eqcls:
            break
        # value, row, col = search_for_largest_coe(raw_outs, idx_to_eqcl, max_eqcl_size)
        random_choice = np.random.choice(list(idx_to_eqcl.keys()), 2, replace=False)
        print('random! ', random_choice[0], random_choice[1])
        # print(raw_outs[:,row], raw_outs[:,col])
        weights = featuremap_weights or np.ones(filters)
        merge_two_eqcls_with_weights(raw_outs, idx_to_eqcl, random_choice[0], random_choice[1], filter_to_weight=weights)
    result = list(idx_to_eqcl.values())
    return result

def _sk_cluster(model, layer_idx, num_eqcls, cluster_class):
    # print('got kernel variables: ', model.get_kernel_variables())
    kernel_var = model.get_kernel_variables()[layer_idx]
    kernel_name = kernel_var.name
    x = model.get_value(kernel_var)
    if x.ndim == 4:
        if 'depth' in kernel_name:
            print('got depthwise kernel, something wrong!')
            assert False
        x = np.reshape(x, (-1, x.shape[3]))
    x = np.transpose(x, [1,0])

    if num_eqcls == x.shape[0]:
        result = [[i] for i in range(num_eqcls)]
        return result
    else:
        print('I cluster {} filters of layer {} into {} clusters'.format(x.shape[0], layer_idx, num_eqcls))

    km = cluster_class(n_clusters=num_eqcls)
    # print('the shape to fit is, ', x.shape)
    km.fit(x)
    # print(km.labels_)
    result = []
    for j in range(num_eqcls):
        result.append([])
    for i, c in enumerate(km.labels_):
        result[c].append(i)
    #   do check
    # print(result)
    for r in result:
        assert len(r) > 0
    return result

def calculate_eqcls_by_kmeans(model, layer_idx, num_eqcls):
    print('applying kmeans clustering')
    return _sk_cluster(model, layer_idx, num_eqcls, KMeans)

def calculate_eqcls_by_spectral(model, layer_idx, num_eqcls):
    print('applying spectral clustering')
    return _sk_cluster(model, layer_idx, num_eqcls, SpectralClustering)

def calculate_eqcls_evenly(filters, num_eqcls, max_eqcl_size=None):
    result = []
    min_filters_per_eqcl = filters // num_eqcls
    left = filters % num_eqcls
    cur_filter_idx = 0
    for i in range(num_eqcls):
        if left > 0:
            left -= 1
            nb_filters_cur_eqcl = min_filters_per_eqcl + 1
        else:
            nb_filters_cur_eqcl = min_filters_per_eqcl
        cur_eqcl = [cur_filter_idx + p for p in range(nb_filters_cur_eqcl)]
        cur_filter_idx += nb_filters_cur_eqcl
        result.append(cur_eqcl)
    return result

def calculate_eqcls_biasly(filters, num_eqcls, max_eqcl_size=None):
    result = []

    num_filters_in_first_eqcl = filters - num_eqcls + 1

    first_eqcl = [i for i in range(num_filters_in_first_eqcl)]
    result.append(first_eqcl)
    for i in range(num_eqcls - 1):
        result.append([i + num_filters_in_first_eqcl])

    return result


#   kernels: [h, w, c]
def tf_aggregate_filters(eqcls, layer_idx, method='mean'):
    number_filters_seen = 0
    num_filters_alike = 0
    kernels = tf_extract_kernel_tensors()
    biases = tf_extract_bias_tensors()
    assert len(kernels) == len(biases)
    kv = get_value(kernels[layer_idx])
    bv = get_value(biases[layer_idx])
    for eqcl in eqcls:
        number_filters_seen += len(eqcl)
        if len(eqcl) == 1:
            continue

        num_filters_alike += len(eqcl)
        eqc = np.array(eqcl)
        selected_k = kv[:,:,:,eqc]
        selected_b = bv[eqc]
        if method == 'mean':
            aggregated_k = np.mean(selected_k, axis=3)
            aggregated_b = np.mean(selected_b)
        elif method == 'first':
            aggregated_k = selected_k[:,:,:,0]
            aggregated_b = selected_b[0]
        else:
            l2_norms = np.sum(np.square(selected_k), axis=(0,1,2))
            idx = np.argmin(l2_norms)
            aggregated_k = selected_k[:, :, :, idx]
            aggregated_b = selected_b[idx]

        aggregated_k = np.expand_dims(aggregated_k, axis=3)
        kv[:,:,:,eqc] = aggregated_k
        bv[eqc] = aggregated_b
    set_value(kernels[layer_idx], kv)
    set_value(biases[layer_idx], bv)
    print('aggregation completed! {} eqcls. {} filters seen. {} filters alike. {} filters eliminated'.format(len(eqcls), number_filters_seen, num_filters_alike,  number_filters_seen - len(eqcls)))


#   kernels: [h, w, c]
#
def tf_prune_aggregated_filters_and_save(eqcls, layer_idx, save_np_file, followed_by_fc):
    result = dict()
    number_filters_seen = 0
    num_filters_alike = 0
    kernels = tf_extract_kernel_tensors()
    biases = tf_extract_bias_tensors()
    assert len(kernels) == len(biases)

    for i, (k, b) in enumerate(zip(kernels, biases)):
        if i != layer_idx and i != (layer_idx + 1):
            result[k.name] = get_value(k)
            result[b.name] = get_value(b)

    kv = get_value(kernels[layer_idx])
    bv = get_value(biases[layer_idx])

    kvf = get_value(kernels[layer_idx+1])
    bvf = get_value(biases[layer_idx+1])
    if followed_by_fc:
        conv_indexes_to_delete = []
        fc_indexes_to_delete = []
        assert kvf.shape[0] % kv.shape[3] == 0
        last_conv_origin_deps = kv.shape[3]
        corresponding_neurons_per_kernel = kvf.shape[0] // kv.shape[3]
        base = np.arange(0, corresponding_neurons_per_kernel) * last_conv_origin_deps
        for eqcl in eqcls:
            number_filters_seen += len(eqcl)
            if len(eqcl) == 1:
                continue
            num_filters_alike += len(eqcl)
            conv_indexes_to_delete += eqcl[1:]
            for i in eqcl[1:]:
                fc_indexes_to_delete.append(base + i)
            to_concat = []
            for i in eqcl:
                corresponding_neurons_idxes = base + i  # 49 neurons
                to_concat.append(np.expand_dims(kvf[corresponding_neurons_idxes, :], axis=0))
            merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
            reserved_idxes = base + eqcl[0]
            kvf[reserved_idxes,:] = merged
        kvf = np.delete(kvf, np.concatenate(fc_indexes_to_delete, axis=0), axis=0)
        kv = np.delete(kv, conv_indexes_to_delete, axis=3)
        bv = np.delete(bv, conv_indexes_to_delete)
    else:
        indexes_to_delete = []
        for eqcl in eqcls:
            number_filters_seen += len(eqcl)
            if len(eqcl) == 1:
                continue
            num_filters_alike += len(eqcl)
            indexes_to_delete += eqcl[1:]
            eqc = np.array(eqcl)
            selected_k_follow = kvf[:, :, eqc, :]
            aggregated_k_follow = np.sum(selected_k_follow, axis=2)
            kvf[:,:,eqcl[0],:] = aggregated_k_follow
        kvf = np.delete(kvf, indexes_to_delete, axis=2)
        kv = np.delete(kv, indexes_to_delete, axis=3)
        bv = np.delete(bv, indexes_to_delete)
    result[kernels[layer_idx].name] = kv
    result[biases[layer_idx].name] = bv
    result[kernels[layer_idx+1].name] = kvf
    result[biases[layer_idx+1].name] = bvf
    np.save(save_np_file, result)
    print('pruned aggregated filters. {} filters seen. {} filters alike. shape of pruned kernel {}, shape of following kernel {}'
        .format(number_filters_seen, num_filters_alike, kv.shape, kvf.shape))

# 0.4: MEAN_100K:   0.1865
# 0.5: FIRST:0.2393    MAX_L2:0.2376   MAX_L1:0.2302   MIN:0.2127  MEAN_100k:0.3695   MEAN_1280K:0.3507
# 0.6: MEAN_100K:  0.5399    MEAN_1280K:0.5340

def double_bias_gradients(origin_gradients):
    bias_cnt = 0
    result = []
    print('doubling bias gradients')
    for grad, var in origin_gradients:
        if 'bias' in var.name:
            result.append((2 * grad, var))
            bias_cnt += 1
        else:
            result.append((grad, var))
    print('doubled gradients for {} bias variables'.format(bias_cnt))
    return result

def tf_aggregate_gradients(origin_gradients, eqcls, layer_idx=-1):
    target_k = tf_get_gradients_by_idx(origin_gradients, layer_idx, 'kernel')
    target_b = tf_get_gradients_by_idx(origin_gradients, layer_idx, 'bias')
    agg_k = get_agg_kernel(target_k, eqcls)
    agg_b = get_agg_1d_tensor(target_b, eqcls)
    result = []
    for (g, v) in origin_gradients:
        if g == target_k:
            result.append((agg_k, v))
        elif g == target_b:
            result.append((agg_b, v))
        else:
            result.append((g, v))
    print('aggregate gradients!')
    return result






#   for 4-D kernels
def get_agg_kernel(target_k, eqcls):
    idx_to_tensor = dict()
    idx_to_tensor_list = dict()

    tensor_lists = []
    for eqcl in eqcls:
        tensor_list = []
        tensor_lists.append(tensor_list)
        for e in eqcl:
            idx_to_tensor_list[e] = tensor_list

    num_filters = target_k.get_shape()[3]
    for i in range(num_filters):
        if i in idx_to_tensor_list:
            idx_to_tensor_list[i].append(tf.expand_dims(target_k[:,:,:,i], axis=3))
        else:
            idx_to_tensor[i] = target_k[:,:,:,i]

    num_tensors_to_merge = 0
    num_tensors_after_merge = 0
    for l in tensor_lists:
        num_tensors_to_merge += len(l)
        num_tensors_after_merge += 1
        conc = tf.concat(l, axis=3)
        mean = tf.reduce_mean(conc, axis=3)
        l.append(mean)

    result_list = []
    for i in range(num_filters):
        if i in idx_to_tensor_list:
            result_list.append(tf.expand_dims(idx_to_tensor_list[i][-1], axis=3))
        else:
            result_list.append(tf.expand_dims(idx_to_tensor[i], axis=3))

    print('{} tensors to merge, got {} tensors'.format(num_tensors_to_merge, num_tensors_after_merge))

    return tf.concat(result_list, axis=3)

def get_agg_1d_tensor(target_t, eqcls):
    idx_to_tensor = dict()
    idx_to_tensor_list = dict()

    tensor_lists = []
    for eqcl in eqcls:
        tensor_list = []
        tensor_lists.append(tensor_list)
        for e in eqcl:
            idx_to_tensor_list[e] = tensor_list

    num_filters = target_t.get_shape()[0]
    for i in range(num_filters):
        if i in idx_to_tensor_list:
            idx_to_tensor_list[i].append(tf.expand_dims(target_t[i], axis=0))
        else:
            idx_to_tensor[i] = target_t[i]

    num_tensors_to_merge = 0
    num_tensors_after_merge = 0
    for l in tensor_lists:
        num_tensors_to_merge += len(l)
        num_tensors_after_merge += 1
        conc = tf.concat(l, axis=0)
        mean = tf.reduce_mean(conc)
        l.append(mean)

    result_list = []
    for i in range(num_filters):
        if i in idx_to_tensor_list:
            result_list.append(tf.expand_dims(idx_to_tensor_list[i][-1], axis=0))
        else:
            result_list.append(tf.expand_dims(idx_to_tensor[i], axis=0))

    print('{} tensors to merge, got {} tensors'.format(num_tensors_to_merge, num_tensors_after_merge))

    return tf.concat(result_list, axis=0)



