from pulp import *

import numpy as np
import random

def GreedyDecomposition(K, C, demand_sets, total_demand):
    N = len(demand_sets)
    assignments = []

    for k in range(K):
        cur_assignment = []
        for i in range(N):
            if k in demand_sets[i]:
                cur_assignment.append(k)
            else:
                cur_assignment.append(demand_sets[i][0])
        assignments.append(cur_assignment)
    
    assignments = assignments * total_demand
    return assignments

def RandomizedDecomposition(K, C, demand_sets, total_demand):
    N = len(demand_sets)
    assignments = []

    feedback_required = [[0 for _ in range(K)] for _ in range(N)]
    for i in range(N):
        for k in demand_sets[i]:
            feedback_required[i][k] = total_demand // len(demand_sets[i])
    
    while len(assignments) < K * total_demand:
        if all(all(demand <= 0 for demand in row) for row in feedback_required):
            break

        # generate assignment
        cur_assignment = [random.choice(demand_sets[i]) if demand_sets[i] else random.randint(0, K-1) 
                          for i in range(N)]
        action_counts = [0 for _ in range(K)]
        for k in cur_assignment:
            action_counts[k] += 1
        
        for i in range(N):
            if action_counts[cur_assignment[i]] >= C+1:
                feedback_required[i][cur_assignment[i]] -= 1
        
        assignments.append(cur_assignment)
    
    if len(assignments) == K * total_demand:
        # randomization failed, resort to Greedy
        return GreedyDecomposition(K, C, demand_sets, total_demand)

    return assignments

def LPDecomposition(K, C, demand_sets, total_demand):
    N = len(demand_sets)

    result = find_alpha(N, K, C, demand_sets)
    if result is None:
        # failed at finding valid alpha (underlying graph doesn't satisfy
        # the C-user-clustering assumption)
        return RandomizedDecomposition(K, C, demand_sets, total_demand)

    alpha, x = result    
    convex_decomp = caratheodory(x, N, K, C, demand_sets)

    if convex_decomp is None:
        # failed at caratheodory (underlying graph doesn't satisfy
        # the C-user-clustering assumption)
        return RandomizedDecomposition(K, C, demand_sets, total_demand)

    vertex_to_assignment = lambda v: [row.index(1) for row in v]
    assignments = []

    for vertex, coeff in convex_decomp:
        num_copies = int(coeff * total_demand) + 1

        assignment = vertex_to_assignment(vertex)
        assignments.extend([assignment for _ in range(num_copies)])

    if len(assignments) >= K * total_demand:
        # worse than greedy, resort to randomized (happens when total_demand is small)
        return RandomizedDecomposition(K, C, demand_sets, total_demand)

    return assignments

HEURISTIC_SCALE = 2.5
def HeuristicLPDecomposition(K, C, demand_sets, total_demand):
    print('decomposing...')
    N = len(demand_sets)

    V = int(HEURISTIC_SCALE*N*K)

    # generate V weak vertices from P
    vertices = []
    for _ in range(V):
        next_vertex = get_random_weak_vertex(N, K, C, demand_sets)
        if next_vertex not in vertices:
            vertices.append(next_vertex)

    V = len(vertices)
    print('V:', V)

    demand_point = [[0 for _ in range(K)] for _ in range(N)]

    for i in range(N):
        for j in demand_sets[i]:
            demand_point[i][j] = 1/len(demand_sets[i])

    lambdas = get_heuristic_decomposition(N, K, demand_point, vertices)
    if lambdas is None:
        # failed at caratheodory (underlying graph doesn't satisfy
        # the C-user-clustering assumption)
        return RandomizedDecomposition(K, C, demand_sets, total_demand)

    vertex_to_assignment = lambda v: [row.index(1) if 1 in row else random.choice(demand_sets[i]) 
                                      for i, row in enumerate(v)]
    assignments = []

    for vertex, coeff in zip(vertices, lambdas):
        num_copies = int(coeff * total_demand) + 1

        assignment = vertex_to_assignment(vertex)
        assignments.extend([assignment for _ in range(num_copies)])


    if len(assignments) >= K * total_demand:
        print('resorting to randomized')
        # worse than greedy, resort to randomized (happens when total_demand is small)
        return RandomizedDecomposition(K, C, demand_sets, total_demand)

    print('done')

    return assignments

#### Convex geometry helper functions

def get_heuristic_decomposition(N, K, x, verts):
    V = len(verts)
    
    prob = LpProblem("HeuristicDecompositionLP", LpMinimize)

    lambdas = [LpVariable("lambda_{}".format(i), 0, 1) for i in range(V)]
    alpha = LpVariable("alpha", 0, 1)

    prob += lpSum(lambdas)

    for i in range(N):
        for j in range(K):
            prob += (lpSum(lambdas[k]*verts[k][i][j] for k in range(V)) >= x[i][j])

    status = prob.solve(PULP_CBC_CMD(msg=0))
    if LpStatus[status] == 'Optimal':
        lambda_vals = [value(lambdas[i]) for i in range(V)]
        return lambda_vals
    else:
        return None
    
def find_alpha(N, K, C, G):
    prob = LpProblem("PrivateBanditLP", LpMaximize)
    xs = [[LpVariable("x_{},{}".format(i,j), 0, 1) for j in range(K)] for i in range(N)]
    alpha = LpVariable("alpha", 0, 1)

    prob += alpha

    for i in range(N):
        for j in range(K):
            factor = 1./len(G[i])
            if j in G[i]:
                prob += (xs[i][j] >= alpha*factor)
            else:
                prob += (xs[i][j] == 0)
    
    for i in range(N):
        prob += lpSum(xs[i][j] for j in range(K)) <= 1
    
    for j in range(K):
        prob += lpSum(xs[i][j] for i in range(N)) >= C+1

    status = prob.solve(PULP_CBC_CMD(msg=0))
    if LpStatus[status] == 'Optimal' and value(alpha) is not None:
        alpha_val = value(alpha)
        x_val = [[value(xs[i][j]) for j in range(K)] for i in range(N)]
        return alpha_val, x_val
    else:
        return None

def in_convex_hull(pts, x, N, K):
    M = len(pts)

    prob = LpProblem("ConvexHullLP", LpMaximize)
    lambdas = [LpVariable("lambda_{}".format(i), 0, 1) for i in range(M)]
    
    prob += 0

    prob += (lpSum(lambdas) == 1.0)

    for i in range(N):
        for j in range(K):
            prob += (lpSum(lambdas[k]*pts[k][i][j] for k in range(M)) >= 0.99*x[i][j])

    status = prob.solve(PULP_CBC_CMD(msg=0))
    if LpStatus[status] == 'Optimal':
        lambda_vals = [value(lambdas[i]) for i in range(M)]
        return lambda_vals
    else:
        return None

def assignment_to_vertex(assignment, N, K, C, only_strong=True):
    counts = [0 for _ in range(K)]
    for j in assignment:
        counts[j] += 1
    
    if only_strong and not all(c > C for c in counts):
        return None

    for i, j in enumerate(assignment):
        if counts[j] <= C:
            assignment[i] = -1
    
    vertex = [[0 for _ in range(K)] for _ in range(N)]
    for i, j in enumerate(assignment):
        if j != -1:
            vertex[i][j] = 1
    return vertex

SAMPLE_ITER_CUTOFF = 100
def get_random_vertex(N, K, C, G):
    for _ in range(SAMPLE_ITER_CUTOFF):
        cur_assignment = [random.choice(G[i]) for i in range(N)]
        cur_vertex = assignment_to_vertex(cur_assignment, N, K, C)

        if cur_vertex is not None:
            return cur_vertex
    
    return get_random_vertex_lp(N, K, C, G)

def get_random_weak_vertex(N, K, C, G):
    cur_assignment = [random.choice(G[i]) for i in range(N)]
    cur_vertex = assignment_to_vertex(cur_assignment, N, K, C, only_strong=False)

    return cur_vertex

def get_random_vertex_lp(N, K, C, G):
    prob = LpProblem("RandomVertexLP", LpMaximize)
    xs = [[LpVariable("x_{},{}".format(i,j), 0, 1) for j in range(K)] for i in range(N)]

    rand_vec = [[np.random.random() for _ in range(K)] for _ in range(N)]

    prob += lpSum(xs[i][j] * rand_vec[i][j] for j in range(K) for i in range(N))

    for i in range(N):
        for j in range(K):
            if j not in G[i]:
                prob += (xs[i][j] == 0)

    for i in range(N):
        prob += lpSum(xs[i][j] for j in range(K)) <= 1

    for j in range(K):
        prob += lpSum(xs[i][j] for i in range(N)) >= C+1
    
    round_val = lambda x: 1 if x > 0.5 else 0

    status = prob.solve(PULP_CBC_CMD(msg=0))
    if LpStatus[status] == 'Optimal':
        vertex = tuple([tuple([round_val(value(xs[i][j])) for j in range(K)]) for i in range(N)])
        return vertex
    else:
        return None

STEP_VERTICES = 10
CARA_ITER_THRESHOLD = 100
CARA_VERT_THRESHOLD = 500
def caratheodory(x, N, K, C, G):
    vertices = []
    while True:
        # add 10 vertices at a time to speed things up
        for _ in range(STEP_VERTICES):
            new_vert = get_random_vertex(N, K, C, G)

            cnt = 0
            while new_vert in vertices:
                new_vert = get_random_vertex(N, K, C, G)
                cnt += 1
                if cnt > CARA_ITER_THRESHOLD:
                    # hard to find new vertex
                    return None

            vertices.append(new_vert)

        if len(vertices) > N*K:
            convex_decomp = in_convex_hull(vertices, x, N, K)
            if convex_decomp is not None:
                ret = [(vert, coeff) for vert, coeff in zip(vertices, convex_decomp) if coeff > 1e-4]
                return ret

        if len(vertices) >= CARA_VERT_THRESHOLD:
            return None

