import numpy as np
import scipy as sp
from scipy import stats
import random
import math
import pickle
import argparse
import sys
import os
import torch, torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable, Function
import copy
import os.path as osp


device = 'cuda' if torch.cuda.is_available() else 'cpu'

parser = argparse.ArgumentParser(description='PyTorch CIFAR MAIL')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--n_task', type=int, default=100, help='number of task')
parser.add_argument('--train_shot', type=int, default=10, help='number of training samples in meta-training')
parser.add_argument('--test_shot', type=int, default=10, help='number of training samples in meta-testing')
parser.add_argument('--query', type=int, default=200, help='number of test samples in meta-testing')
parser.add_argument('--lamda', type=float, default=0.5, help='reg. parameter')
parser.add_argument('--eta', type=float, default=0.02, help='eta')
parser.add_argument('--gamma', type=float, default=0.1, help='gamma')
parser.add_argument('--T', type=int, default=100, help='meta-level iterations')
parser.add_argument('--K', type=int, default=15, help='task-level iterations')


args = parser.parse_args()


class Regression_meta(nn.Module):
    def __init__(self, input_dim=1, output_dim=1, hidden_dim=40):
        super(Regression_meta, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim=output_dim
        self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, self.output_dim)

        self.act = nn.Tanh()
    def forward(self,x):
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x
        
class Averager():
    def __init__(self):
        self.n = 0
        self.v = 0
    def add(self, x):
        self.v = (self.v * self.n + x) / (self.n + 1)
        self.n += 1
    def item(self):
        return self.v
        

seed = args.seed
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

n_task = args.n_task # m: number of task
train_shot = args.train_shot # p: number of training sample in training task
test_shot = args.test_shot # n: number of training sample in testing phase
query = args.query # number of test sample in testing phase

# form the meta-sample
ampl_tr_list = np.random.uniform(-5, 5, n_task)
phase_tr_list = np.random.uniform(0, 1, n_task) * math.pi
tr_shot_list=[]
tr_shot_y_list=[]
tr_query_list = []
tr_query_y_list = []
task_tr_list=[]

for amplitude, phase in zip(ampl_tr_list, phase_tr_list):
    amplitude, phase=torch.tensor([amplitude]).to(device), torch.tensor([phase]).to(device)

    x_shot = 10.0*(torch.rand(train_shot) - 0.5).float().to(device)  # sample K shots from [-5.0, 5.0]
    y_shot = amplitude * (torch.sin(x_shot) * torch.cos(phase) + torch.sin(phase) * torch.cos(x_shot))
    # x_query = 10.0*(torch.rand(query) - 0.5).to(device)
    # y_query = amplitude * (torch.sin(x_query) * torch.cos(phase) + torch.sin(phase) * torch.cos(x_query))
    tr_shot_list.append(x_shot)
    tr_shot_y_list.append(y_shot)
    # tr_query_list.append(x_query)
    # tr_query_y_list.append(y_query)
    task_tr_list.append([amplitude, phase])
    
# form the test dataset
ampl_val_list=np.random.uniform(-5, 5, 1000)
phase_val_list=np.random.uniform(0, 1, 1000) * math.pi

val_shot_list = []
val_shot_y_list = []
val_query_list = []
val_query_y_list = []
task_val_list=[]
for ampl,phase in zip(ampl_val_list,phase_val_list):
    ampl = torch.tensor([ampl]).to(device)
    phase = torch.tensor([phase]).to(device)
    x_shot = (torch.rand(test_shot) - 0.5).to(device) * 10.0  # sample K shots from [-5.0, 5.0]
    y_shot = ampl * (torch.sin(x_shot) * torch.cos(phase) + torch.sin(phase) * torch.cos(x_shot))
    x_query = (torch.rand(query) - 0.5).to(device) * 10.0
    y_query = ampl * (torch.sin(x_query) * torch.cos(phase) + torch.sin(phase) * torch.cos(x_query))
    val_shot_list.append(x_shot)
    val_shot_y_list.append(y_shot)
    val_query_list.append(x_query)
    val_query_y_list.append(y_query)
    task_val_list.append([ampl,phase])
    
T = args.T
K = args.K
hdim = 40
eta = args.eta
gamma = args.gamma
lamda = args.lamda

dir = 'seed'+str(seed)+'_lamda'+str(lamda)+'_T'+str(T)+'_gamma'+str(gamma)+'_K'+str(K)+'_eta'+str(eta)+'_m'+str(n_task)+'_p'+str(train_shot)+'_n'+str(test_shot)+'_query'+str(query)

if not os.path.exists('log/'+dir):
    os.makedirs('log/'+dir)

log_filename = 'log/'+dir+'/log.txt'
log = open(log_filename, 'w')
sys.stdout = log


model = Regression_meta(1, 1, hdim).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=gamma)

def save_model(dir,name):
    if not os.path.exists('log/'+dir+'/'+name):
        os.makedirs('log/'+dir+'/'+name)
    torch.save(model.state_dict(), osp.join('log/'+dir+'/'+name+'.pth'))

trlog = {}
trlog['train_loss'] = []
trlog['emp_loss'] = []
trlog['test_loss'] = []
trlog['gap'] = []
trlog['min_loss'] = 100.0 # set a large number as the initialization
trlog['lambda'] = lamda
trlog['eta'] = eta
trlog['gamma'] = gamma
trlog['K'] = K
trlog['T'] = T
trlog['n_task'] = n_task
trlog['n_train_training'] = train_shot
trlog['n_train_testing'] = test_shot
trlog['n_test_testing'] = query


for epoch in range(1, T + 1):
    model.train()
    tl = Averager()

    for step in range(n_task):
        x_shot = tr_shot_list[step].reshape(-1, 1)
        y_shot = tr_shot_y_list[step]
        regressor = copy.deepcopy(model)
        optimizer_innertask = torch.optim.SGD(regressor.parameters(), lr=eta)

        list_support=[]

        for i in range(K):
            list_support.append(x_shot)

        list_acc=[]

        for data_support in list_support:
            pred_y_shot=regressor(data_support)
            loss_support=F.mse_loss(pred_y_shot.squeeze(),y_shot.float().to(device))
            reg = 0
            for w1, w2 in zip(model.parameters(),regressor.parameters()):
                reg += F.mse_loss(w1, w2)
            loss_support += lamda * reg
            optimizer_innertask.zero_grad()
            loss_support.backward()
            optimizer_innertask.step()
            list_acc.append(loss_support.item())

        regressor.eval()
        reg = 0
        for w1, w2 in zip(model.parameters(),regressor.parameters()):
            reg += F.mse_loss(w1, w2)
        loss = lamda * reg
        tl.add(loss.item()/lamda)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss = None
    tl = tl.item()
    print('epoch {}, train, loss={:.4f}'.format(epoch, tl))

    if tl < trlog['min_loss']:
        trlog['min_loss'] = tl
        save_model(dir,'min-loss')

    trlog['train_loss'].append(tl)
    torch.save(trlog, osp.join('log/'+dir, 'trlog'))
    save_model(dir,'epoch-last')
    

model.load_state_dict(torch.load('log/'+dir+'/min-loss.pth'))
model.eval()
emp_loss = Averager()
test_loss = Averager()
y_train_query_pred = []
y_test_query_pred = []

for step in range(n_task):
    x_shot = tr_shot_list[step].reshape(-1, 1)
    y_shot = tr_shot_y_list[step]
    regressor = copy.deepcopy(model)
    optimizer_innertask = torch.optim.SGD(regressor.parameters(), lr=eta)

    list_support=[]

    for i in range(K):
        list_support.append(x_shot)

    for data_support in list_support:
        pred_y_shot=regressor(data_support)
        loss_support=F.mse_loss(pred_y_shot.squeeze(),y_shot.float().to(device))
        reg = 0
        for w1, w2 in zip(model.parameters(),regressor.parameters()):
            reg += F.mse_loss(w1, w2)
        loss_support += lamda * reg
        optimizer_innertask.zero_grad()
        loss_support.backward()
        optimizer_innertask.step()

    regressor.eval()
    pred_y_query = regressor(x_shot)
    y_train_query_pred.append(pred_y_query)
    loss = F.mse_loss(pred_y_query.squeeze(), y_shot.float().to(device))
    emp_loss.add(loss.item())

for step in range(1000):
    x_shot = val_shot_list[step].reshape(-1, 1)
    y_shot = val_shot_y_list[step]
    x_query = val_query_list[step].reshape(-1, 1)
    y_query = val_query_y_list[step]
    regressor = copy.deepcopy(model)
    optimizer_innertask = torch.optim.SGD(regressor.parameters(), lr=eta)

    list_support = []

    for i in range(K):
        list_support.append(x_shot)

    for data_support in list_support:
        pred_y_shot=regressor(data_support)
        loss_support=F.mse_loss(pred_y_shot.squeeze(),y_shot.float().to(device))
        reg = 0
        for w1, w2 in zip(model.parameters(),regressor.parameters()):
            reg += F.mse_loss(w1, w2)
        loss_support += lamda * reg
        optimizer_innertask.zero_grad()
        loss_support.backward()
        optimizer_innertask.step()

    regressor.eval()
    pred_y_query = regressor(x_query)
    y_test_query_pred.append(pred_y_query)
    loss = F.mse_loss(pred_y_query.squeeze(), y_query.float().to(device))
    test_loss.add(loss.item())


trlog['emp_loss'].append(emp_loss.item())
trlog['test_loss'].append(test_loss.item())
trlog['gap'].append(test_loss.item()-emp_loss.item())
torch.save(trlog, osp.join('log/'+dir, 'trlog'))
print(trlog)
print('epoch {}, emp_loss={:.4f}, test_loss={:.4f}, gap={:.4f}'.format(epoch, emp_loss.item(), test_loss.item(), test_loss.item()-emp_loss.item()))
