import torch
from dataset.HuffPost import HuffPost
from learner.learner import Learner
import argparse
import os
import numpy as np

from init import *
dataset.load_embedding_RobertaBase()

parser = argparse.ArgumentParser(description='')
parser.add_argument('--n_month_train', default = 12, type = int, help = '')
parser.add_argument('--name', default = 'tmp', type = str, help = '')
parser.add_argument('--init_from_last', action = 'store_true')
parser.add_argument('--n_aggregation', default = 12, type = int)
parser.add_argument('--eval_start', default = 0, type=int)
parser.add_argument('--debug', action = 'store_true')


args = parser.parse_args()

checkpoint_dir = './checkpoints'
eval_dir = './certs'

if not os.path.exists(eval_dir):
    os.makedirs(eval_dir)

load_path = checkpoint_dir + '/' + args.name + '.t7'
save_path = eval_dir + '/' + args.name + '_agg' + str(args.n_aggregation) + '_eval_start' + str(args.eval_start) + '.t7'


predictions = torch.load(load_path).predictions

certs_in_advance = []
certs_duration = []

if args.init_from_last:
    args.n_month_train = len(months)

for month_eval in range(args.eval_start, len(months)):
    idx_start, idx_end = dataset.set_range_date(months[month_eval][0], months[month_eval][1])
    
    print (month_eval)

    for sample_idx in range(idx_start, idx_end):
        votes = np.zeros(dataset.n_class)

        voter_start = max(0, month_eval - args.n_aggregation)

        for model_idx in range(voter_start, month_eval):
            votes[predictions[str(model_idx)][sample_idx].item()] += 1

        if args.debug:
            print ([predictions[str(model_idx)][sample_idx].item()
                for model_idx in range(voter_start, month_eval)])
            print (votes)

        majority = np.argmax(votes)
        
        if majority != dataset[sample_idx - idx_start][2]:
            certs_in_advance.append(-1)
            certs_duration.append(-1)
            if args.debug:
                print (-1, -1, flush = True)
            continue

        cert_in_advance = len(months)
        cert_duration = len(months)

        for c in range(dataset.n_class):
            if c == majority:
                continue

            #for cert_in_advance
            poison_start = month_eval
            votes_majority = votes[majority]
            votes_c = votes[c]

            while (poison_start > voter_start) and (votes_majority >= votes_c + (majority > c)):
                poison_start -= 1
                pred = predictions[str(poison_start)][sample_idx].item()
                votes_majority -= pred == majority
                votes_c += pred != c

            cert_in_advance = min(cert_in_advance, month_eval - poison_start - 1)

            #for cert_duration
            poison_end = voter_start
            votes_majority = votes[majority]
            votes_c = votes[c]

            for poison_start in range(voter_start, month_eval):
                while (poison_end < month_eval) and (votes_majority >= votes_c + (majority > c)):
                    pred = predictions[str(poison_end)][sample_idx].item()
                    votes_majority -= pred == majority
                    votes_c += pred != c
                    poison_end += 1
                
                if votes_majority < votes_c + (majority > c):
                    cert_duration = min(
                            cert_duration, 
                            max(1, poison_end - poison_start - (args.n_month_train - 1)) - 1
                        )
                else:
                    break

                pred = predictions[str(poison_start)][sample_idx].item()
                votes_majority += pred == majority
                votes_c -= pred != c

        certs_in_advance.append(cert_in_advance)
        certs_duration.append(cert_duration)

        if args.debug:
            print (cert_in_advance, cert_duration, flush=True)


torch.save({
    'certs_in_advance': certs_in_advance,
    'certs_duration': certs_duration
    }, save_path)

            


                



            
        
        
