import pandas as pd
import torch
from torch.nn import functional as F
import tqdm
#balanced accuracy
from monotonenorm.monotonicnetworks import GroupSort, direct_norm, SigmaNet
from BLNN import PICNN

df_train = pd.read_csv("data/heart_train.csv")
df_test = pd.read_csv("data/heart_test.csv")

def get_comp(list, length):
    id = []
    for i in range (length):
        if i not in list:
            id.append(i)
    return torch.tensor(id)



def preprocess(df):
  X = df.drop(columns=['target']).values
  Y = df['target'].values
  X = torch.tensor(X.astype(float), dtype=torch.float32)
  Y = torch.tensor(Y.astype(float), dtype=torch.float32).view(-1, 1)
  X = (X - X.mean(0)) / X.std(0)
  return X, Y
Xtr, Ytr = preprocess(df_train)
Xts, Yts = preprocess(df_test)

def get_acc(Yhat, Y):
  max_acc = 0
  for threshold in torch.linspace(-1, 1, 100):
    acc = (Yhat > threshold) == Y
    acc = acc.sum().item() / acc.numel()
    max_acc = max(max_acc, acc)
  return max_acc

accs = []
for seed in range(3):
  torch.manual_seed(seed)

  width = 10
  mono_list=[4,5]
  c_mono = get_comp(mono_list, len(Xtr[0]))

  model = PICNN(len(Xtr[0])-len(mono_list),len(mono_list),width,1, 10, 2)

  # number of elements
  print(sum(p.numel() for p in model.parameters()))


  optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
  epochs = 300

  best_acc = 0
  bar = tqdm.tqdm(range(epochs))
  for epoch in bar:
    optimizer.zero_grad()
    yhat = model(Xtr[:,c_mono],Xtr[:,mono_list])
    loss = F.binary_cross_entropy_with_logits(torch.nn.Sigmoid()(yhat), Ytr)
    loss.backward()
    optimizer.step()
    train_acc = get_acc(yhat, Ytr)
    with torch.no_grad():
      yhat = model(Xts[:,c_mono],Xts[:,mono_list])
      accuracy = get_acc(torch.nn.Sigmoid()(yhat), Yts)
      best_acc = max(best_acc, accuracy)
      bar.set_description(f"loss {loss.item():.3f} train acc: {train_acc:.3f}, test acc: {accuracy:.3f}, best acc: {best_acc:.3f}")
  accs.append(best_acc)

print(f"mean: {torch.tensor(accs).mean():.3f}, std: {torch.tensor(accs).std():.4f}")
