
# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# %%
df1 = pd.read_csv('./out/quinn2020.csv')

# %%

# average over seeds
means = df1.pivot_table(values=['auc', 'bacc'], index=['data_idx', 'aug_params', 'tran_params', 'dr_params', 'head_params'])
means = means.reset_index()
# average over datasets
means = means.pivot_table(values=['auc', 'bacc'], index=['aug_params', 'tran_params', 'dr_params', 'head_params'])
means = means.reset_index()

print(means.sort_values('auc').tail(40))

# %%

trans = "clr"
means = means[means['tran_params'] == trans]

# print(means.sort_values('auc'))

# %%

x = [1, 2, 5, 10, 20]
y_idx = [
    "{}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 2}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 5}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 10}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 20}",
]

y_idx = [
    "{}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 2}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 5}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 10}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 20}",
]

# y_idx = [
#     "{}",
#     "{'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 2}",
#     "{'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 5}",
#     "{'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 10}",
# ]

# %%

metric = 'bacc'

head_params = "{'model': 'mlp'}"
means_head = means[means['head_params'] == head_params].set_index("aug_params")
y = means_head.loc[y_idx][metric]
plt.plot(x, y, label='NN')

# %%

# head_params = "{'model': 'mlp', 'early': True}"
# means_head = means[means['head_params'] == head_params].set_index("aug_params")
# y = means_head.loc[y_idx][metric]
# plt.plot(x, y)

# %%

head_params = "{'model': 'rf'}"
means_head = means[means['head_params'] == head_params].set_index("aug_params")
y = means_head.loc[y_idx][metric]
plt.plot(x, y, label='RF')
# %%

# head_params = "{'model': 'svm'}"
# means_head = means[means['head_params'] == head_params].set_index("aug_params")
# y = means_head.loc[y_idx][metric]
# plt.plot(x, y, label='SVM')

# %%

head_params = "{'model': 'ridge'}"
means_head = means[means['head_params'] == head_params].set_index("aug_params")
y = means_head.loc[y_idx][metric]
plt.plot(x, y, label='CLR-ridge')

# %%
plt.legend()
plt.xlabel("Augmentation factor (*n)")
plt.ylabel("Average test Balanced Accuracy")
plt.savefig("bacc.pdf")
plt.show()
