import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

            
class BinaryAlphaLossLogits(nn.Module):
    '''Implementation of binary alpha-loss with logits
    
    Parameters
    ---
    logits: torch.Tensor
        Unnormalized logits from classifier
    targets: torch.Tensor
        Target labels in {0,1}
    sample_weight: torch.Tensor
        Sample weight for each
    '''
    def __init__(self, alpha: float):
        super(BinaryAlphaLossLogits, self).__init__()
        self.alpha = alpha

    def forward(self, logits, targets, sample_weight=None):
        if self.alpha == 1:
            loss = F.binary_cross_entropy_with_logits(logits[:,1],targets.float(), reduction='none')
        else:
            targets = (targets*2-1).reshape(-1,1)
            loss = self.alpha/(self.alpha-1) * (1- torch.pow(1+torch.exp(-logits[:,1].reshape(-1,1) *targets.reshape(-1,1)),1/self.alpha -1))
        if sample_weight is not None:
            loss = torch.multiply(loss, sample_weight.reshape(-1,1)) #/ sample_weight.sum()
        return loss.mean()
        
