import torch
import torch.nn as nn

import numpy as np

import math

from train_model import training_setup, train
from dataset import read_dataset

dataset = 'iris'
noise = 0.4
trial = 2

base_reg = 0.5 * noise ** 2
reg_factors = [0.25,0.5,1,2,4]

lr = 0.1
patience = 50 # how often to decay learning rate
factor = 0.3 # how much to decay learning rate

epochs = 1000

torch.manual_seed(3)


for f in reg_factors:
    reg = f * base_reg

    X, y0, y1 = read_dataset(f'./datasets/{dataset}/noise_{noise}/trial_{trial}')
    X, y0, y1 = read_dataset(f'./datasets/{dataset}/noise_{noise}')
    X = torch.tensor(X)
    y1 = torch.tensor(y1)

    batch_size = X.shape[0]

    model, optimizer, scheduler = training_setup(X, lr, factor, patience)
    loss = train(model, optimizer, scheduler, X, y1, epochs, batch_size, reg=None)


    # X_orig, y0_orig, y1_orig = read_dataset(f'./datasets/{dataset}')
    # X_orig = torch.tensor(X_orig)
    # y1_orig = torch.tensor(y1_orig)

    # batch_size = X_orig.shape[0]

    # model_orig, optimizer, scheduler = training_setup(X_orig, lr, factor, patience)
    # loss_orig = train(model_orig, optimizer, scheduler, X_orig, y1_orig, epochs, batch_size, reg)

    # print(loss, model.loss(X, y1).item(), model.loss(X_orig, y1_orig).item(), model.norm().item(), model.loss(X_orig, y1_orig).item() * math.exp(0.5 * noise ** 2 * model.norm().item()))
    # print(loss_orig, model_orig.loss(X, y1).item(), model_orig.loss(X_orig, y1_orig).item(), model_orig.norm().item(), model_orig.loss(X_orig, y1_orig).item() * math.exp(0.5 * noise ** 2 * model_orig.norm().item()))

print(X.shape, X_orig.shape)