


from collections import Counter

import numpy as np


def distinct(seqs):
    """ Calculate intra/inter distinct 1/2. """
    batch_size = len(seqs)
    intra_dist1, intra_dist2 = [], []
    unigrams_all, bigrams_all = Counter(), Counter()
    for seq in seqs:
        unigrams = Counter(seq)
        bigrams = Counter(zip(seq, seq[1:]))
        intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5))
        intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5))

        unigrams_all.update(unigrams)
        bigrams_all.update(bigrams)

    inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5)
    inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5)
    intra_dist1 = np.average(intra_dist1)
    intra_dist2 = np.average(intra_dist2)
    return intra_dist1, intra_dist2, inter_dist1, inter_dist2



#fn = '/mnt/efs/fs2/hzt/causal/Optimus/outputs//finetune_lm/vae_gpt2encoder/basic-s2-beta1-gsfixed-newmask-t1_w0_wr1_lr5e5_gumbel_samelen_bz16_bak/outputs-242000/debug-semi-s2-gsfixed-newmask-t05_wc05_wzc05_wz05_w05_wr1_lr1e6_gumbel_samelen_bz8_dtrm_bak/outputs-226000/train_gan_dtrm_lr1e4_it226000_ep5_debug-semi-s2-gsfixed-newmask-t05_wc05_wzc05_wz05_w05_wr1_lr1e6_gumbel_samelen_bz8_dtrm_bak_it226000/outputs-158000/ppl.txt'
#fn = '/mnt/efs/fs2/hzt/causal/Optimus/data/multi_yelp_tst/full_5_15_clean/yelp_text_attrs_5_15_cor90_test.txt.reformat.len20.sub.13125'
#fn = '/mnt/efs/fs2/hzt/causal/Optimus/outputs/finetune_causal_lm/bias_basic_lm_yelp_10ep_sent/checkpoint-48000/gen_n10000.txt'
#fn = '/mnt/efs/fs2/hzt/causal/Optimus/outputs//finetune_causal_lm/bias_lm_yelp_10ep_sent_corrected/checkpoint-48000/gen_n10000.txt'

#fn = '/mnt/efs/fs2/hzt/causal/Optimus/outputs_gender//finetune_causal_lm/genderfirst_gender_bias_lm_50ep_sent_cor94/checkpoint-3600/gen_n10000.txt'
#fn = '/mnt/efs/fs2/hzt/causal/Optimus/outputs_gender/finetune_causal_lm/genderfirst_gender_bias_basic_lm_50ep_sent_cor94/checkpoint-3600/gen_n10000.txt'
#fn = '/mnt/efs/fs2/hzt/causal/Optimus/outputs/finetune_causal_lm/bias_basic_lm_yelp_10ep_sent/checkpoint-48000/for_classifier/ppl.txt'
fn = '/mnt/efs/fs2/hzt/causal/Optimus/outputs/finetune_lm_gender/vae_gpt2encoder/genderfirst-basic-s3-beta1-gsfixed-newmask-t1_w0_wr1_lr5e5_gumbel_samelen_bz16_bak/outputs-45000/genderfirst-semi-s103-gsfixed-newmask-t05_wc1_wzc1_wz05_w1_wr1_lr1e6_gumbel_samelen_bz8_dtrm_bak/outputs-139200/ppl.txt'

with open(fn, 'r') as fin:
    seqs = []
    for line in fin:
        seq = line.strip().lower().split()
        seqs.append(seq)

results = distinct(seqs)

print(results)
