import torch.nn.functional as F
from utils.metrics import topk_corrects
import torch
from torch.autograd import grad
import numpy as np

def gather_flat_grad(loss_grad):
    #cnt = 0
    #for g in loss_grad:
    #    g_vector = g.contiguous().view(-1) if cnt == 0 else torch.cat([g_vector, g.contiguous().view(-1)])
    #    cnt = 1
    return torch.cat([p.contiguous().view(-1) for p in loss_grad if not p is None]) #g_vector

def neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, num_neumann_terms, model):
    preconditioner = d_val_loss_d_theta.detach()
    counter = preconditioner
    # Do the fixed point iteration to approximate the vector-inverseHessian product
    i = 0
    while i < num_neumann_terms:  # for i in range(num_neumann_terms):
        old_counter = counter
        # This increments counter to counter * (I - hessian) = counter - counter * hessian
        #gradient=grad(d_train_loss_d_w, model.parameters(), grad_outputs=counter.view(-1), retain_graph=True)
        #print(gradient)
        #print(d_train_loss_d_w)
        hessian_term = gather_flat_grad(
            grad(d_train_loss_d_w, model.parameters(), grad_outputs=counter.view(-1), retain_graph=True))
        counter = old_counter - elementary_lr * hessian_term
        preconditioner = preconditioner + counter
        i += 1
    return elementary_lr * preconditioner

def loss_adjust_cross_entropy(logits,targets,params,group_size=1):
    dy=params[0]
    ly=params[1]
    if group_size!=1:
        new_dy=dy.repeat_interleave(group_size)
        new_ly=ly.repeat_interleave(group_size)
        x=logits*F.sigmoid(new_dy)+new_ly
    else:
        x=logits*F.sigmoid(dy)+ly
    if len(params)==3:
        wy=params[2]
        loss=F.cross_entropy(x,targets,weight=wy)
    else:
        loss=F.cross_entropy(x,targets)
    return loss

def loss_adjust_dy(logits,targets,params,group_size=1):
    dy=params[0]
    ly=params[1]
    x=torch.transpose(torch.transpose(logits,0,1)*F.sigmoid(dy[targets]),0,1)+ly
    loss=F.cross_entropy(x,targets)
    return loss

def cross_entropy(logits,targets,params,group_size=1):
    if len(params)==3:
        return F.cross_entropy(logits,targets,weight=params[2])
    else:
        return F.cross_entropy(logits,targets)

def logit_adjust_ly(logits,params):
    dy=params[0]
    ly=params[1]
    x=logits*dy-ly
    return x

def get_trainable_hyper_params(params):
    return[param for param in params if param.requires_grad]

def assign_hyper_gradient(params,gradient,num_classes):
    i=0
    for para in params:
        if para.requires_grad:
            num=para.nelement()
            grad=gradient[i:i+num].clone()
            torch.reshape(grad,para.shape)
            para.grad=grad
            i+=num
            # para.grad=gradient[i:i+num].clone()
            # para.grad=gradient[i:i+num_classes].clone()
            # i+=num_classes