import pickle
import torch
from models import InferSent
import nltk
import argparse
nltk.download('punkt')

'''Use pretrained inferSent model and export sentence embeddings.
Note that regardless of which dataset is being annotated, we use the train subset for building the vocabulary.
'''

if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--filepath', type=str,
                           required=False, help="Dataset path")
    argparser.add_argument('--streaming', action="store_true", default=False, help="Whether to use streaming "
                                                                                   "pickle or dump everything "
                                                                                   "at once. Use streaming for "
                                                                                   "large datasets")
    argparser.add_argument('--savemodel', action="store_true", default=False,
                           help="Whether to save the trained infersent model.")
    argparser.add_argument('--exportembeddings', action="store_true", default=False,
                           help="Whether to export infersent embeddings.")
    argparser.add_argument('--debuglen', type=int, default=5, help="Number of the sentences to show the output for")
    argparser.add_argument('--step', type=int, default=100, help="Number of steps for saving output")
    argparser.add_argument('--version', type=int, default=1, help="Which model version of inferSent to use. "
                                                                  "V1 has been trained on GloVe. "
                                                                  "V2 has been trained on fastText.")
    argparser.add_argument('--bsize', type=int, default=64, help="batch size")
    argparser.add_argument('--word_emb_dim', type=int, default=300, help="Dimension of word embeddings")
    argparser.add_argument('--enc_lstm_dim', type=int, default=2048, help="Dimension of LSTM")

    args = argparser.parse_args()
    MODEL_PATH = f'encoder/infersent{args.version}.pickle'
    params_model = {'bsize': args.bsize, 'word_emb_dim': args.word_emb_dim, 'enc_lstm_dim': args.enc_lstm_dim,
                    'pool_type': 'max', 'dpout_model': 0.0, 'version': args.version}

    sentence_embeddings = []
    prefix = args.filepath[:args.filepath.rfind('/')]
    vocab_file_path = prefix[:prefix.rfind('/')]+'/train/raw_sentences.pkl'
    if args.streaming:
        output_path = prefix + f'/sentence_embeddings_{args.version}_streaming.pkl'
    else:
        output_path = prefix + f'/sentence_embeddings_{args.version}.pkl'

    updated_model_path = prefix + f'/updated_infersent{args.version}.pkl'

    sentences = pickle.load(open(args.filepath, 'rb'))
    if vocab_file_path != args.filepath:
        vocab_sentences = pickle.load(open(vocab_file_path, 'rb'))
    else:
        vocab_sentences = sentences
    flattened_vocab_sentences = [utterance for conversation in vocab_sentences for utterance in conversation]
    flattened_sentences = [utterance for conversation in sentences for utterance in conversation]

    if args.savemodel:
        # from models import InferSent # TODO for future work, add this and check if this works
        model = InferSent(params_model)
        model.load_state_dict(torch.load(MODEL_PATH))

        # model = model.cuda()
        if args.version == 1:
            W2V_PATH = 'dataset/GloVe/glove.840B.300d.txt'
        elif args.version == 2:
            W2V_PATH = 'dataset/fastText/crawl-300d-2M-subword.vec'
        model.set_w2v_path(W2V_PATH)

        model.build_vocab(flattened_vocab_sentences, tokenize=True)
        pickle.dump(model, open(updated_model_path, 'wb'))
    else:
        model = pickle.load(open(updated_model_path, 'rb'))

    if args.exportembeddings:
        flattened_embeddings = model.encode(flattened_sentences, tokenize=True, bsize=64)
        idx = 0
        sent_idx = 0
        for conversation in sentences:
            idx += 1
            conversation_embeddings = []
            for sentence in conversation:
                # embeddings = model.encode([sentence], tokenize=True)[0]
                # conversation_embeddings += [list(embeddings)]
                conversation_embeddings += [list(flattened_embeddings[sent_idx])]
                if idx < args.debuglen:
                    print(flattened_embeddings[sent_idx])
                sent_idx += 1
            if idx % args.step == 0:
                print(f'Conversations: {idx}, Sentences embedded: {sent_idx}')
                # pickle.dump(sentence_embeddings, open(output_path, 'wb'))
            if args.streaming:
                if idx == 1:
                    pickle.dump(conversation_embeddings, open(output_path, 'wb'))
                else:
                    pickle.dump(conversation_embeddings, open(output_path, 'ab'))
            sentence_embeddings += [conversation_embeddings]
        print(f'Conversations: {idx}, Sentences embedded: {sent_idx}')
        if not args.streaming:
            pickle.dump(sentence_embeddings, open(output_path, 'wb'))
