import argparse
import pickle

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from actfuns import *
from dataset import JSBDataset
from models import FCNN

parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--lr", default=0.001, type=float)
parser.add_argument("--seq_len", default=4, type=int)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--idx", default=1, type=int)
args = parser.parse_args()
args.idx -= 1  # array task index is 1-indexed

ACTFUN = [
    "relu",
    "prelu",
    "maxout",
    "max_min_dup",
    "signedgeomean",
    "ail_and",
    "ail_xnor",
    "ail_or",
    "ail_and_or_dup",
    "ail_or_xnor_part",
    "ail_or_xnor_dup",
    "ail_and_or_xnor_part",
    "ail_and_or_xnor_dup",
]
h = 1536
a = ACTFUN[args.idx]

fname = "Jsb16thSeparated.json"
train_data = JSBDataset(fname, seq_len=args.seq_len, num_tokens=37, split="train")
train_loader = DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True, drop_last=True
)
test_data = JSBDataset(fname, seq_len=args.seq_len, num_tokens=37, split="test")
test_loader = DataLoader(
    train_data, batch_size=args.batch_size, shuffle=False, drop_last=True
)

nc_in = train_data[0][0].flatten().size(0)
model = FCNN(nc_in, nc_hidden=h, actfun=a).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(1, args.epochs + 1):
    for batch_idx, batch in enumerate(train_loader, 1):
        inp = batch[0].cuda()
        tgt = batch[1].cuda()

        optimizer.zero_grad()
        out = model(inp.flatten(start_dim=1))
        loss = criterion(out, tgt)
        loss.backward()
        optimizer.step()
model.eval()

with open("logs/logs_jsb_actfun_weights/weight_%s.pt" % (a), "wb") as f:
    torch.save(model.fc1.weight, f)
