import argparse
import os
import random
import numpy as np
import math
from tqdm import tqdm, trange
import pickle

import torch

#from data_utils import TextClassificationDataProcessor
from data_utils import SST5Processor, TRECProcessor, IMDBProcessor

from generator import Generator
from classifier import Classifier

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

parser = argparse.ArgumentParser()

parser.add_argument('--seed', default=159, type=int)
parser.add_argument("--data_dir", default="datasets", type=str)
parser.add_argument("--name", default="", type=str)
parser.add_argument("--task_name", default="stsa.fine", type=str)
parser.add_argument("--max_seq_length", default=64, type=int)
parser.add_argument("--batch_size", default=32, type=int)

parser.add_argument("--generator_learning_rate", default=1e-5, type=float)
parser.add_argument("--classifier_learning_rate", default=1e-5, type=float)
parser.add_argument("--warmup_proportion", default=0.1, type=float)

parser.add_argument("--generator_pretrain_epochs", default=60, type=int)
parser.add_argument("--classifier_pretrain_epochs", default=30, type=int)
parser.add_argument("--finetune_epochs", default=60, type=int)
parser.add_argument("--min_finetune_epochs", default=0, type=int)

parser.add_argument('--softmax_temperature', default=1.0, type=float)
#parser.add_argument('--percent', default=100, type=int)
#parser.add_argument('--dev_percent', default=-1, type=int)
parser.add_argument('--num_per_class', default=100, type=int)
parser.add_argument('--dev_num_per_class', default=10, type=int)

parser.add_argument('--finetune_generator', default=1, type=int)
parser.add_argument('--base', default=0, type=int)
parser.add_argument('--evaluator_pretrained_path',
                    default='output_evaluator_stsa.fine', type=str)
parser.add_argument('--result_fn',
                    default='results.txt', type=str)
parser.add_argument('--num_aug', default=1, type=int)

args = parser.parse_args()
print(args)


random.seed(args.seed)
np.random.seed(seed=args.seed)


#OUTPUT_DIR = 'log_{}_{}percent_temp{}_{}'.format(
#    args.task_name, args.percent, args.softmax_temperature, args.name)
#PRETRAIN_MODEL_DIR = '{}_{}percent_pretrain_models_G{}_C{}_{}/'.format(
#    args.task_name, args.percent,
#    args.generator_pretrain_epochs, args.classifier_pretrain_epochs,
#    args.name)
OUTPUT_DIR = 'log_{}_tr{}_dev{}_temp{}_naug{}_{}'.format(
    args.task_name, args.num_per_class, args.dev_num_per_class,
    args.softmax_temperature, args.num_aug, args.name)
PRETRAIN_MODEL_DIR = '{}_tr{}_dev{}_pretrain_models_G{}_C{}_{}/'.format(
    args.task_name, args.num_per_class, args.dev_num_per_class,
    args.generator_pretrain_epochs, args.classifier_pretrain_epochs,
    args.name)
os.mkdir(OUTPUT_DIR)


log_history = open(os.path.join(OUTPUT_DIR, 'history.txt'), 'w')

def _print_log(*inputs):
    print(*inputs)
    print(*inputs, file=log_history)
    log_history.flush()

def _pretrain(generator, classifier, train_examples):
    os.mkdir(PRETRAIN_MODEL_DIR)

    best_dev_acc = -1.
    best_dev_epoch = 0
    test_acc = -1.
    for epoch in range(args.classifier_pretrain_epochs):
        dev_acc = _run_classifier(
            epoch=-1,
            generator=generator,
            classifier=classifier,
            train_examples=train_examples,
            do_augment=False)

        if dev_acc > best_dev_acc:
            best_dev_acc = dev_acc
            best_dev_epoch = epoch
            test_acc = classifier.eval(set_type='test')
            torch.save(classifier.model.state_dict(), os.path.join(
                PRETRAIN_MODEL_DIR, 'classifier_model.pt'))
            torch.save(classifier.optimizer.state_dict(), os.path.join(
                PRETRAIN_MODEL_DIR, 'classifier_optimizer.pt'))

        _print_log('Classifier pretrain Epoch {}, Dev Acc: {:.4f}'.format(
            epoch, 100. * dev_acc))

    _print_log('Classifier pretrain, best dev acc: {:.4f}, epoch: {}, '
          'test acc: {:.4f}'.format(100. * best_dev_acc, best_dev_epoch, 100. * test_acc))

    if args.base == 1 or args.base == 2:
        return best_dev_acc, best_dev_epoch, test_acc

    # G
    best_avg_loss = 1e10
    best_g_epoch = 0
    for epoch in range(args.generator_pretrain_epochs):
        _print_log('=' * 5, 'Generator PreTrain Epoch {}'.format(epoch), '=' * 5)
        avg_loss = generator.train_epoch()

        if avg_loss < best_avg_loss:
            torch.save(generator.model.state_dict(), os.path.join(
                PRETRAIN_MODEL_DIR, 'generator_model.pt'))
            torch.save(generator.optimizer.state_dict(), os.path.join(
                PRETRAIN_MODEL_DIR, 'generator_optimizer.pt'))
            best_avg_loss = avg_loss
            best_g_epoch = epoch
        _print_log('Avg Loss: {:.4f}; Best Avg Loss: {:.4f}; Best Epoch: {}'.format(avg_loss, best_avg_loss, best_g_epoch))

        if math.isnan(avg_loss):
            _print_log('*' * 30)
            _print_log('G avg_loss is NAN')
            _print_log('*' * 30)
            exit()

    if best_avg_loss > 6: #TODO
        _print_log('*' * 30)
        _print_log('G best_avg_loss: {}'.format(best_avg_loss))
        _print_log('*' * 30)
        exit()

    return best_dev_acc, best_dev_epoch, test_acc, best_g_epoch


def _run_classifier(epoch, generator, classifier, train_examples, do_augment,
                    log_file=None):
    random.seed(199 * (epoch + 1))
    random.shuffle(train_examples)

    if do_augment:
        log_file = open(os.path.join(
            OUTPUT_DIR, 'aug_log{}.txt'.format(epoch)), 'w')

    batch_size = args.batch_size
    #if do_augment:
    #    batch_size = args.batch_size // 2

    for i in trange(0, len(train_examples), batch_size,
                    desc='Classifier Epoch'):
        batch_examples = train_examples[i: i + batch_size]
        if do_augment:
            batch_examples = generator.finetune_and_augment_batch(
                classifier=classifier,
                examples=batch_examples,
                max_seq_length=args.max_seq_length,
                softmax_temperature=args.softmax_temperature,
                log_file=log_file,
                finetune_generator=args.finetune_generator,
                num_aug=args.num_aug)

        classifier.train_batch(
            batch_examples, args.max_seq_length, is_augment=do_augment)

    id_example_list = []
    for example in train_examples:
        id_example_list.append([example.guid, example])

    random.seed(11111)
    eval_log_file = open(os.path.join(
        OUTPUT_DIR, 'eval_aug_log_epoch{}.txt'.format(epoch)), 'w')
    for _, example in sorted(id_example_list):
        generator._augment_example(
            example,
            max_seq_length=64,
            softmax_temperature=1.0,
            log_file=eval_log_file,
            return_all=True,
            num_aug=1)

    return classifier.eval()


def main():
    ## data
    if args.task_name == 'stsa.fine':
        processor = SST5Processor()
    elif args.task_name == 'trec':
        processor = TRECProcessor()
    elif args.task_name == 'imdb':
        processor = IMDBProcessor()
    else:
        raise ValueError('Unknown task')

    # training
    if args.num_per_class < 0:
        num_per_class = None
    else:
        #proportions = {label: args.percent / 100. for label in processor.get_labels()}
        num_per_class = {label: args.num_per_class for label in processor.get_labels()}
    train_examples = processor.get_train_examples(
        num_per_class=num_per_class)
    print("#train: {}".format(len(train_examples)))
    # dev
    if args.dev_num_per_class < 0:
        dev_num_per_class = None
    else:
        #dev_percent = args.dev_percent
        #if dev_percent < 0:
        #    dev_percent = args.percent
        #proportions = {label: dev_percent / 100. for label in processor.get_labels()}
        dev_num_per_class = {label: args.dev_num_per_class for label in processor.get_labels()}
    dev_examples = processor.get_dev_examples(
        num_per_class=dev_num_per_class)
    # train + dev
    train_dev_examples = train_examples + dev_examples
    # test
    test_examples = processor.get_test_examples()
    print('#test: {}'.format(len(test_examples)))

    #num_epoachs = args.classifier_pretrain_epochs
    #if args.base == 0:
    #    num_epoachs += args.finetune_epochs
    #num_train_steps = num_epoachs  * (len(train_examples) //
    #        args.batch_size + 1)

    generator = Generator(
        label_list=processor.get_labels(),
        learning_rate=args.generator_learning_rate,
        warmup_proportion=args.warmup_proportion,
        num_train_steps=-1,
        device=device)

    generator.load_train_data(
        train_examples=train_examples, #TODO
        #train_examples=train_examples_full,
        max_seq_length=args.max_seq_length,
        batch_size=args.batch_size)
    generator.load_dev_data(
        dev_examples=dev_examples,
        max_seq_length=args.max_seq_length,
        batch_size=args.batch_size)

    classifier = Classifier(
        label_list=processor.get_labels(),
        learning_rate=args.classifier_learning_rate,
        warmup_proportion=args.warmup_proportion,
        num_train_steps=-1,
        device=device)

    classifier.load_dev_data(
        dev_examples=dev_examples,
        max_seq_length=args.max_seq_length,
        batch_size=args.batch_size)
    classifier.load_test_data(
        test_examples=test_examples,
        max_seq_length=args.max_seq_length,
        batch_size=args.batch_size)

    if not os.path.exists(PRETRAIN_MODEL_DIR):
        if args.base == 0:
            dev_acc_no_aug, best_dev_epoch, test_acc_no_aug, best_g_epoch = _pretrain(generator, classifier, train_examples)
        else:
            if args.base == 1:
                dev_acc_no_aug, best_dev_epoch, test_acc_no_aug = _pretrain(generator, classifier, train_examples)
            else:
                dev_acc_no_aug, best_dev_epoch, test_acc_no_aug = _pretrain(generator, classifier, train_dev_examples)
            #test_acc = classifier.eval(set_type='test')
            #_print_log('Base test acc: {:.4f}'.format(100. * test_acc))

            with open(args.result_fn, 'a+') as result_file:
                s_full = '{:.4f},{:.4f},{}'.format(
                    100. * dev_acc_no_aug, 100. * test_acc_no_aug, best_dev_epoch)
                print('{}'.format(s_full), file=result_file)
                result_file.flush()
            exit()

    generator.model.load_state_dict(
        torch.load(os.path.join(PRETRAIN_MODEL_DIR, 'generator_model.pt')))
    generator.optimizer.load_state_dict(
        torch.load(os.path.join(PRETRAIN_MODEL_DIR, 'generator_optimizer.pt')))
    best_g_epoch = 0
    best_dev_acc = classifier.eval()
    test_acc = classifier.eval(set_type='test')
    dev_acc_no_aug = best_dev_acc
    test_acc_no_aug = test_acc

    _print_log('=' * 50)
    _print_log('Begin Training...')
    _print_log('=' * 50)

    _print_log('Without Augmentation, Eval Acc: {:.4f}, Test Acc: {:.4f}'.format(
        100. * dev_acc_no_aug, 100. * test_acc_no_aug))

    best_dev_acc = 0 #TODO
    final_test_acc = 0
    best_dev_epoch = 0
    do_test = False
    for epoch in range(args.finetune_epochs):
        dev_acc = _run_classifier(
            epoch=epoch,
            generator=generator,
            classifier=classifier,
            train_examples=train_examples,
            do_augment=True)

        if dev_acc > best_dev_acc:
            do_test = True
            best_dev_epoch = epoch + 1
        best_dev_acc = max(best_dev_acc, dev_acc)

        _print_log('Joint Training Epoch {}, Dev Acc: {:.4f}'.format(
            epoch, 100. * dev_acc))
        _print_log('Best Ever: {:.4f}'.format(100. * best_dev_acc))

        if do_test:
            final_test_acc = classifier.eval(set_type='test')
            do_test = False

        if epoch < args.min_finetune_epochs:
            best_dev_acc = 0 #TODO
            final_test_acc = 0
            best_dev_epoch = 0

        pickle.dump(generator, open('g_{}.p'.format(epoch), 'wb'))

    # Output
    s_final = 'Final Eval Acc: {:.4f}, Test Acc: {:.4f}, epoch {}'.format(
        100. * best_dev_acc, 100. * final_test_acc, best_dev_epoch)
    s_prior = 'Without Augmentation, Eval Acc: {:.4f}, Test Acc: {:.4f}'.format(
        100. * dev_acc_no_aug, 100. * test_acc_no_aug)
    _print_log(s_final)
    _print_log(s_prior)

    with open(args.result_fn, 'a+') as result_file:
        #print('{}\n'.format(args), file=result_file)
        s_full = '{:.4f},{:.4f},{:.4f},{:.4f},{},{}'.format(
            100. * dev_acc_no_aug, 100. * test_acc_no_aug,
            100. * best_dev_acc, 100. * final_test_acc,
            best_dev_epoch, best_g_epoch)
        print('{}'.format(s_full), file=result_file)
        #print('\n', file=result_file)
        result_file.flush()


if __name__ == '__main__':
    main()
