import numpy as np
from itertools import product
import math

def Beta(a, b, obs):
    if len(obs) == 0:
        return (math.factorial(a - 1) * math.factorial(b - 1)) / math.factorial(
            a + b - 1
        )
    else:
        a = a + sum(obs)
        b = len(obs) + b - sum(obs)
        return (math.factorial(a - 1) * math.factorial(b - 1)) / math.factorial(
            a + b - 1
        )


def compute_z(x, y):
    z = (1 - x) * y + (1 - y) * x
    return z

def analytic_causal_effect_computation(interve_var, interve_val, dag, alpha, beta):
    # compute analytic causal effects under beta-bernoulli dist with binary values
    # intialize table with each entry corresponds to specific value
    # x1, x2, y1, y2
    keys = [(a, b, c, d) for a, b, c, d in product([0, 1], [0, 1], [0, 1], [0, 1])]
    intervention_desc = ('do(' + interve_var + ')' + '=' + str(interve_val))
    joint_causal_effect = {intervention_desc: dict.fromkeys(keys, 0)}
    var_all = ['x1', 'x2', 'y1', 'y2']
    position = var_all.index(interve_var)
    if dag == 'xtoy':
        for key in keys:
            indicator = (key[position] == interve_val)
            [z1, z2] = [compute_z(key[0], key[2]), compute_z(key[1], key[3])]
            if position <= 1:
                pcauses = indicator * Beta(alpha, beta, [key[1-position]]) / Beta(alpha, beta, [])
                peffect_given_cause = Beta(alpha, beta, [z1, z2]) / Beta(alpha, beta, [])
            else:
                pcauses = Beta(alpha, beta, [key[0], key[1]]) / Beta(alpha, beta, [])
                obs = [z1] if position == 3 else [z2]
                peffect_given_cause = indicator * Beta(alpha, beta, obs) / Beta(alpha, beta, [])
            joint_causal_effect[intervention_desc][key] = peffect_given_cause * pcauses
    elif dag == 'ytox':
        for key in keys:
            indicator = (key[position] == interve_val)
            [z1, z2] = [compute_z(key[0], key[2]), compute_z(key[1], key[3])]
            if position <= 1:
                pcauses = Beta(alpha, beta, [key[2], key[3]]) / Beta(alpha, beta, [])
                obs = [z2] if position == 0 else [z1]
                peffect_given_cause = indicator * Beta(alpha, beta, obs) / Beta(alpha, beta, [])
            else:
                observed_cause = 3 if position == 2 else 2
                pcauses = indicator * Beta(alpha, beta, [key[observed_cause]]) / Beta(alpha, beta, [])
                peffect_given_cause = Beta(alpha, beta, [z1, z2]) / Beta(alpha, beta, [])
            joint_causal_effect[intervention_desc][key] = peffect_given_cause * pcauses
    else:
        for key in keys:
            indicator = (key[position] == interve_val)
            if position <= 1:
                pcauses = indicator * Beta(alpha, beta, [key[1-position]]) / Beta(alpha, beta, [])
                peffects = Beta(alpha, beta, [key[2], key[3]]) / Beta(alpha, beta, [])
            else:
                pcauses = Beta(alpha, beta, [key[0], key[1]]) / Beta(alpha, beta, [])
                observed_cause = 3 if position == 2 else 2
                peffects = indicator* Beta(alpha, beta, [key[observed_cause]]) / Beta(alpha, beta, [])
            joint_causal_effect[intervention_desc][key] = peffects * pcauses
    return joint_causal_effect
