import torch
from loaders.compas_loader import load_data, mono_list
import torch.utils.data as Data
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from monotonenorm.monotonicnetworks import SigmaNet, direct_norm, GroupSort
def empirical_lipschitz(model, x, scale=1, eps=0.05):

    norms = lambda X: X.view(X.shape[0], -1).norm(dim=1) ** 2
    gam = 0.0
    L = []
    model_output = model(scale*x)
    for i in range (3000):
        train_evalset = np.random.choice(np.arange(x.shape[0]), 2, replace=False)
        train_data = x[train_evalset]
        fx = model_output[train_evalset]
        L.append(torch.norm(fx[0]-fx[1]).cpu().detach().numpy()/np.linalg.norm((train_data[0]-train_data[1]).cpu().numpy()))
    gam = np.max(L)
    return gam
def square_wave(size, xrand=True):
        if xrand:
            x = 2*(2*torch.rand(size,1) - 1)
        else:
            #x = torch.linspace(-1.0, 1.0, size).reshape((size,1))
            x = torch.linspace(-0.5,0.5, size).reshape((size,1))
            #x = torch.linspace(-0.2,0.0, size).reshape((size,1))
        y = torch.zeros((size,1))+x
        for i in range(size):
            #if x[i, 0] <= -1.0:
            #    y[i, 0] += 1.0
            if x[i, 0] > 0.0:#and x[i, 0] <= 1.0:
                y[i, 0] += 1.0
        return x, y
from argparse import ArgumentParser

parser = ArgumentParser()

parser.add_argument('-norm', '--norm', \
        type=str, default="one-inf")
parser.add_argument('-gamma', '--gamma', \
        type=float, default=1)
parser.add_argument('-ep', '--epochs', \
        type=int, default=2000)
args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


train_size, test_size = 300, 2000
x, y = square_wave(train_size)
#print(x,y)
xt, yt = square_wave(test_size, xrand=False)
data = {"xt":xt, "yt":yt, "x":x, "y":y}

X_train = x.to(device)
X_test = xt.to(device)
y_train = y.to(device)
y_test = yt.to(device)


monotone_constraints = [1]

gamma = args.gamma

per_layer_lip = 1




def run(seed):
    torch.manual_seed(seed)

    width = 64

    network = torch.nn.Sequential(
        direct_norm(
            torch.nn.Linear(X_train.shape[1], width),
            kind=args.norm,
            max_norm=per_layer_lip,
        ),
        GroupSort(width//2),
        direct_norm(torch.nn.Linear(width, width), kind="inf", max_norm=per_layer_lip),
        GroupSort(width//2),
        direct_norm(torch.nn.Linear(width, 1), kind="inf", max_norm=per_layer_lip),
    )
    network = torch.nn.Sequential(
        SigmaNet(
            network, sigma = 1 , invlip = 0, scale = (gamma)/2, monotonic_constraints=monotone_constraints
        ),
#        torch.nn.Sigmoid(),
    )

    network = network.to(device)

    print("params:", sum(p.numel() for p in network.parameters()))

    optimizer = torch.optim.Adam(network.parameters(), lr=0.01, weight_decay=0)
    Epochs = args.epochs
    Lr = 0.01
    lr_schedule = lambda t: np.interp([t], [0, Epochs*2//5, Epochs*4//5, Epochs], [0, Lr, Lr/20.0, 0])[0]

    data_train_loader = Data.DataLoader(dataset=Data.TensorDataset(X_train, y_train),batch_size=50, shuffle=True)

    bar = tqdm(range(Epochs))
    steps_per_epoch = len(data_train_loader)
    acc = -10000000
    for i in bar:
        for batch_idx, batch in enumerate(data_train_loader):
            X,y = batch[0], batch[1]
            lr = lr_schedule(i + (batch_idx+1)/steps_per_epoch)
            optimizer.param_groups[0].update(lr=lr)
            y_pred = network(X)
            loss_train = F.mse_loss(y_pred, y)
            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()

        with torch.no_grad():
            y_pred = network(X_test)
            loss = F.mse_loss(y_pred, y_test)


            acc = min(acc, loss)
            bar.set_description(
                f"train: {loss_train.item():.4f}, test: {loss.item():.4f}, current loss: {loss:.4f}, best loss: {acc:.4f}, lr : {lr:.4f}"
            )
    y_pred = network(X_test)
    print("plotting...")
    plt.plot(X_test.cpu().detach().numpy(),y_test.cpu().detach().numpy(),label="True")
    plt.scatter(X_test.cpu().detach().numpy(),y_pred.cpu().detach().numpy(), label="Learned")
    plt.legend()
    plt.plot()
    print("result.png")
    plt.savefig("result.png")
    print("calculating Lipschitz")
    gam = empirical_lipschitz(network, xt.cuda())
    print(f"Lipschitz capcity: {gam:.4f}/{gamma:.2f}, {100*gam/gamma:.2f}")
    return acc


accs = [run(i) for i in range(5)]
print(f"mean: {np.mean(accs):.4f}, std: {np.std(accs):.4f}")
