import torch
from ..smoothing import ExactSmoothingLayer, ApproxSmoothingLayer
import re

class MILAttentionPool(torch.nn.Module):
    def __init__(
            self, 
            in_dim, 
            att_dim=50, 
            alpha=None, 
            smooth_mode='approx_10', 
            smooth_where='att_values', 
            spectral_norm=False, 
            **kwargs
        ):
        super(MILAttentionPool, self).__init__()
        self.in_dim = in_dim
        self.att_dim = att_dim
        self.alpha = alpha
        self.smooth_mode = smooth_mode
        self.smooth_where = smooth_where
        self.spectral_norm = spectral_norm

        self.fc1 = torch.nn.Linear(in_dim, att_dim)
        self.fc2 = torch.nn.Linear(att_dim, 1, bias=False)
        
        if self.alpha not in [0, 0.0, None] and self.smooth_where != 'none':
            if self.smooth_mode == 'exact':
                print('Using ExactSmoothingLayer')
                self.smooth_layer = ExactSmoothingLayer(alpha=self.alpha)
            else:
                approx_pattern = r'approx_(\d+)' # approx_<num_steps>
                approx_match = re.search(approx_pattern, self.smooth_mode)
                if approx_match is not None:
                    num_steps = int(approx_match.group(1))
                    print('Using ApproxSmoothingLayer with num_steps={}'.format(num_steps))
                    self.smooth_layer = ApproxSmoothingLayer(alpha=self.alpha, num_steps=num_steps)
                else:
                    raise ValueError("smooth_mode must be 'exact' or 'approx_<num_steps>'")
            
            if self.smooth_where not in ['att_values', 'att_representation', 'representation']:
                raise ValueError("smooth_where must be 'none', 'att_values', 'att_representation' or 'representation'")

            if self.spectral_norm:
                if self.smooth_where == 'att_representation':
                    self.fc2 = torch.nn.utils.parametrizations.spectral_norm(self.fc2)
                elif self.smooth_where in ['representation', 'none']:
                    self.fc1 = torch.nn.utils.parametrizations.spectral_norm(self.fc1)
                    self.fc2 = torch.nn.utils.parametrizations.spectral_norm(self.fc2)
            
            self.use_smooth = True
        else:
            self.smooth_layer = None
            self.use_smooth = False
            self.smooth_mode = ''
            self.smooth_where = ''        
    
    def forward(self, X, adj_mat=None, mask=None, return_att=False):
        """
        input:
            X: tensor (batch_size, bag_size, D)
            adj_mat: sparse coo tensor (batch_size, bag_size, bag_size)
            mask: tensor (batch_size, bag_size)
        output:
            z: tensor (batch_size, D)
            s: tensor (bag_size, 1)
        """

        batch_size = X.shape[0]
        bag_size = X.shape[1]
        D = X.shape[2]
        
        if mask is None:
            mask = torch.ones(batch_size, bag_size, device=X.device)
        mask = mask.unsqueeze(dim=-1) # (batch_size, bag_size, 1)

        if self.smooth_where == 'representation' and self.use_smooth:
            X = self.smooth_layer(X, adj_mat) # (batch_size, bag_size, D)

        H = self.fc1(X.reshape(-1, D)).view(batch_size, bag_size, -1) # (batch_size, bag_size, att_dim)
        if self.smooth_where == 'att_representation' and self.use_smooth:
            H = self.smooth_layer(H, adj_mat) # (batch_size, bag_size, att_dim)
        H = torch.nn.functional.tanh(H) # (batch_size, bag_size, L)

        f = self.fc2(H.reshape(-1, self.att_dim)).view(batch_size, bag_size, -1) # (batch_size, bag_size, 1)
        # f = f / math.sqrt(self.in_dim) # (batch_size, bag_size, 1)
        if self.smooth_where =='att_values' and self.use_smooth:
            f = self.smooth_layer(f, adj_mat) # (batch_size, bag_size, 1)

        exp_f = torch.exp(f)*mask # (batch_size, bag_size, 1)
        sum_exp_f = torch.sum(exp_f, dim=1, keepdim=True) # (batch_size, 1, 1)
        s = exp_f/sum_exp_f # (batch_size, bag_size, 1)
        z = torch.bmm(X.transpose(1,2), s).squeeze(dim=2) # (batch_size, D)
        # z = torch.matmul(X.transpose(1,2), s).squeeze(dim=2) # (batch_size, D)

        if return_att:
            return z, f.squeeze(dim=2)
        else:
            return z
        
    def compute_loss(self, *args, **kwargs):
        return {}