import torch
import torch.nn as nn
import numpy as np


class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.reshape(self.shape)
    
    
class ConditionLayer(nn.Module):
    def __init__(self, col_cat_count, sample_emb_dim=128):
        super(ConditionLayer, self).__init__()
        self.num_index, self.cat_index, self.cat_count, self.cat_offset = self.get_num_cat_index(col_cat_count)
        self.sample_emb_dim = sample_emb_dim
        self.sample_embeddings = self.get_sample_embeddings()
        
    def get_num_cat_index(self, col_cat_count):
        num_index = []
        cat_index = []
        cat_count = []
        for idx, ele in enumerate(col_cat_count):
            if ele == -1:
                num_index.append(idx)
            else:
                cat_index.append(idx)
                cat_count.append(ele)
        cat_offset = [0] + np.cumsum(cat_count).tolist()[:-1]   
        return num_index, cat_index, cat_count, cat_offset

    def get_sample_embeddings(self, ):
        sample_emb = nn.ModuleDict()
        if len(self.num_index):
            sample_emb['num_emb'] = nn.Sequential(
                # input = (b, n_num_columns)
                # output = (b, n_num_columns, sample_emb_dim)
                Reshape(-1, len(self.num_index), 1),
                nn.Conv1d(len(self.num_index), len(self.num_index)*self.sample_emb_dim, kernel_size=1, groups=len(self.num_index)),
                nn.Sigmoid(),
                Reshape(-1, len(self.num_index), self.sample_emb_dim)
            )
        if len(self.cat_index):
            sample_emb['cat_emb'] = nn.ModuleDict()
            sample_emb['cat_emb']['embedding_layer'] = nn.Embedding(sum(self.cat_count), self.sample_emb_dim)            
            sample_emb['cat_emb']['norm_layer'] = nn.Sequential(
                # input = (b, n_cat_columns, sample_emb_dim)
                # output = (b, n_cat_columns, sample_emb_dim)
                Reshape(-1, len(self.cat_index) * self.sample_emb_dim, 1),
                nn.GroupNorm(len(self.cat_index), len(self.cat_index) * self.sample_emb_dim),
                nn.Conv1d(len(self.cat_index)*self.sample_emb_dim, len(self.cat_index)*self.sample_emb_dim, kernel_size=1, groups=len(self.cat_index)),
                nn.Sigmoid(),
                Reshape(-1, len(self.cat_index), self.sample_emb_dim)
            ) 
        return sample_emb
    
    def forward(self, x=None):
        bs, cs = x.shape
        sample_emb_out = []
        if len(self.num_index):
            num_x = x[:, self.num_index].float()
            num_sample_emb = self.sample_embeddings['num_emb'](num_x)
            sample_emb_out.append(num_sample_emb)
        if len(self.cat_index):
            device_name = x.get_device()
            cat_x = x[:, self.cat_index].long() + torch.tensor(self.cat_offset).long().to(torch.device(f'cuda:{device_name}' if device_name != -1 else 'cpu'))
            cat_sample_emb = self.sample_embeddings['cat_emb']['norm_layer'](self.sample_embeddings['cat_emb']['embedding_layer'](cat_x))
            sample_emb_out.append(cat_sample_emb)
        sample_emb_out = torch.cat(sample_emb_out, dim=1) # (b, c, e)
        sample_emb_out = sample_emb_out.permute(0, 2, 1)
        return sample_emb_out


class OriginalCCC(nn.Module):
    def __init__(
        self, 
        col_cat_count, 
        num_classes, 
        num_cond_per_column=None,
        num_cond_per_column_scale=32, 
        num_cond_per_subtree=4, 
        num_subtree_per_condset=1, 
        num_subtree_per_estimator=-1,
        train_num_estimator=1,
        test_num_estimator=100,
        subtree_hidden_dim=128,
        dropout=0.0, 
        shuffle_condition=True,
        condition_shuffle_type='random',
        device=torch.device('cpu'),
    ):
        super().__init__()
        
        self.device = device
        self.num_classes = 1 if num_classes == -1 else num_classes
        self.is_rgr = True if num_classes == -1 else False
        
        if num_cond_per_column is None:
            num_cond_per_column = num_cond_per_subtree * num_cond_per_column_scale
        
        self.emb_layer = ConditionLayer(col_cat_count, sample_emb_dim=num_cond_per_column)
        
        self.num_columns = len(col_cat_count)
        self.shuffle_condition = shuffle_condition
        if condition_shuffle_type == 'row':
            self.condition_shuffler = ((
                torch.rand(num_cond_per_column, self.num_columns).argsort(-1) + \
                (torch.arange(num_cond_per_column) * self.num_columns).unsqueeze(-1)
            )).reshape(-1)
        elif condition_shuffle_type == 'random':
            self.condition_shuffler = torch.rand(num_cond_per_column * self.num_columns).argsort(-1)
        else:
            raise NotImplementedError 
            
        self.num_total_conditions = self.num_columns * num_cond_per_column
        num_condset = self.num_total_conditions // num_cond_per_subtree
        self.num_subtree = num_condset * num_subtree_per_condset
        
        self.subtree_scoring = nn.Sequential(
            nn.GroupNorm(num_condset, self.num_total_conditions),
            nn.Dropout(dropout) if dropout > 0.0 else nn.Identity(),
            nn.Conv1d(self.num_total_conditions, self.num_total_conditions, groups=num_condset, kernel_size=1),
            nn.ReLU(),
            nn.GroupNorm(num_condset, self.num_total_conditions),
            nn.Dropout(dropout) if dropout > 0.0 else nn.Identity(),
            nn.Conv1d(self.num_total_conditions, self.num_subtree, groups=num_condset, kernel_size=1)
        )
        self.subtree_embedding = nn.Embedding(self.num_subtree, subtree_hidden_dim)
        
        self.train_num_estimator = train_num_estimator
        self.test_num_estimator = test_num_estimator
        self.subtree_chunk_indices = torch.arange(self.num_subtree).reshape(self.num_columns, -1)
        self.max_subtree_chunk = max(2, int(self.num_columns ** 0.5))
        
        if num_subtree_per_estimator == -1:
            self.num_subtree_per_estimator = (self.max_subtree_chunk * self.subtree_chunk_indices.shape[1])
        elif num_subtree_per_estimator > 1:
            self.num_subtree_per_estimator = num_subtree_per_estimator
        else:
            self.num_subtree_per_estimator = int(self.num_subtree * num_subtree_per_estimator)
        
        self.downstream = nn.Sequential(
            nn.LayerNorm(subtree_hidden_dim),
            nn.Dropout(dropout) if dropout > 0.0 else nn.Identity(),
            nn.Linear(subtree_hidden_dim, subtree_hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(subtree_hidden_dim),
            nn.Dropout(dropout) if dropout > 0.0 else nn.Identity(),
            nn.Linear(subtree_hidden_dim, self.num_classes)
        )
        
        print('num_total_conditions: ', self.num_total_conditions)
        print('num_condset: ', num_condset)
        print('num_subtree: ', self.num_subtree)
        print('num_subtree_per_estimator: ', self.num_subtree_per_estimator)
        
    def condition_shuffling(self, emb):
        if self.shuffle_condition:
            return emb.reshape(emb.shape[0], -1, 1)[:, self.condition_shuffler, :]
        else:
            return emb.reshape(emb.shape[0], -1, 1)
        
    def calc_downstream_loss(self, x, y, reduction='mean'):
        if self.is_rgr:
            loss = torch.nn.functional.mse_loss(x.squeeze(-1), y.float(), reduction=reduction)
        else:
            loss = torch.nn.functional.cross_entropy(x, y.long(), reduction=reduction)
        return loss

    def forward(self, x, y=None):
        b, c = x.shape
        sample_emb = self.emb_layer(x) # (b, n_condition, n_column)
        sample_emb = self.condition_shuffling(sample_emb) # (b, n_condition, n_column)
        subtree_score = self.subtree_scoring(sample_emb) # (b, num_subtree, N)
        subtree_embed = self.subtree_embedding.weight.unsqueeze(0)

        num_estimator = self.train_num_estimator if self.training else self.test_num_estimator
        rand_subtree = torch.rand(num_estimator, self.num_subtree).argsort(-1)[:, :self.num_subtree_per_estimator]
            
        rand_condition_score = subtree_score[:, rand_subtree].softmax(-2) # (b, n_estimator, num_subtree, 1)
        rand_subtree_embed = subtree_embed[:, rand_subtree] # (1, n_estimator, num_subtree, hidden)
        rand_condition_embedding = (rand_condition_score * rand_subtree_embed).sum(-2) # (b, n_estimator, hidden)

        pred = self.downstream(rand_condition_embedding) # (b, n_estimator, class)
        final_pred = pred.detach().mean(1)

        if y is not None:
            loss = self.calc_downstream_loss(
                pred.permute(0, 2, 1) if not self.is_rgr else pred, 
                y.unsqueeze(-1).expand(-1, num_estimator)
            )
        else:
            loss = torch.tensor(0.0)
        
        return final_pred, loss