from collections import defaultdict, deque
import datetime
import errno
import logging
import models
import numpy as np
import os
import signal
import socket
import subprocess
import sys
import time
import random
from collections import OrderedDict
import glob

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data.dataloader import default_collate
import torchvision
import transforms as T

from datasets.UCF101 import UCF101
from datasets.HMDB51 import HMDB51
from datasets.AVideoDataset import AVideoDataset


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def SIGTERMHandler(a, b):
    print('received sigterm')
    pass


def signalHandler(a, b):
    print('Signal received', a, time.time(), flush=True)
    os.environ['SIGNAL_RECEIVED'] = 'True'
    return


def init_signal_handler():
    """
    Handle signals sent by SLURM for time limit / pre-emption.
    """
    os.environ['SIGNAL_RECEIVED'] = 'False'
    os.environ['MAIN_PID'] = str(os.getpid())

    signal.signal(signal.SIGUSR1, signalHandler)
    signal.signal(signal.SIGTERM, SIGTERMHandler)
    print("Signal handler installed.", flush=True)


def trigger_job_requeue(checkpoint_filename):
    ''' Submit a new job to resume from checkpoint.
        Be careful to use only for main process.
    '''
    print("IN JOB REQUEUE FUNCTION")
    print(checkpoint_filename)
    if int(os.environ['SLURM_PROCID']) == 0 and \
            str(os.getpid()) == os.environ['MAIN_PID'] and os.path.isfile(checkpoint_filename):
        print('time is up, back to slurm queue', flush=True)
        command = 'scontrol requeue ' + os.environ['SLURM_JOB_ID']
        print(command)
        if os.system(command):
            raise RuntimeError('requeue failed')
        print('New job submitted to the queue', flush=True)
    exit(0)


def restart_from_checkpoint(args, ckp_path=None, run_variables=None, **kwargs):
    """
    Re-start from checkpoint present in experiment repo
    """
    if ckp_path is None:
        ckp_path = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth')

    print(f'Ckpt path: {ckp_path}', flush=True)

    # look for a checkpoint in exp repository
    if not os.path.isfile(ckp_path):
        return

    print('Found checkpoint in experiment repository', flush=True)

    # open checkpoint file
    map_location = None
    if args.world_size > 1:
        map_location = "cuda:" + str(args.local_rank)
    checkpoint = torch.load(ckp_path, map_location=map_location)

    # key is what to look for in the checkpoint file
    # value is the object to load
    # example: {'state_dict': model}
    for key, value in kwargs.items():
        if key in checkpoint and value is not None:
            if key == 'model':
                from collections import OrderedDict
                self_state = value.state_dict()
                new_state_dict = OrderedDict()
                for k, v in checkpoint[key].items():
                    name = 'module.' + k  # adding module because we load after doing distributed
                    if name in self_state.keys():
                        new_state_dict[name] = v
                    else:
                        print("didnt load ", name)
                value.load_state_dict(new_state_dict)
            elif key == 'model_ema':
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in checkpoint[key].items():
                    name =  k  # no need for adding module because it's not distributed.
                    new_state_dict[name] = v
                value.load_state_dict(new_state_dict)
            else:
                value.load_state_dict(checkpoint[key])
            print("=> loaded {} from checkpoint '{}'"
                        .format(key, ckp_path))
        else:
            print("=> failed to load {} from checkpoint '{}'"
                        .format(key, ckp_path))

    # re load variable important for the run
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in checkpoint:
                run_variables[var_name] = checkpoint[var_name]


def save_ckpt(args, epoch, model, optimizer, lr_scheduler, selflabels=None):
    checkpoint = {
        'model': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch + 1,
        'args': args
    }
    if selflabels is not None:
        checkpoint['selflabels'] = selflabels
    mkdir(os.path.join(args.output_dir, 'model_weights'))
    mkdir(os.path.join(args.output_dir, 'checkpoints'))
    if epoch % 10 == 0:
        save_on_master(
            checkpoint,
            os.path.join(args.output_dir, 'model_weights', f'model_{epoch}.pth'.format(epoch))
        )
    save_on_master(
        checkpoint,
        os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth')
    )
    if epoch % 5 == 0:
        save_on_master(
            checkpoint,
            os.path.join(args.output_dir, 'checkpoints', f'ckpt_{epoch}.pth')
        )
    if args.global_rank == 0:
        print(f'Saving checkpoint to: {args.output_dir}', flush=True)
        print(f'Checkpoint saved', flush=True)


def init_distributed_mode(params, make_communication_groups=False):
    """
    Handle single and multi-GPU / multi-node / SLURM jobs.
    Initialize the following variables:
        - n_nodes
        - node_id
        - local_rank
        - global_rank
        - world_size
    """
    params.is_slurm_job = 'SLURM_JOB_ID' in os.environ and not params.debug_slurm
    print("SLURM job: %s" % str(params.is_slurm_job))

    # SLURM job
    if params.is_slurm_job and not params.debug_slurm:

        assert params.local_rank == -1   # on the cluster, handled by SLURM

        SLURM_VARIABLES = [
            'SLURM_JOB_ID',
            'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS',
            'SLURM_TASKS_PER_NODE',
            'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU',
            'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID'
        ]

        PREFIX = "%i - " % int(os.environ['SLURM_PROCID'])
        for name in SLURM_VARIABLES:
            value = os.environ.get(name, None)
            print(PREFIX + "%s: %s" % (name, str(value)))

        # # job ID
        params.job_id = os.environ['SLURM_JOB_ID']

        # number of nodes / node ID
        params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
        params.node_id = int(os.environ['SLURM_NODEID'])

        # local rank on the current node / global rank
        params.local_rank = int(os.environ['SLURM_LOCALID'])
        params.global_rank = int(os.environ['SLURM_PROCID'])

        # number of processes / GPUs per node
        params.world_size = int(os.environ['SLURM_NTASKS'])
        params.n_gpu_per_node = params.world_size // params.n_nodes

        # define master address and master port
        hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames',
            os.environ['SLURM_JOB_NODELIST']])
        params.master_addr = hostnames.split()[0].decode('utf-8')
        params.master_port = 19500

        assert 10001 <= params.master_port <= 20000 or params.world_size == 1
        print(PREFIX + "Master address: %s" % params.master_addr)
        print(PREFIX + "Master port   : %i" % params.master_port)

        # set environment variables for 'env://'
        os.environ['MASTER_ADDR'] = str(params.master_addr)
        os.environ['MASTER_PORT'] = str(params.master_port)
        os.environ['WORLD_SIZE'] = str(params.world_size)
        os.environ['RANK'] = str(params.global_rank)

    # multi-GPU job (local/multi-node) - started with torch.distributed.launch
    elif params.local_rank != -1:

        assert params.master_port == -1

        # read environment variables
        params.global_rank = int(os.environ['RANK'])
        params.world_size = int(os.environ['WORLD_SIZE'])
        params.n_gpu_per_node = int(os.environ['NGPU'])

        # number of nodes / node ID
        params.n_nodes = params.world_size // params.n_gpu_per_node
        params.node_id = params.global_rank // params.n_gpu_per_node

    # local job (single GPU)
    else:
        assert params.local_rank == -1
        assert params.master_port == -1
        params.n_nodes = 1
        params.node_id = 0
        params.local_rank = 0
        params.global_rank = 0
        params.world_size = 1
        params.n_gpu_per_node = 1

    # sanity checks
    assert params.n_nodes >= 1
    assert 0 <= params.node_id < params.n_nodes
    assert 0 <= params.local_rank <= params.global_rank < params.world_size
    assert params.world_size == params.n_nodes * params.n_gpu_per_node

    # define whether this is the master process / if we are in distributed mode
    params.is_master = params.node_id == 0 and params.local_rank == 0
    params.multi_node = params.n_nodes > 1
    params.multi_gpu = params.world_size > 8

    # summary
    PREFIX = "%i - " % params.global_rank
    print(PREFIX + "Number of nodes: %i" % params.n_nodes)
    print(PREFIX + "Node ID        : %i" % params.node_id)
    print(PREFIX + "Local rank     : %i" % params.local_rank)
    print(PREFIX + "Global rank    : %i" % params.global_rank)
    print(PREFIX + "World size     : %i" % params.world_size)
    print(PREFIX + "GPUs per node  : %i" % params.n_gpu_per_node)
    print(PREFIX + "Master         : %s" % str(params.is_master))
    print(PREFIX + "Multi-node     : %s" % str(params.multi_node))
    print(PREFIX + "Multi-GPU      : %s" % str(params.multi_gpu))
    print(PREFIX + "Hostname       : %s" % socket.gethostname())

    # set GPU device
    torch.cuda.set_device(params.local_rank)

    # initialize multi-GPU
    if params.multi_gpu:
        params.distributed = True

        # 'env://' will read these environment variables:
        # MASTER_PORT - required; has to be a free port on machine with rank 0
        # MASTER_ADDR - required (except for rank 0); address of rank 0 node
        # WORLD_SIZE - required; can be set either here, or in a call to init fn
        # RANK - required; can be set either here, or in a call to init function

        print("Initializing PyTorch distributed ...")
        torch.distributed.init_process_group(
            init_method='env://',
            backend='nccl',
            rank=params.global_rank,
            world_size=params.world_size,
        )
        print("Initialized!")

        if make_communication_groups:
            params.super_classes = 1
            params.training_local_world_size = params.world_size // params.super_classes
            params.training_local_rank = params.global_rank % params.training_local_world_size
            params.training_local_world_id = params.global_rank // params.training_local_world_size

            # prepare training groups
            
            training_groups = []
            for group_id in range(params.super_classes):
                ranks = [params.training_local_world_size * group_id + i \
                        for i in range(params.training_local_world_size)]
                training_groups.append(dist.new_group(ranks=ranks))
            return training_groups


def mkdir(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def _get_cache_path(dataset, mode, fold, clip_len, steps_between_clips):
    import hashlib
    filepath = str(dataset) + str(mode) + str(fold) + str(clip_len) + str(steps_between_clips)
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", dataset, h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None, logger=None, writer=None, mode='train', epoch=0, args=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for idx, obj in enumerate(iterable):
            data_time.update(time.time() - end)
            yield idx, obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    pass
                    '''
                    print_or_log(log_msg.format(
                        i,
                        len(iterable),
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB), logger=logger
                    )
                    '''
                else:
                    print_or_log(log_msg.format(
                        i,
                        len(iterable),
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)), logger=logger)
                if writer:
                    step = epoch * len(iterable) + i
                    for key in self.meters:
                        writer.add_scalar(
                            f'{mode}/{key}/iter', 
                            self.meters[key].avg, 
                            step
                        )
                    writer.add_scalar(
                        f'{mode}/memory/iter', 
                        torch.cuda.max_memory_allocated() / MB, 
                        step
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        # print_or_log('{} Total time: {}'.format(header, total_time_str), logger=logger)


class MetricLoggerSLLX(MetricLogger):
    def log_every(self, iterable, print_freq, header=None, logger=None, writer=None, mode='train', epoch=0, args=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for idx, obj in enumerate(iterable):
            data_time.update(time.time() - end)
            yield idx, obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    pass
                    print_or_log(log_msg.format(
                        i,
                        len(iterable),
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB), logger=logger
                    )
                else:
                    print_or_log(log_msg.format(
                        i, 
                        len(iterable), 
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)), logger=logger)
                if writer:
                    step = epoch * len(iterable) + i
                    for key in self.meters:
                        writer.add_scalar(
                            f'{mode}/{key}/iter', 
                            self.meters[key].avg, 
                            step
                        )
                    writer.add_scalar(
                        f'{mode}/memory/iter', 
                        torch.cuda.max_memory_allocated() / MB, 
                        step
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    

class MetricLoggerKinetics(MetricLogger):
    def log_every(self, iterable, print_freq, header=None, logger=None, writer=None, mode='train', epoch=0, args=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for idx, obj in enumerate(iterable):
            data_time.update(time.time() - end)
            yield idx, obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    pass
                    print_or_log(log_msg.format(
                        i,
                        len(iterable),
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB), logger=logger
                    )
                else:
                    print_or_log(log_msg.format(
                        i, 
                        len(iterable), 
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)), logger=logger)
                if writer:
                    step = epoch * len(iterable) + i
                    for key in self.meters:
                        writer.add_scalar(
                            f'{mode}/{key}/iter', 
                            self.meters[key].avg, 
                            step
                        )
                    writer.add_scalar(
                        f'{mode}/memory/iter', 
                        torch.cuda.max_memory_allocated() / MB, 
                        step
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))


class MyFormatter(logging.Formatter):

    err_fmt = "%(asctime)s %(name)s %(module)s: %(lineno)d: %(levelname)s: %(msg)s"
    dbg_fmt = "%(asctime)s %(module)s: %(lineno)d: %(levelname)s:: %(msg)s"
    info_fmt = "%(msg)s"

    def __init__(self):
        super().__init__(fmt="%(asctime)s %(name)s %(levelname)s: %(message)s",
                         datefmt=None,
                         style='%')

    def format(self, record):

        # Save the original format configured by the user
        # when the logger formatter was instantiated
        format_orig = self._style._fmt

        # Replace the original format with one customized by logging level
        if record.levelno == logging.DEBUG:
            self._style._fmt = MyFormatter.dbg_fmt

        elif record.levelno == logging.INFO:
            self._style._fmt = MyFormatter.info_fmt

        elif record.levelno == logging.ERROR:
            self._style._fmt = MyFormatter.err_fmt

        # Call the original formatter class to do the grunt work
        result = logging.Formatter.format(self, record)

        # Restore the original format configured by the user
        self._style._fmt = format_orig

        return result


def setup_logger(name, save_dir, is_master, logname="run.log"):
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    # don't log results for the non-master process
    if not is_master:
        return logger
    ch = logging.StreamHandler(stream=sys.stdout)
    formatter = MyFormatter()
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    print("Creating logger save dir")
    if save_dir:
        fh = logging.FileHandler(os.path.join(save_dir, logname))
        fh.setFormatter(formatter)
        logger.addHandler(fh)
    print(f"Finished creating logger: {save_dir}")

    return logger


def print_or_log(message, logger=None):
    if logger is None:
        print(message, flush=True)
    else:
        logger.info(message)


def setup_tbx(save_dir, is_master):
    from torch.utils.tensorboard import SummaryWriter

    if not is_master:
        return None

    writer = SummaryWriter(save_dir)
    return writer


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k * (100.0 / batch_size))
        return res


### MODELS
def load_model(
    model_name='r3d_18', 
    vid_base_arch='r2plus1d_18', 
    aud_base_arch='resnet18', 
    pretrained=False,
    norm_feat=True,
    use_mlp=False,
    mlptype=0,
    headcount=1,
    num_classes=256,
    temporal_length=8,
    return_features=False
):
    if model_name in ['resnet50']:
        print(f"Loading {model_name} with num classes: {num_classes}", flush=True)
        model = resnet50(pretrained=pretrained, num_classes=[num_classes])
        return model
    elif model_name in ['resnet18']:
        print(f"Loading {model_name} with num classes: {num_classes}", flush=True)
        model = imgresnet18(pretrained=pretrained, out=[num_classes])
        model.fc = torch.nn.Identity()
        return model
    elif model_name in ['resnet50_st_pool']:
        print(f"Loading {model_name} with num classes: {num_classes}", flush=True)
        model = resnet50_st_pool(pretrained=pretrained, num_classes=[num_classes], temporal_length=temporal_length)
        return model
    elif model_name in ['r3d_18', 'mc3_18', 'r2plus1d_18']:
        print(f"Loading {model_name} with num classes: {num_classes}", flush=True)
        model = torchvision.models.video.__dict__[model_name](pretrained=pretrained)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        return model
    elif model_name == 'avc_concat':
        print(f"Loading {model_name}: {vid_base_arch} and {aud_base_arch}, using MLP head: {use_mlp}", flush=True)
        model = models.AVCconcat(
            vid_base_arch=vid_base_arch,
            aud_base_arch=aud_base_arch,
            pretrained=pretrained,
            norm_feat=norm_feat,
            use_mlp=use_mlp,
            mlptype=mlptype,
            headcount=headcount,
            num_classes=num_classes,
            return_features=return_features
        )
        return model
    else:
        print(f"Loading {model_name}: {vid_base_arch} and {aud_base_arch}, using MLP head: {use_mlp}", flush=True)
        model = models.AVC(
            vid_base_arch=vid_base_arch, 
            aud_base_arch=aud_base_arch, 
            pretrained=pretrained,
            norm_feat=norm_feat,
            use_mlp=use_mlp,
            mlptype=mlptype,
            headcount=headcount,
            num_classes=num_classes,
            return_features=return_features
        )
        return model
    return None


def load_model_parameters(model, model_weights, only_encoder=False, on_ddp=False):
    loaded_state = model_weights
    self_state = model.state_dict()
    for name, param in loaded_state.items():
        param = param
        if not on_ddp and 'module.' in name:
            name = name.replace('module.', '')
        if name in self_state.keys():
            if only_encoder:
                if 'mlp' in name:
                    print(f"Not loading {name}")
                else:
                    self_state[name].copy_(param)
            else:
                self_state[name].copy_(param)
        else:
            print("didnt load ", name)


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


def str2bool(v):
    v = v.lower()
    if v in ('yes', 'true', 't', '1'):
        return True
    elif v in ('no', 'false', 'f', '0'):
        return False
    raise ValueError('Boolean argument needs to be true or false. '
        'Instead, it is %s.' % v)


def compute_metrics(y, pred, num_classes=10):
    """
    Compute perfomance metrics given the predicted labels and the true labels
    Args:
        y: True label vector
           (Type: np.ndarray)
        pred: Predicted label vector
              (Type: np.ndarray)
    Returns:
        metrics: Metrics dictionary
                 (Type: dict[str, *])
    """
    # Make sure everything is a numpy array
    if isinstance(y, torch.Tensor):
        y = y.cpu().data.numpy()
    elif not isinstance(y, np.ndarray):
        y = np.array(y)
    if isinstance(pred, torch.Tensor):
        pred = pred.cpu().data.numpy()
    elif not isinstance(pred, np.ndarray):
        pred = np.array(pred)
    assert isinstance(y, np.ndarray)
    assert isinstance(pred, np.ndarray)

    # Convert from one-hot to integer encoding if necessary
    if y.ndim == 2:
        y = np.argmax(y, axis=1)
    if pred.ndim == 2:
        pred = np.argmax(pred, axis=1)
    assert y.ndim == 1
    assert pred.ndim == 1

    acc = (y == pred).mean()

    class_acc = []
    for class_idx in range(num_classes):
        idxs = (y == class_idx)
        class_acc.append((y[idxs] == pred[idxs]).mean())

    ave_class_acc = np.mean(class_acc)

    return {
        'accuracy': acc,
        'class_accuracy': class_acc,
        'average_class_accuracy': ave_class_acc
    }


def load_dataset(
    dataset_name='kinetics',
    mode='train',
    fold=1,
    frames_per_clip=30,
    transforms=None,
    clips_per_video=1,
    num_data_samples=None,
    sampletime=1.2,
    subsample=False,
    seed=0,
    model='avc_cluster',
    sample_rate=1,
    train_crop_size=112,
    colorjitter=False,
    dualdata=False,
    synced=True,
    temp_jitter=True,
    center_crop=False,
    decode_audio=True,
    target_fps=30,
    aug_audio=[],
    num_sec=1,
    aud_sample_rate=48000,
    aud_spec_type=1,
    use_volume_jittering=False,
    use_temporal_jittering=False,
    z_normalize=False
):
    if dataset_name in ['kinetics', 'kinetics600', 'audioset', 'vggsound', 'kinetics_sound', 'ave']:
        print(f"Loading {dataset_name} dataset", flush=True)
        dataset = AVideoDataset(
            ds_name=dataset_name,
            mode=mode,
            num_frames=frames_per_clip,
            seed=seed,
            sample_rate=sample_rate,
            num_spatial_crops=1,
            num_ensemble_views=10,
            train_crop_size=train_crop_size,
            num_data_samples=num_data_samples,
            colorjitter=colorjitter,
            temp_jitter=temp_jitter,
            center_crop=center_crop,
            decode_audio=decode_audio,
            aug_audio=aug_audio,
            target_fps=target_fps,
            num_sec=num_sec,
            aud_sample_rate=aud_sample_rate,
            aud_spec_type=aud_spec_type,
            use_volume_jittering=use_volume_jittering,
            use_temporal_jittering=use_temporal_jittering,
            z_normalize=z_normalize
        )
        return dataset
    elif dataset_name == 'ucf101':
        dataset = UCF101(
            frames_per_clip=frames_per_clip,
            step_between_clips=1,
            transform=transforms,
            fold=fold,
            subsample=subsample,
            train=True if mode == 'train' else False
        )
        return dataset
    elif dataset_name == 'hmdb51':
        dataset = HMDB51(
            frames_per_clip=frames_per_clip,
            step_between_clips=1,
            transform=transforms,
            fold=fold,
            subsample=subsample,
            train=True if mode == 'train' else False
        )
        return dataset
    else:
        assert ("Dataset is not supported")


def collate_fn(batch):
    batch = [(d[0], d[1], d[2], d[3], d[4]) for d in batch if d is not None]
    if len(batch) == 0:
        return None
    else:
        return default_collate(batch)


def load_optimizer(name, params, lr=1e-4, momentum=0.9, weight_decay=0,model=None):
    if name == 'sgd':
        optimizer = torch.optim.SGD(params, 
            lr=lr, 
            momentum=momentum, 
            weight_decay=weight_decay
        )
    elif name == 'adam':
        optimizer = torch.optim.Adam(params, 
            lr=lr, 
            weight_decay=weight_decay
        )
    elif name == 'lbfgs':
        optimizer = torch.optim.LBFGS(model.parameters(),
            lr=lr,
            max_iter=10000
        )
    elif name == 'adamax':
        optimizer = torch.optim.Adamax(model.parameters(),
            lr=lr,
            weight_decay=weight_decay
        )
    elif name == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=lr,
                                      weight_decay=weight_decay,

        )
    elif name == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                      lr=lr,
                                      weight_decay=weight_decay,
                                      momentum=momentum,
                                      )
    else:
        assert("Only 'adam' and 'sgd' supported")
    return optimizer

'''
def _warmup_batchnorm(args, model, dataset, device, batches=100):
    """
    Run some batches through all parts of the model to warmup the running
    stats for batchnorm layers.
    """
    print("Warming up batchnorm", flush=True)

    # Create train sampler
    train_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    
    # Create dataloader
    data_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=args.batch_size,
        sampler=train_sampler, 
        shuffle=False if train_sampler else True,
        num_workers=args.workers,
        collate_fn=None,
        pin_memory=True
    )
    
    # Put model in train mode
    model.train()

    # Iterate over dataloader batches times 
    for i, q in enumerate(data_loader):
        video, audio, _, _, _ = q
        if i == batches:
            break
        if args.global_rank == 0:
            print((i, video.shape), flush=True)
        video = video.to(device)
        audio = audio.to(device)

        # Forward pass: get features, compute loss and accuracy
        _ = model(video, audio)
    if args.distributed:
        dist.barrier()
    print("Finshed warming up batchnorm", flush=True)
'''

def get_transforms(args):
    normalize = T.Normalize(
        mean=[0.43216, 0.394666, 0.37645],
        std=[0.22803, 0.22145, 0.216989]
    )
    normalize = T.Normalize(
        mean=[0.45, 0.45, 0.45],
        std=[0.225, 0.225, 0.225]
    )
    subsample = False
    if args.augtype == 1:
        transform_train = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((128, 171)),
            T.RandomHorizontalFlip(),
            normalize,
            T.RandomCrop((112, 112))
        ])
        transform_test = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((128, 171)),
            normalize,
            T.CenterCrop((112, 112))
        ])
    elif args.augtype in [2,3]: # augtype = 2,3:
        # note that 8x224x224 = 400k, i.e. the same as 32x112x112
        transform_train = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((256, 320)),
            T.RandomHorizontalFlip(),
            normalize,
            T.RandomCrop((112*2, 112*2))
        ])
        transform_test = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((256, 320)),
            normalize,
            T.CenterCrop((224, 224))
        ])
        args.batch_size = 8 # default because resolution is higher (bigger memory footprint)
        if args.augtype == 2:
            args.clip_len = 8
        if args.augtype == 3:
            args.clip_len = 32
    elif args.augtype == 0:
        transform_train = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((128, 171)),
            normalize,
            T.CenterCrop((112, 112))
        ])
        transform_test = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((128, 171)),
            normalize,
            T.CenterCrop((112, 112))
        ])
    return transform_train, transform_test, subsample

def get_ds(args, epoch, mode='train'):
    # Getting transforms
    transform_train, transform_test, subsample = get_transforms(args)

    print("Loading data", flush=True)
    st = time.time()
    dataset = load_dataset(
        dataset_name=args.dataset,
        fold=args.fold,
        mode=mode,
        frames_per_clip=args.clip_len,
        transforms=transform_train,
        subsample=subsample,
        clips_per_video=args.clips_per_video,
        num_data_samples=args.num_data_samples,
        seed=epoch,
        model=args.model,
        sample_rate=args.sample_rate,
        train_crop_size=args.train_crop_size,
        colorjitter=args.colorjitter,
        temp_jitter=args.use_temp_jitter,
        center_crop=args.center_crop,
        decode_audio=args.decode_audio,
        target_fps=args.target_fps,
        aug_audio=args.aug_audio,
        num_sec=args.num_sec,
        aud_sample_rate=args.aud_sample_rate,
        aud_spec_type=args.aud_spec_type,
        use_volume_jittering=args.use_volume_jittering,
        use_temporal_jittering=args.use_temporal_jittering,
        z_normalize=args.z_normalize
    )
    print(f"Took {time.time() - st}", flush=True)

    return dataset


def get_dataloader(args, epoch, mode='train'):
    # Getting transforms
    transform_train, transform_test, subsample = get_transforms(args)

    print("Loading data", flush=True)
    st = time.time()
    dataset = load_dataset(
        dataset_name=args.dataset,
        fold=args.fold,
        mode=mode,
        frames_per_clip=args.clip_len,
        transforms=transform_train,
        subsample=subsample,
        clips_per_video=args.clips_per_video,
        num_data_samples=args.num_data_samples,
        seed=epoch,
        model=args.model,
        sample_rate=args.sample_rate,
        train_crop_size=args.train_crop_size,
        colorjitter=args.colorjitter,
        temp_jitter=args.use_temp_jitter,
        center_crop=args.center_crop,
        decode_audio=args.decode_audio,
        target_fps=args.target_fps,
        aug_audio=args.aug_audio,
        num_sec=args.num_sec,
        aud_sample_rate=args.aud_sample_rate,
        aud_spec_type=args.aud_spec_type,
        use_volume_jittering=args.use_volume_jittering,
        use_temporal_jittering=args.use_temporal_jittering,
        z_normalize=args.z_normalize
    )
    print(f"Took {time.time() - st}", flush=True)

    print("Creating data loaders", flush=True)
    train_sampler = None 
    if args.distributed:
        print("Loading distributed train sampler")
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 
        train_sampler.set_epoch(epoch)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=args.batch_size if mode == 'train' else args.batch_size * 4,
        sampler=train_sampler, 
        num_workers=args.workers,
        shuffle=False if train_sampler else True,
        pin_memory=True, 
        collate_fn=None,
        drop_last=True
    )
    return dataset, data_loader


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input_tensor):
        return input_tensor.view(input_tensor.size(0), -1)

######### util for evaluation

def setup_runtime(seed=0, cuda_dev_id=[1]):
    """Initialize CUDA, CuDNN and the random seeds. """
    # Setup CUDA
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if len(cuda_dev_id) == 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_dev_id[0])
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_dev_id[0])
        for i in cuda_dev_id[1:]:
            os.environ["CUDA_VISIBLE_DEVICES"] += "," + str(i)

    # global cuda_dev_id
    _cuda_device_id = cuda_dev_id
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
    # Fix random seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


class TotalAverage():
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.
        self.mass = 0.
        self.sum = 0.
        self.avg = 0.

    def update(self, val, mass=1):
        self.val = val
        self.mass += mass
        self.sum += val * mass
        self.avg = self.sum / self.mass


class MovingAverage():
    def __init__(self, inertia=0.9):
        self.inertia = inertia
        self.reset()

    def reset(self):
        self.avg = 0.

    def update(self, val):
        self.avg = self.inertia * self.avg + (1 - self.inertia) * val


def imgaccuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k."""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()

        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size).cpu().numpy())
        return res

class CRELU(nn.Module):
    def __init__(self, in_planes, planes,p=0.05):
        super(CRELU, self).__init__()
        withbias = True
        self.bn = nn.BatchNorm1d(in_planes, affine=False)
        self.relu = nn.ReLU(inplace=False)
        self.do = nn.Dropout(p) if p != 0. else nn.Sequential()
        self.linear = nn.Linear(in_planes * 2, planes, bias=withbias)

    def forward(self, x):
        out = self.bn(x)
        out = torch.cat([self.relu(out), self.relu(-out)], dim=1)
        out = self.linear(self.do(out))
        return out

def prep_model(args):
    dat = torch.load(args.modelpath, map_location=lambda storage, loc: storage)['model']
    ncls = []
    for q in (dat.keys()):
        if 'fc' in q:
            if 'weight' in q:
                ncl = dat[q].shape[0]
                ncls.append(ncl)
    if args.arch == 'resnet18':
        model = imgresnet18(out=ncls)
    elif args.arch == 'resnet50':
        model = imgresnet50(out=ncls)
    new_state_dict = OrderedDict()
    for k, v in dat.items():
        name = k.replace('module.', '')  # remove `module.`
        new_state_dict[name] = v
    own_state = model.state_dict()

    for name, param in new_state_dict.items():
        if name not in own_state:
            print('not loaded:', name)
            continue
        if isinstance(param, torch.nn.Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)
    del dat
    for param in model.parameters():
        param.requires_grad = False

    if model.headcount > 1:
        for i in range(model.headcount):
            setattr(model, "top_layer%d" % i, None)
    model.headcount = 1
    model.withfeature = False
    model.return_feature_only = False
    model.return_feature = False

    if args.feature == 'avg':
        top = []
        if args.prob != 0:
            top.append(nn.Dropout(p=args.prob))
        if 'resnet18' in args.arch.lower():
            top.append(nn.Linear(512, 1000))
        else:
            top.append(nn.Linear(2048, 1000))
        model.fc = nn.Sequential(*top)
    else:
        if 'resnet18' in args.arch.lower():
            model.fc = CRELU(512, 1000)
        else:
            model.fc = CRELU(2048, 1000, p=args.prob)
    return model

def xmkdir(path):
    """Create directory PATH recursively if it does not exist."""
    if path is not None and not os.path.exists(path):
        os.makedirs(path)

def clean_checkpoint(checkpoint_dir, dry_run=False, lowest=False):
    if lowest:
        names = list(sorted(
            glob.glob(os.path.join(checkpoint_dir, 'lowest*.pth'))
        ))
    else:
        names = list(sorted(
            glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.pth'))
        ))
    if len(names) > 2:
        for name in names[0:-2]:
            if dry_run:
                print(f"Would delete redundant checkpoint file {name}")
            else:
                print(f"Deleting redundant checkpoint file {name}")
                os.remove(name)

def save_checkpoint(checkpoint_dir, model, optimizer, metrics, epoch, defsave=False):
    """Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir
    for the specified epoch. If checkpoint_dir is None it does not do anything."""
    if checkpoint_dir is not None:
        if model:
            xmkdir(checkpoint_dir)
            try:
                state_dict = model.module.state_dict()
            except AttributeError:
                state_dict = model.state_dict()
            if (epoch % 50 == 0) or defsave:
                name = os.path.join(checkpoint_dir, f'ckpt{epoch:08}.pth')
            else:
                name = os.path.join(checkpoint_dir, f'checkpoint{epoch:08}.pth')
            torch.save({
                'epoch': epoch + 1,
                'metrics': metrics,
                'model': state_dict,
                'optimizer': optimizer.state_dict(),
            }, name)
            clean_checkpoint(checkpoint_dir)
        else:
            xmkdir(checkpoint_dir)
            if (epoch % 50 == 0) or defsave:
                name = os.path.join(checkpoint_dir, f'ckpt{epoch:08}.pth')
            else:
                name = os.path.join(checkpoint_dir, f'checkpoint{epoch:08}.pth')
            torch.save({
                'epoch': epoch + 1,
                'metrics': metrics,
                'optimizer': optimizer.state_dict(),
            }, name)
            clean_checkpoint(checkpoint_dir)

def load_checkpoint(checkpoint_dir, model, optimizer=None):
    """Search the latest checkpoint in checkpoint_dir and load the model and optimizer and return the metrics."""
    names = list(sorted(
        glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.pth'))
    ))
    if len(names) == 0:
        return 0, {'train': [], 'val': []}
    print(f"Loading checkpoint '{names[-1]}'")
    cp = torch.load(names[-1], map_location=str(get_model_device(model)))
    epoch = cp['epoch']
    metrics = cp['metrics']
    if model:
        model.load_state_dict(cp['model'])
    if optimizer:
        optimizer.load_state_dict(cp['optimizer'])
    return epoch, metrics


## Aggerate video level softmaxes into an accuracy score
def aggregrate_video_accuracy(softmaxes, labels, topk=(1,), aggregate="mean"):
    maxk = max(topk)
    output_batch = torch.stack(
        [torch.mean(torch.stack(
            softmaxes[sms]),
            0,
            keepdim=False
        ) for sms in softmaxes.keys()])
    num_videos = output_batch.size(0)
    output_labels = torch.stack(
        [labels[video_id] for video_id in softmaxes.keys()])

    _, pred = output_batch.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(output_labels.expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / num_videos))
    return res


def _warmup_batchnorm(args, model, dataloader, batches=20, group=None):
    """
    Run some batches through all parts of the model to warmup the running
    stats for batchnorm layers.
    """
    print("Warming up batchnorm", flush=True)
    start = time.time()
    with torch.no_grad():
        # Put model in train mode
        model.train()
        # Iterate over dataloader batches times
        for i, batch in enumerate(dataloader):
            video, audio, _, _, idx = batch

            # Move to GPU
            video = video.cuda(non_blocking=True)
            audio = audio.cuda(non_blocking=True)
            if i == 0:
                print(video.shape, audio.shape)
            if i == batches:
                break
            # Forward pass: get features, compute loss and accuracy
            _ = model(video, audio)

        # Ensure processes reach to end of optim clusters
        if args.distributed and args.world_size > 1:
            if group is not None:
                dist.barrier(group=group)
            else:
                dist.barrier()
    print(f"Finshed warming up batchnorm!)"
          f"took {(time.time()-start)/60:.1f}min", flush=True)


def split_head(last_fc, which=None):
    # test it like this: split_head(split_head(torch.nn.Linear(5,5)
    with torch.no_grad():
        K_in = last_fc.bias.size(0) if which is None else 1
        H = last_fc.weight.data.size(1)

        bias_in = last_fc.bias.data.clone().detach()
        weight_in = last_fc.weight.data.clone().detach()

        W_concat = torch.cat([weight_in, bias_in.unsqueeze(1)], dim=1) # K x H+1

        # obtain norm and noise
        original_norms = torch.norm(W_concat, dim=0) # size H+1
        noise = torch.randn(K_in, H+1, device='cuda')*W_concat.std(dim=1).unsqueeze(1)*0.1 # K x H+1

        # compute new weights
        W_new = torch.cat([W_concat + noise, W_concat - noise], dim=0) # 2K x H+1
        # normalize
        W_new = W_new/torch.norm(W_new, dim=0)  # 2K x H+1
        # multiply it back to the original lengths
        W_new = W_new*original_norms# 2K x H+1
        last_fc = torch.nn.Linear(H, K_in*2)
        last_fc.bias.data = W_new[:, -1]
        last_fc.weight.data = W_new[:, :-1]
    return last_fc.to('cuda')
