# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Train and eval functions used in main.py
"""
import math
import numpy as np
import sys
from typing import Dict, Iterable, Optional
import copy
import torch
import torch.nn
import torch.optim

import util.dist as dist
from datasets.coco_eval import CocoEvaluator
from datasets.flickr_eval import FlickrEvaluator, FlickrCaptionEvaluator
from datasets.refexp import RefExpEvaluator
from util.metrics import MetricLogger, SmoothedValue
from util.misc import targets_to
from util.optim import adjust_learning_rate, update_ema
import json
from text_attack import BERTATT
sys.path.append('cleverhans')
import cleverhans.torch.attacks.projected_gradient_descent as pgd
def train_one_epoch(
    model: torch.nn.Module,
    criterion: Optional[torch.nn.Module],
    weight_dict: Dict[str, float],
    data_loader: Iterable,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    args,
    max_norm: float = 0,
    model_ema: Optional[torch.nn.Module] = None,
):
    model.train()
    if criterion is not None:
        criterion.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
    metric_logger.add_meter("lr_backbone", SmoothedValue(window_size=1, fmt="{value:.6f}"))
    metric_logger.add_meter("lr_text_encoder", SmoothedValue(window_size=1, fmt="{value:.6f}"))
    header = "Epoch: [{}]".format(epoch)
    print_freq = 1000

    num_training_steps = int(len(data_loader) * args.epochs)
    for i, batch_dict in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        curr_step = epoch * len(data_loader) + i
        samples = batch_dict["samples"].to(device)
        positive_map = batch_dict["positive_map"].to(device) if "positive_map" in batch_dict else None
        targets = batch_dict["targets"]
        answers = {k: v.to(device) for k, v in batch_dict["answers"].items()} if "answers" in batch_dict else None
        captions = [t["caption"] for t in targets]

        targets = targets_to(targets, device)

        memory_cache = model(samples, captions, targets, encode_and_save=True)
        outputs = model(samples, captions, targets, encode_and_save=False, memory_cache=memory_cache)

        loss_dict = {}
        if criterion is not None:
            loss_dict.update(criterion(outputs, targets, positive_map))

        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = dist.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {f"{k}_unscaled": v for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict}
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        adjust_learning_rate(
            optimizer,
            epoch,
            curr_step,
            num_training_steps=num_training_steps,
            args=args,
        )
        if model_ema is not None:
            update_ema(model, model_ema, args.ema_decay)

        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(lr_backbone=optimizer.param_groups[1]["lr"])
        metric_logger.update(lr_text_encoder=optimizer.param_groups[2]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
class Adv_attack:
    def __init__(self, model,device,white_box=None,black_box=None):
        self.white_model=copy.deepcopy(model)
        checkpoint = torch.load(white_box, map_location="cpu")
        state_dict = checkpoint["model_ema"]
        self.white_model.load_state_dict(state_dict, strict=False)
        self.black_model=copy.deepcopy(model)
        checkpoint = torch.load(black_box, map_location="cpu")
        state_dict = checkpoint["model_ema"]
        self.black_model.load_state_dict(state_dict, strict=False)
        self.device=device
        self.batch=None
    def pgd_attack(self,x):
        if self.batch is None:
            raise ValueError
        samples = self.batch["samples"].to(self.device)
        targets = self.batch["targets"]
        # answers = {k: v.to(device) for k, v in batch_dict["answers"].items()} if "answers" in batch_dict else None
        captions = [t["caption"] for t in targets]
        targets = targets_to(targets, self.device)
        samples.tensors = x
        memory_cache = self.white_model(samples, captions, targets, encode_and_save=True)
        text_masks = torch.where(memory_cache['mask'][0] == False)
        ori_enc_feats = memory_cache['enc_feats'][:,text_masks[0], 0, :]
        ori_res_feats = memory_cache['ori_res_feats']
        return [ori_res_feats,ori_enc_feats]

    @torch.no_grad()
    def evaluate(
        self,
        criterion: Optional[torch.nn.Module],
        postprocessors: Dict[str, torch.nn.Module],
        weight_dict: Dict[str, float],
        data_loader,
        evaluator_list,
        args,
    ):
        self.white_model.eval()

        self.black_model.eval()
        if criterion is not None:
            criterion.eval()


        metric_logger = MetricLogger(delimiter="  ")
        header = "Test:"

        f = open('right_refcoco_5k.txt', 'r')
        a = list(f)
        f.close()
        correct_list = [int(l.strip('\n')) for l in a][:5000]
        idx=0
        score_num=0
        score_cnt=0
        adv_text_dict={}
        count=0
        bert_attack=BERTATT(self.black_model,evaluator_list[-1],postprocessors)

        for batch_dict in metric_logger.log_every(data_loader, 100, header):

            samples = batch_dict["samples"].to(self.device)
            positive_map = batch_dict["positive_map"].to(self.device) if "positive_map" in batch_dict else None

            targets = batch_dict["targets"]
            captions = [t["caption"] for t in targets]
            if targets[0]["image_id"].item() not in correct_list:
                continue
            idx += 1
            targets = targets_to(targets, self.device)
            orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
            memory_cache = self.black_model(samples, captions, targets, encode_and_save=True)
            outputs = self.black_model(samples, captions, targets, encode_and_save=False, memory_cache=memory_cache)

            results = postprocessors["bbox"](outputs, orig_target_sizes)
            if results[0]['boxes'].shape[0] == 1:
                for result in results:
                    result['scores'] = result['scores'].unsqueeze(1)[0]
                    result['labels'] = result['labels'].unsqueeze(1)[0]
            res = {target["image_id"].item(): output for target, output in zip(targets, results)}
            evaluator_list[-1].update(res)
            refexp_res, success = evaluator_list[-1].summarize()
            ori_box=results[0]['boxes'].cpu()
            adv_text, success, sim = bert_attack.attack(samples,captions, ori_box,orig_target_sizes,targets)
            adv_text.append(success)
            adv_text_dict[str(targets[0]["image_id"].item())] = adv_text
            if success==1:
                print('success',adv_text,captions)
                score_num+=1
                score_cnt+=1
            else:
                score_cnt+=1
            if idx %10==0:
                print('refcoco_asr:', score_num/score_cnt)
        print('final_refcoco_asr:',score_num/score_cnt)
        with open('all_adv_bert_attack_bank_on_refcoco_ori_90.txt', 'w') as file:
            file.write(json.dumps(adv_text_dict))

        exit()
    # gather the stats from all processes
