from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from dataloader import *

from graph_util import num_atom_features, num_bond_features
from collections import OrderedDict
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

import sys
sys.path.insert(0, '../src')
from util import output_classification_result


target_name_to_hit_ratio = {'NR-AR': 0.0420254202542,
                            'NR-AR-LBD': 0.0344599072233,
                            'NR-AhR': 0.118911502401,
                            'NR-Aromatase': 0.0506948018528,
                            'NR-ER': 0.125271542361,
                            'NR-ER-LBD': 0.0492261392949,
                            'NR-PPAR-gamma': 0.0287236506833,
                            'SR-ARE': 0.16196287821,
                            'SR-ATAD5': 0.0369510135135,
                            'SR-HSE': 0.0578114246387,
                            'SR-MMP': 0.156483629801,
                            'SR-p53': 0.0597477317991}


def tensor_to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x.float())


def variable_to_numpy(x):
    if torch.cuda.is_available():
        x = x.cpu()
    x = x.data.numpy()
    return x


class Flatten(nn.Module):
    def forward(self, x):
        x = x.contiguous().view(x.size()[0], -1)
        return x


class GraphModel(nn.Module):
    def __init__(self, atom_attr_dim, n_gram_num=6):
        super(GraphModel, self).__init__()
        self.n_gram_num = n_gram_num
        self.atom_attr_dim = atom_attr_dim

        self.fc_layer = nn.Sequential(
            nn.Linear(self.atom_attr_dim, 100),
            nn.Linear(100, 50),
            nn.Sigmoid(),
            nn.Linear(50, 30),
            nn.Sigmoid(),
            Flatten(),
            nn.Linear(self.n_gram_num * 30, 1),
            nn.Sigmoid(),
        )

    def forward(self, node_attr_matrix, incidence_matrix):
        incidence_matrix = incidence_matrix.transpose(1, 2)
        # print('incidence_matrix\t', incidence_matrix.size())
        # print('node_attr_matrix\t', node_attr_matrix.size())
        x = torch.bmm(incidence_matrix, node_attr_matrix)
        # print('x\t', x.size())
        x = self.fc_layer(x)
        return x

    def loss_(self, y_predicted, y_actual, alpha=5e-3):
        # sample_weight = 1 + y_label * 999
        # criterion = nn.BCELoss(weight=sample_weight, size_average=True)
        criterion = nn.BCELoss(size_average=True)
        loss = criterion(y_predicted, y_actual)
        return loss


def visualize(model):
    params = model.state_dict()
    for k, v in sorted(params.items()):
        print(k, v.shape)
    for name, param in model.named_parameters():
        print(name, '\t', param.requires_grad, '\t', param.data)
    return


def train(data_loader):
    graph_model.train()
    total_loss = 0
    for batch_id, (node_attr_matrix, incidence_matrix, y_label) in enumerate(data_loader):
        # print('Batch id: {}'.format(batch_id))
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        incidence_matrix = tensor_to_variable(incidence_matrix)
        y_label = tensor_to_variable(y_label)
        y_predicted = graph_model(node_attr_matrix=node_attr_matrix, incidence_matrix=incidence_matrix)
        loss = graph_model.loss_(y_predicted=y_predicted, y_actual=y_label)
        total_loss += loss.data[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total_loss /= len(data_loader)
    return total_loss


def val(data_loader):
    graph_model.eval()
    total_loss = 0
    for batch_id, (node_attr_matrix, incidence_matrix, y_label) in enumerate(data_loader):
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        incidence_matrix = tensor_to_variable(incidence_matrix)
        y_label = tensor_to_variable(y_label)
        y_predicted = graph_model(node_attr_matrix=node_attr_matrix, incidence_matrix=incidence_matrix)
        loss = graph_model.loss_(y_predicted=y_predicted, y_actual=y_label)
        total_loss += loss.data[0]
    total_loss /= len(data_loader)
    return total_loss


def make_predictions(data_loader):
    if data_loader is None:
        return None, None
    y_label_list = []
    y_pred_list = []
    for batch_id, (node_attr_matrix, incidence_matrix, y_label) in enumerate(data_loader):
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        incidence_matrix = tensor_to_variable(incidence_matrix)
        y_label = tensor_to_variable(y_label)
        y_pred = graph_model(node_attr_matrix=node_attr_matrix, incidence_matrix=incidence_matrix)
        y_label_list.extend(variable_to_numpy(y_label))
        y_pred_list.extend(variable_to_numpy(y_pred))
    y_label_list = np.array(y_label_list)
    y_pred_list = np.array(y_pred_list)
    return y_label_list, y_pred_list


def test(train_dataloader=None, val_dataloader=None, test_dataloader=None):
    graph_model.eval()
    y_train, y_pred_on_train = make_predictions(train_dataloader)
    y_val, y_pred_on_val = make_predictions(val_dataloader)
    y_test, y_pred_on_test = make_predictions(test_dataloader)
    output_classification_result(y_train=y_train, y_pred_on_train=y_pred_on_train,
                                 y_val=y_val, y_pred_on_val=y_pred_on_val,
                                 y_test=y_test, y_pred_on_test=y_pred_on_test,
                                 EF_ratio_list=[0.001, 0.0015, 0.01, 0.02],
                                 hit_ratio=target_name_to_hit_ratio[target_name])
    return


def save_model(weight_path):
    print('Saving weight path:\t', weight_path)
    with open(weight_path, 'wb') as f_:
        torch.save(graph_model, f_)


def load_best_model(weight_path):
    with open(weight_path, 'rb') as f_:
        graph_model = torch.load(f_)
    return graph_model


if __name__ == '__main__':
    import time
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', action='store', dest='epoch',
                        type=int, required=False, default=20)
    parser.add_argument('--batch_size', action='store', dest='batch_size',
                        type=int, required=False, default=128)
    parser.add_argument('--seed', action='store', dest='seed',
                        type=int, required=False, default=123)
    parser.add_argument('--target_name', action='store', dest='target_name',
                        type=str, required=False, default='NR-AhR')
    parser.add_argument('--n_gram_num', action='store', dest='n_gram_num',
                        type=int, required=False, default=4)
    given_args = parser.parse_args()

    # # LifeChem
    # K = 5
    # directory = '../datasets/keck_pria_lc/{}_grammed_graph.npz'
    # file_list = []
    # for i in range(K):
    #     file_list.append(directory.format(i))
    # file_list.append('../datasets/keck_pria_lc/keck_lc4_grammed_graph.npz')

    # tox21
    K = 6
    target_name = given_args.target_name
    directory = '../datasets/tox21/{}/{}_grammed_matrix.npz'
    file_list = []
    for i in range(K):
        file_list.append(directory.format(target_name, i))

    EPOCHS = given_args.epoch
    BATCH = given_args.batch_size
    MAX_ATOM_NUM = 55
    ATOM_FEATURE_DIM = num_atom_features()
    BOND_FEATURE_DIM = num_bond_features()
    N_GRAM_NUM = given_args.n_gram_num
    torch.manual_seed(given_args.seed)

    graph_model = GraphModel(atom_attr_dim=ATOM_FEATURE_DIM, n_gram_num=N_GRAM_NUM)
    if torch.cuda.is_available():
        graph_model.cuda()
    # graph_model.apply(weights_init)
    # visualize(graph_model)
    print(graph_model)

    train_graph_matrix_file = file_list[:4]
    val_graph_matrix_file = file_list[4]
    test_graph_matrix_file = file_list[5]

    train_dataset = GraphDataset_N_Gram_Embedding(train_graph_matrix_file, n_gram_num=N_GRAM_NUM)
    val_dataset = GraphDataset_N_Gram_Embedding(val_graph_matrix_file, n_gram_num=N_GRAM_NUM)
    test_dataset = GraphDataset_N_Gram_Embedding(test_graph_matrix_file, n_gram_num=N_GRAM_NUM)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH, shuffle=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH, shuffle=False)

    # optimizer = optim.Adam(graph_model.parameters(), lr=3e-2)
    optimizer = optim.SGD(graph_model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3, min_lr=1e-4, verbose=True)

    for epoch in range(1, 1+EPOCHS):
        print('Epoch: {}'.format(epoch))

        train_start_time = time.time()
        train_loss = train(train_dataloader)
        train_end_time = time.time()
        print('Train time: {:.3f}s. Train loss is {}.'.format(train_end_time - train_start_time, train_loss))

        val_start_time = time.time()
        val_loss = val(val_dataloader)
        scheduler.step(val_loss)
        val_end_time = time.time()
        print('Valid time: {:.3f}s. Val loss is {}.'.format(val_end_time - val_start_time, val_loss))
        print()

        if epoch % 10 == 0:
            test_start_time = time.time()
            test(train_dataloader=train_dataloader, val_dataloader=val_dataloader, test_dataloader=None)
            test_end_time = time.time()
            print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
            print()

    test_start_time = time.time()
    test(train_dataloader=train_dataloader, val_dataloader=val_dataloader, test_dataloader=test_dataloader)
    test_end_time = time.time()
    print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
    print()
