import json
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--files', nargs='+')
args = parser.parse_args()

all_dicts = []
for f in args.files:
    with open(f, 'r') as fh:
        d = json.load(fh)
    all_dicts.append(d)

keys = all_dicts[0].keys()
assert(d.keys() == keys for d in all_dicts)

total_vals = {}
for k in keys:
    all_vals = []
    for d in all_dicts:
        if isinstance(d[k], dict):
            all_vals.append(d[k]['eval_accuracy'])
        elif isinstance(d[k], list):
            all_vals.append(d[k])
        elif isinstance(d[k], float):
            all_vals.append(d[k])

    all_vals = np.array(all_vals)
    total_vals[k] = all_vals
    mean = all_vals.mean(axis=0)
    std = all_vals.std(axis=0)

    if 'expansion' in k:
        print(k, np.round(all_vals.min(axis=0),2))
    elif 'bdval' in k:
        print(k, np.round(all_vals.max(axis=0),2))
    else:
        print(k, np.round(mean*100,1), np.round(std*100,1))

accuracies = np.vstack((total_vals['S0_test'], total_vals['S1_test'], total_vals['T0_test'], total_vals['T1_test']))
test_accs = (accuracies * np.array([.28,.21,.22,.29])[:,None]).sum(axis=0)
print("test acc", np.round(100*test_accs.mean(), 1), np.round(100*test_accs.std(), 1))

for i in ('0', '1'):
    weak_err = total_vals[f'S{i}_test_weakerr'].max()
    good_expansion = total_vals[f'S{i}_test_T{i}_test_expansion'][:,0].min()
    bad_expansion = total_vals[f'S{i}_test_T{i}_test_expansion'][:,1].min()
    alpha = total_vals[f'S{i}_test_alpha'].max()
    bound_value = (weak_err - bad_expansion * alpha) / (good_expansion - (good_expansion + bad_expansion)*alpha)
    bound_numerator = (weak_err - bad_expansion * alpha)
    bound_denominator = (good_expansion - (good_expansion + bad_expansion)*alpha)
    if min(bound_numerator, bound_denominator) < 0:
        print(f"T{i} worst case bound is vacuous")
    else:
        print(f"T{i} worst case bound value: {bound_value}")
        avg_acc = total_vals[f'T{i}_test'].mean()
        worst_acc = total_vals[f'T{i}_test'].min()
        print(f"T{i} actual avg error: {1-avg_acc}")
        print(f"T{i} actual worst error: {1-worst_acc}")
    print('----------')


for i in ('0', '1'):
    weak_err = total_vals[f'S{i}_test_weakerr'].max()
    bad_expansion = total_vals[f"S{i}_testbad_S{i}_testgood_expansion"].min()
    good_expansion = total_vals[f"S{i}_testgood_S{i}_testbad_expansion"].min()
    alpha = total_vals[f'S{i}_test_alpha'].max()
    new_bound_value = (1-alpha+bad_expansion*alpha)/(1-alpha-bad_expansion*alpha)*(weak_err + alpha) - 2*bad_expansion*alpha/(1-alpha-bad_expansion*alpha)
    badgood_c_lower_bound = (1-alpha)*weak_err / (2*alpha*(1-alpha) - alpha*weak_err)

    print(f"alpha_{i} {alpha}")
    print(f"(worst) weak_err_{i} {weak_err}")
    print(f"S{i} good expansion: {good_expansion}")
    print(f"S{i} bad expansion lower bound: {badgood_c_lower_bound}")
    print(f"S{i} bad expansion: {bad_expansion}")
    print(f"S{i} worst case bound value (new bound): {new_bound_value}")
    avg_acc = total_vals[f'S{i}_test'].mean()
    worst_acc = total_vals[f'S{i}_test'].min()
    print(f"S{i} actual avg error: {1-avg_acc}")
    print(f"S{i} actual worst error: {1-worst_acc}")
    print('-----------')

print(accuracies)
breakpoint()
