from itertools import combinations, chain

def all_k_way(domain, k):
    """ 
    return all k-way combinations of attributes in the domain

    Arguments:
    k: the number of attributes in each combination (positive integer)

    Returns:
    a list of tuples of attributes (list)
    """
    return list(combinations(domain.attrs, k))

def powerset(iterable_set):
    """
    Take an iterable that corresponds to a set and return the powerset of that set as an iterator
    powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
    
    Arguments:
    iterable_set: the iterable that corresponds to the set (iterable)

    Returns:
    an iterator over the powerset of the set (iterator)
    """
    s = list(iterable_set)
    return chain.from_iterable(combinations(s, r) for r in range(1,len(s)+1))

def downward_closure(W):
    """
    Take a workload and return the downward closure of the workload (the union of the power sets of the marginals)

    Arguments:
    W: the workload (list of marginals (also lists))

    Returns:
    the downward closure of the workload (list of marginals (also lists))
    """
    ans = set([tuple()])
    for marginal in W:
        ans.update(powerset(marginal))
    return list(sorted(ans, key=len))

def subsets_strict(iterable_set):
    """
    Take an iterable that corresponds to a set and return all strict subsets of that set as an iterator
    powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3)
    
    Arguments:
    iterable_set: the iterable that corresponds to the set (iterable)

    Returns:
    an iterator over all strict subsets of the set (iterator)
    """
    s = list(iterable_set)
    return chain.from_iterable(combinations(s, r) for r in range(1,len(s)))


def downward_closure_strict(W):
    """
    Take a workload and return the strict downward closure of the workload (the union of the stict subsets of the marginals)

    Arguments:
    W: the workload (list of marginals (also lists))

    Returns:
    the strict downward closure of the workload (list of marginals (also lists))
    """
    ans = set([tuple()])
    for marginal in W:
        ans.update(subsets_strict(marginal))
    return list(sorted(ans, key=len))