from __future__ import print_function
import copy
import math
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable


margin_epsilon = 0.05
MIN_VALUE = -3
MAX_VALUE = 3


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k)
    return res


def train(model, train_loader, optimizer, args):
    criterion = nn.CrossEntropyLoss()
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    return


def test(model, data_loader, args):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    prec1, prec5, loss = 0, 0, 0
    for data, target in data_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss += criterion(output, target).item()  # sum up batch loss
        prec1_, prec5_ = accuracy(output.data, target.data, topk=(1, 5))
        prec1 += prec1_.cpu().sum()
        prec5 += prec5_.cpu().sum()
    loss /= len(data_loader.dataset)
    prec1 = 100. * prec1.numpy() / len(data_loader.dataset)
    prec5 = 100. * prec5.numpy() / len(data_loader.dataset)
    return prec1, prec5, loss


def test_with_margin(model, data_loader, args):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    prec1, prec5, loss = 0, 0, 0
    margin_list = torch.Tensor([])
    if args.cuda:
        margin_list = margin_list.cuda()

    for data, target in data_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss += criterion(output, target).item()  # sum up batch loss
        prec1_, prec5_ = accuracy(output.data, target.data, topk=(1, 5))
        prec1 += prec1_.cpu().sum()
        prec5 += prec5_.cpu().sum()

        target = target.data
        output_m = output.clone()
        for i in range(target.size(0)):
            output_m[i, target[i]] = output_m[i, :].min()
        margin_ = (output[:, target].diag() - output_m[:, output_m.max(1)[1]].diag()).data
        margin_list = torch.cat((margin_list, margin_), 0)
    if args.cuda:
        margin_list = margin_list.cpu()
    margin = np.percentile(margin_list.numpy(), 100*margin_epsilon)

    N = len(data_loader.dataset)
    loss /= N
    prec1 = 100. * prec1.numpy() / N
    prec5 = 100. * prec5.numpy() / N
    print('margin\t', margin)
    return prec1, prec5, loss, margin


def n_param(module, init_module):
    bparam = 0 if module.bias is None else module.bias.size(0)
    return bparam + module.weight.size(0) * module.weight.view(module.weight.size(0),-1).size(1)


def norm(module, init_module, p=2, q=2):
    return module.weight.view(module.weight.size(0), -1).norm(p=p, dim=1).norm(q).item()


def op_norm(module, init_module, p=float('Inf')):
    _, S, _ = module.weight.view(module.weight.size(0), -1).svd()
    return S.norm(p).item()


def reparam(model, prev_layer=None):
    for child in model.children():
        module_name = child._get_name()
        prev_layer = reparam(child, prev_layer)
        if module_name in ['Linear', 'Conv1d', 'Conv2d', 'Conv3d']:
            prev_layer = child
        elif module_name in ['BatchNorm2d', 'BatchNorm1d']:
            with torch.no_grad():
                scale = child.weight / ((child.running_var + child.eps).sqrt())
                if prev_layer.bias is not None:
                    prev_layer.bias.copy_( child.bias  + ( scale * (prev_layer.bias - child.running_mean) ) )
                perm = list(reversed(range(prev_layer.weight.dim())))
                prev_layer.weight.copy_((prev_layer.weight.permute(perm) * scale ).permute(perm))
                child.bias.fill_(0)
                child.weight.fill_(1.)
                child.running_mean.fill_(0)
                child.running_var.fill_(1.)
    return prev_layer


def calc_measure(model, init_model, measure_func, operator, kwargs={}, p=1):
    measure_val = 0
    if operator == 'product':
        measure_val = math.exp(calc_measure(model, init_model, measure_func, 'log_product', kwargs, p))
    elif operator == 'norm':
        measure_val = (calc_measure(model, init_model, measure_func, 'sum', kwargs, p=p)) ** (1. / p)
    else:
        measure_val = 0
        for child, init_child in zip(model.children(), init_model.children()):
            module_name = child._get_name()
            if module_name in ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'BatchNorm2d', 'BatchNorm1d']:
                if operator == 'log_product':
                    measure_val += math.log(measure_func(child, init_child, **kwargs))
                elif operator == 'sum':
                    measure_val += (measure_func(child, init_child, **kwargs)) ** p
                elif operator == 'max':
                    measure_val = max(measure_val, measure_func(child, init_child, **kwargs))
            else:
                measure_val += calc_measure(child, init_child, measure_func, operator, kwargs, p=p)
    return measure_val


def lp_path_norm(model, args, p, input_size):
    tmp_model = copy.deepcopy(model)
    tmp_model.eval()
    tmp_model.double()
    for name, param in tmp_model.named_parameters():
        if param.requires_grad:
            param.data = param.data.abs().pow(p)
    data_ones = torch.ones(input_size).double()
    if args.cuda:
        data_ones = data_ones.cuda()
    tmp_out = tmp_model(data_ones)
    return tmp_out.data.sum() ** (1. / p )


def calculate_complexity(model, init_model, margin, args):
    nchannels, img_dim  = 3, 32

    model = copy.deepcopy(model)

    Frobenious_norm = calc_measure(model, init_model, norm, 'product', {'p':2, 'q':2}) / margin**2
    spectral_norm = calc_measure(model, init_model, op_norm, 'product', {'p':float('Inf')}) / margin**2
    l1_path_norm = lp_path_norm(model, args, p=1, input_size=[1, nchannels, img_dim, img_dim]) / margin**2
    l2_path_norm = lp_path_norm(model, args, p=2, input_size=[1, nchannels, img_dim, img_dim]) / margin**2
    print('Frobenious norm: {}\tspectral norm: {}'.format(Frobenious_norm, spectral_norm))
    print('l1 path norm: {}\tl2 path norm: {}'.format(l1_path_norm, l2_path_norm))
    return Frobenious_norm, l1_path_norm, l2_path_norm, spectral_norm