import numpy as np

# ###### Divisive local search algorithms
def local_search(similarities, max_swaps_fn=None, normalize=False, verbose=False):
    """ Find the cut (A, B) maximizing |B|\sum_{a,a' \in A} w(a, a') + [symmetric in B]

    similarities: (n, n) matrix of similarity scores. Must be symmetric
    max_swaps_fn: function of n indicating how many swaps to allow, to bound running time
    normalize: If True, use the bisecting k-means score, which normalizes by revenue by |A||B|

    """
    n = similarities.shape[0]
    assert n > 1
    # start with random cut
    A = np.random.randint(low=0, high=2, size=n)
    n_a = np.sum(A)
    if n_a == 0:
        A[0] = 1
        n_a = 1
    if n_a == n:
        A[0] = 0
        n_a = n-1
    n_b = n-n_a

    # Define the weight function
    # Total revenue is f(|A|, |B|)w_A + f(|B|, |A|)w_B
    if normalize:
        f = lambda n_a, n_b: 1./n_a
    else:
        f = lambda n_a, n_b: n_b

    # Define the maximum number of swaps allowed
    if max_swaps_fn is None:
        max_swaps = 1000000000 # maxint
    else:
        max_swaps = max_swaps_fn(n)

    A_mask = list(map(bool,A))
    B_mask = list(map(bool,1-A))

    w = np.sum(similarities) # sum of all similarities
    w_x = np.sum(similarities, axis=-1) # for each element w, sum of similarities to all other elems
    w_A = np.sum(np.tril(similarities[A_mask][:, A_mask])) # sum of weights among A
    w_B = np.sum(np.tril(similarities[B_mask][:, B_mask])) # sum of weights among B
    w_xA = np.sum(similarities[:, A_mask], axis=-1) # for each element x, sum of weights between x and A
    w_xB = np.sum(similarities[:, B_mask], axis=-1) # for each element x, sum of weights between x and B
    revenue = f(n_a, n_b)*w_A + f(n_b, n_a)*w_B

    for _ in range(max_swaps):
        if verbose:
            print("A, w_A, w_B, w_xA, w_xB, revenue", A, w_A, w_B, w_xA, w_xB, revenue)
        # find best element to move
        # Consider the case moving an element from A to B
        # Old cost: f(|A|, |B|)(w_A - w_xA) + f(|A|, |B|)w_xA + f(|B|, |A|)w_B
        # New cost: f(|A|-1, |B|+1)(w_A - w_xA) + f(|B|+1, |A|-1)w_xB + f(|B|+1, |A|-1)w_B
        rev = np.zeros(n)
        if n_a > 1:
            # calculate revenue gain from switching an element from A to B
            rev_A = (f(n_a-1, n_b+1)-f(n_a, n_b))*(w_A-w_xA) \
                    + (f(n_b+1, n_a-1)-f(n_b, n_a))*w_B \
                    + f(n_b+1, n_a-1)*w_xB - f(n_a, n_b)*w_xA
            rev_A = rev_A * A # mask for only elements in A
            rev = rev + rev_A
        if n_b > 1:
            # calculate revenue gain from switching an element from B to A
            rev_B = (f(n_b-1, n_a+1)-f(n_b, n_a))*(w_B-w_xB) \
                    + (f(n_a+1, n_b-1)-f(n_a, n_b))*w_A \
                    + f(n_a+1, n_b-1)*w_xA - f(n_b, n_a)*w_xB
            rev_B = rev_B * (1-A) # mask for only elements in B
            rev = rev + rev_B
        if np.max(rev) > 0.:
            ind = np.argmax(rev)
            revenue += rev[ind]
            if verbose: print("swapping index ", ind)
            a = 1 if A[ind] else -1 # was this originally in A
            w_A = w_A - a * w_xA[ind]
            w_B = w_B + a * w_xB[ind]
            w_xA = w_xA - a * similarities[ind, :] # for every x in A, w_xA -= a*w(x, ind)
            w_xB = w_xB + a * similarities[ind, :]
            n_a -= a
            n_b += a
            A[ind] = 1 - A[ind]
        else:
            # No more local moves
            break
    if verbose:
        print("revenue:", revenue)

    A_mask = list(map(bool,A))
    B_mask = list(map(bool,1-A))
    A_inds = np.arange(n, dtype=np.int)[A_mask]
    B_inds = np.arange(n, dtype=np.int)[B_mask]
    # return A, 1-A
    return A_inds, B_inds



def divisive_local_search(similarities, bkm=False):
    """ Divisive local search from similarity scores.
    See [Moseley, Wang] Approximation Bounds for Hierarchical Clustering: Average Linkage, Bisecting K-means, and Local Search

    similarities: (n, n) matrix of similarity scores
    bkm: if True, use bisecting k-means objective instead
    
    Output: array of parent pointers of length 2n-1. The last element is guaranteed to be the root, i.e. parent[-1] = -1
    """
    n = similarities.shape[0]
    if n == 1: return np.array([-1])
    # if n == 2: return np.array([2, 2, -1])

    parent = np.zeros(2*n-1, dtype=int) - 1

    A, B = local_search(similarities, normalize=bkm)
    n_a, n_b = len(A), len(B)
    A_similiarities = similarities[A][:, A]
    B_similiarities = similarities[B][:, B]
    A_parent = divisive_local_search(A_similiarities, bkm)
    B_parent = divisive_local_search(B_similiarities, bkm)

    # Pad A, B to include the intermediate nodes
    # print(A, B)
    A_idx = np.concatenate([A, np.arange(n, n+n_a-1, dtype=np.int)])
    B_idx = np.concatenate([B, np.arange(n+n_a-1, 2*n-2, dtype=np.int)])
    # print("indices", A_idx, B_idx)
    # print("parents", A_parent, B_parent)
    # print(A_idx[A_parent])
    parent[A_idx] = A_idx[A_parent]
    parent[B_idx] = B_idx[B_parent]
    parent[A_idx[-1]] = 2*n-2
    parent[B_idx[-1]] = 2*n-2
    # parent[n+n_a-2] = 2*n-2
    # parent[2*n-3] = 2*n-2

    # print(parent)

    return parent

def test_local_search():
    n = 6
    # similarities = np.arange(n*n).reshape(n, n)
    similarities = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            similarities[i, j] = np.abs(j-i)
    print("Similarity matrix:")
    print(similarities)
    A, B = local_search(similarities, max_swaps_fn=None, normalize=False)

def test_divisive_local_search():
    n = 6
    # similarities = np.arange(n*n).reshape(n, n)
    similarities = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            similarities[i, j] = np.abs(j-i)
    print("Similarity matrix:")
    print(similarities)

    parent = divisive_local_search(similarities, bkm=False)
    print(parent)
    parent = divisive_local_search(similarities, bkm=True)
    print(parent)

if __name__ == "__main__":
    # test_local_search()
    test_divisive_local_search()
