import sys
sys.path.append(r'./dialogue_generation/')
import torch

is_cuda, device = (True, "cuda") if torch.cuda.is_available() else (False, "cpu")

from simctgdialogue import SimCTGDialogue
# load model
print ('Loading model...')
model_path = r'./simctg_lccc'
eos_token, pad_token = '[SEP]', '[PAD]'
model = SimCTGDialogue(model_path, eos_token, pad_token).to(device)
tokenizer = model.tokenizer
model.eval()
print ('Model loaded!')

context_list = [
                ['老铁家好吃贾三不好吃'],
                ['话说红海真的好看！！！准备二刷了！！！但求多排场！！！', '我明天二刷！'],
                ['你家真有钱','从何说起？','可以买粮'],
                ['为何突然伤感?'],
                ['刺猬很可爱！以前别人送了只没养，味儿太大！', '是很可爱但是非常臭', '是啊，没办法养', '那个怎么养哦不会扎手吗'],
                ['杂糅太多反而摸不到主题了是吗']
                ]

with torch.no_grad():
    for context in context_list:
        print ('#########################################################################################################')
        print ('------ Dialogue Context (i.e., Prefix) is ------')
        print (context)
        print ('-----------------------')
        
        print ('------ Contrastive Search Result ------')
        beam_width, alpha, decoding_len = 5, 0.6, 64
        #print ('-----  Output is: -----')
        print (model.contrastive_search(context, beam_width, alpha, decoding_len, cuda_available=is_cuda))
        print ('-----------------------')

        print ('------ Greedy Search Result ------')
        decoding_len = 64
        #print ('-----  Output is: -----')
        print (model.greedy_search(context, decoding_len, cuda_available=is_cuda))
        print ('-----------------------')

        print ('------ Beam Search Result ------')
        beam_width, decoding_len = 10, 64
        #print ('-----  Output is: -----')
        print (model.beam_search(context, beam_width, decoding_len, cuda_available=is_cuda))
        print ('-----------------------')

        print ('------ Nucleus Sampling Result ------')
        nucleus_p, decoding_len = 0.95, 64
        #print ('-----  Output is: -----')
        print(model.nucleus_sampling(context, nucleus_p, decoding_len, cuda_available=is_cuda))
        print ('-----------------------')
        print ('\n')
        '''
            The result of nucleus sampling is different at every run due to its stochasticity.
        '''


