import numpy as np
import pandas as pd
import pickle

methods = "ols-split ols-cc ols-ccrand qr-split qrf-split qrf-ccrand".split(' ')

with open('/Users/cherian/Projects/conformal-docker/results_lp.pkl', 'rb') as fp:
    res = pickle.load(fp)

with open('/Users/cherian/Projects/conformal-docker/results_qr.pkl', 'rb') as fp:
    res_qr = pickle.load(fp)

method_lengths = {m : [] for m in methods}
for r, r_qr in zip(res, res_qr):
    for i, method in enumerate(methods):
        lengths = r['length'][i].flatten()
        if "qrf-split" in method:
            pred_lengths_qrf = r_qr['length'][1].flatten() - lengths
            lengths = lengths + pred_lengths_qrf
        elif "qrf-cc" in method:
            lengths = lengths + pred_lengths_qrf
        elif "qr-split" in method:
            pred_lengths_qr = r_qr['length'][0].flatten() - lengths
            lengths = lengths + pred_lengths_qr
        method_lengths[method].append(np.median(lengths))

method_wc_cov = {m : [] for m in methods}
for r in res:
    x = r['x_test'].T
    x = x / np.sum(x, axis=1).reshape(-1,1)
    for i, method in enumerate(methods):
        coverage_dev = r['coverage'][i].reshape(-1,1) - 0.9
        wc_cov = np.max(np.abs(x @ coverage_dev))
        method_wc_cov[method].append(wc_cov)

with open('/Users/cherian/Projects/conformal-gan/results_small.pkl', 'rb') as fp:
    res_small = pickle.load(fp)

method_wc_cov = {m : [] for m in methods}
for r in res_small:
    x = r['x_test'].T
    x = x / np.sum(x, axis=1).reshape(-1,1)
    for i, method in enumerate(methods):
        coverage_dev = r['coverage'][i].reshape(-1,1) - 0.9
        wc_cov = np.max(np.abs(x @ coverage_dev))
        method_wc_cov[method].append(wc_cov)

method_lengths = {m : [] for m in methods}
for r in res_small:
    x = r['x_test'].T
    x = x / np.sum(x, axis=1).reshape(-1,1)
    for i, method in enumerate(methods):
        lengths = r['length'][i].flatten()
        method_lengths[method].append(np.median(lengths))

import IPython; IPython.embed()