'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  layers.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
from typing import Tuple, List, Optional

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.o3.tensor_product import WeightedTensorProduct
from src.o3.linear_transform import LinearTransform
from src.o3.product_basis import WeightedProductBasis

from src.nn.radial import BesselRBF, PolynomialCutoff

from src.utils.math import segment_sum
from src.utils.torch_geometric import Data


class RescaledSiLULayer(nn.Module):
    """Rescaled SiLU layer.
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies rescaled SiLU to the input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output of the rescaled SiLU layer.
        """
        return 1.6765324703310907 * F.silu(x)


class LinearLayer(nn.Module):
    """Simple linear layer.

    Args:
        in_features (float): Number of input features.
        out_features (float): Number of output features.
        bias (bool, optional): If True, apply bias. Defaults to False.
    """
    def __init__(self, 
                 in_features: float,
                 out_features: float,
                 bias: bool = False):
        super(LinearLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # define weight and bias
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_buffer('bias', None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies linear layer to the input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output of the linear layer.
        """
        return F.linear(x, self.weight / (self.in_features) ** 0.5, self.bias)
    
    def __repr__(self) -> str:
        return (f"{self.__class__.__name__} ({self.in_features} -> {self.out_features} | {self.weight.numel()} weights)")


class RadialEmbeddingLayer(nn.Module):
    """Non-linear embedding layer for the radial part.

    Adapted from MACE (https://github.com/ACEsuit/mace/blob/main/mace/modules/blocks.py).
    
    Args:
        r_cutoff (float): Cutoff radius.
        n_basis (int): Number of radial basis functions.
        n_polynomial_cutoff (int): Parameter `p` of the envelope function.
    """
    def __init__(self,
                 r_cutoff: float,
                 n_basis: int,
                 n_polynomial_cutoff: int):
        super(RadialEmbeddingLayer, self).__init__()
        self.bessel_fn = BesselRBF(r_cutoff=r_cutoff, n_basis=n_basis)
        self.cutoff_fn = PolynomialCutoff(r_cutoff=r_cutoff, p=n_polynomial_cutoff)
        self.out_dim = n_basis

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies the non-linear embedding layer to the input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output of the radial embedding layer.
        """
        radial = self.bessel_fn(x)
        cutoff = self.cutoff_fn(x)
        return radial * cutoff
    

class ProductBasisLayer(nn.Module):
    """Equivariant product basis layer with contractions based on the tensor product between 
    irreducible Cartesian tensors/Cartesian harmonics.
        
    Args:
        l_max_node_feats (int): Maximal rank of the irreducible node features (node embeddings).
        l_max_target_feats (int): Maximal rank of the irreducible features (target ones, can be 
                                  different from node features).
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        n_species (int): Number of species/atom types.
        correlation (int): Correlation order, i.e., the number of contracted tensors. It also 
                           corresponds to the many-body order + 1.
        coupled_feats (bool): If True, use mix channels when computing the product basis.
        symmetric_product (bool): If True, exploit symmetry of the tensor product to reduce 
                                  the number of possible tensor contractions.
        use_sc (bool): If True, use self-connection.
    """
    def __init__(self,
                 l_max_node_feats: int,
                 l_max_target_feats: int,
                 in_features: int,
                 out_features: int,
                 n_species: int,
                 correlation: int,
                 coupled_feats: bool,
                 symmetric_product: bool,
                 use_sc: bool):
        super(ProductBasisLayer, self).__init__()
        self.use_sc = use_sc
        
        # define weighted product basis
        self.product_basis = WeightedProductBasis(in1_l_max=l_max_node_feats, out_l_max=l_max_target_feats,
                                                  in1_features=in_features, in2_features=n_species,
                                                  correlation=correlation, coupled_feats=coupled_feats,
                                                  symmetric_product=symmetric_product)
        
        # linear transform
        self.linear = LinearTransform(in_l_max=l_max_target_feats, out_l_max=l_max_target_feats, 
                                      in_features=in_features, out_features=out_features)
    
    def forward(self, 
                node_feats: torch.Tensor,
                sc: Optional[torch.Tensor],
                node_attrs: torch.Tensor) -> torch.Tensor:
        """Computes the product basis.

        Args:
            node_feats (torch.Tensor): Node features (node embeddings).
            sc (torch.Tensor): Residual connection.
            node_attrs (torch.Tensor): Node attributes, e.g., one-hot encoded species.

        Returns:
            torch.Tensor: Product basis.
        """
        product_basis = self.product_basis(node_feats, node_attrs)
        
        # use self-connection if necessary
        if self.use_sc and sc is not None:
            return self.linear(product_basis) + sc
        
        return self.linear(product_basis)


class InteractionLayer(nn.Module):
    """Equivariant interaction layer with the convolution based on the tensor product between 
    irreducible Cartesian tensors/Cartesian harmonics.

    Args:
        l_max_node_feats (int): Maximal rank of the irreducible node features (node embeddings).
        l_max_edge_attrs (int): Maximal rank of the irreducible edge attributes (Cartesian harmonics).
        l_max_target_feats (int): Maximal rank of the irreducible features (target ones, can be 
                                  different from node features).
        l_max_hidden_feats (int): Maximal rank of the irreducible hidden features (for the first 
                                  layer can be different from node features).
        n_basis (int): Number of radial basis functions.
        n_species (int): Number of species/elements.
        in_features (int): Number of input features.
        out_features (int): Number of input features.
        avg_n_neighbors (float): Average number of neighbors.
        radial_MLP (List[int]): List of hidden features for the radial embedding network.
    """
    def __init__(self,
                 l_max_node_feats: int,
                 l_max_edge_attrs: int,
                 l_max_target_feats: int,
                 l_max_hidden_feats: int,
                 n_basis: int,
                 n_species: int,
                 in_features: int,
                 out_features: int,
                 avg_n_neighbors: float,
                 radial_MLP: List[int]):
        super(InteractionLayer, self).__init__()
        self.l_max_node_feats = l_max_node_feats
        self.l_max_edge_attrs = l_max_edge_attrs
        self.l_max_target_feats = l_max_target_feats
        self.l_max_hidden_feats = l_max_hidden_feats
        self.n_basis = n_basis
        self.n_species = n_species
        self.in_features = in_features
        self.out_features = out_features
        self.avg_n_neighbors = avg_n_neighbors
        self.radial_MLP = radial_MLP
        
        self._setup()
    
    def _setup(self):
        """Setup specific to the interaction layer."""
        raise NotImplementedError()
        
    def forward(self,
                node_attrs: torch.Tensor,
                node_feats: torch.Tensor,
                edge_attrs: torch.Tensor,
                edge_feats: torch.Tensor,
                idx_i: torch.Tensor,
                idx_j: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes the output of the interaction layer.

        Args:
            node_attrs (torch.Tensor): Node attributes, e.g., one-hot encoded species.
            node_feats (torch.Tensor): Node features (node embeddings).
            edge_attrs (torch.Tensor): Edge attributes (Cartesian harmonics).
            edge_feats (torch.Tensor): Edge features (radial basis).
            idx_i (torch.Tensor): Receivers (central nodes).
            idx_j (torch.Tensor): Senders (neighboring nodes)

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Node messages and residual connections.
        """
        raise NotImplementedError()


class RealAgnosticResidualInteractionLayer(InteractionLayer):
    """Equivariant interaction layer with residual connection."""
    def _setup(self):
        # first linear transform
        self.linear_first = LinearTransform(in_l_max=self.l_max_node_feats, out_l_max=self.l_max_node_feats, 
                                            in_features=self.in_features, out_features=self.in_features)
        
        # tensor product between hidden features and Cartesian harmonics
        self.conv_tp = WeightedTensorProduct(in1_l_max=self.l_max_node_feats, in2_l_max=self.l_max_edge_attrs, out_l_max=self.l_max_target_feats,
                                             in1_features=self.in_features, in2_features=1, out_features=self.in_features,
                                             connection_mode='uvu', internal_weights=False, shared_weights=False)

        # convolution weights
        layers = []
        for in_size, out_size in zip([self.n_basis] + self.radial_MLP,  
                                     self.radial_MLP + [self.conv_tp.n_total_paths * self.in_features]):
            layers.append(LinearLayer(in_size, out_size))
            layers.append(RescaledSiLULayer())
        self.conv_tp_weights = torch.nn.Sequential(*layers[:-1])

        # second linear layer
        self.linear_second = LinearTransform(in_l_max=self.l_max_target_feats, out_l_max=self.l_max_target_feats, 
                                             in_features=self.in_features, out_features=self.out_features,
                                             in_paths=self.conv_tp.n_paths)

        # tensor product between node features and node attributes for the residual connection
        self.skip_tp = WeightedTensorProduct(in1_l_max=self.l_max_node_feats, in2_l_max=0, out_l_max=self.l_max_hidden_feats,
                                             in1_features=self.in_features, in2_features=self.n_species, out_features=self.in_features,
                                             connection_mode='uvw', internal_weights=True, shared_weights=True)

    def forward(self,
                node_attrs: torch.Tensor,
                node_feats: torch.Tensor,
                edge_attrs: torch.Tensor,
                edge_feats: torch.Tensor,
                idx_i: torch.Tensor,
                idx_j: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # shape: n_atoms x n_feats * (1 + 3 + 3^2 + ...)
        sc = self.skip_tp(node_feats, node_attrs)
        
        # shape: n_atoms x n_feats * (1 + 3 + 3^2 + ...)
        node_feats = self.linear_first(node_feats)
        
        # shape: n_neighbors x n_total_paths x n_hidden_feats x 1
        tp_weights = self.conv_tp_weights(edge_feats).view(-1, self.conv_tp.n_total_paths, self.in_features, 1)
        
        # shape: n_neighbors x n_feats * (n_paths_0 + 3 * n_paths_1 + 3^2 * n_paths_2 + ...)
        m_ij = self.conv_tp(node_feats.index_select(0, idx_j), edge_attrs, tp_weights)
        
        # shape: n_atoms x n_feats * (n_paths_0 + 3 * n_paths_1 + 3^2 * n_paths_2 + ...)
        message = segment_sum(m_ij, idx_i, node_feats.shape[0], 0)
        
        # shape: n_atoms x n_feats * (1 + 3 + 3^2 + ...)
        message = self.linear_second(message) / self.avg_n_neighbors
        
        return message, sc


class ScaleShiftLayer(nn.Module):
    """Re-scales and shifts atomic energies predicted by the model.
    
    Args:
        shift_param (float): Parameter by which atomic energies should be shifted.
        scale_param (float): Parameter by which atomic energies should be scaled.
    """
    def __init__(self,
                 shift_params: np.ndarray,
                 scale_params: np.ndarray):
        super().__init__()
        self.register_buffer("scale_params", torch.tensor(scale_params, dtype=torch.get_default_dtype()))
        self.register_buffer("shift_params", torch.tensor(shift_params, dtype=torch.get_default_dtype()))

    def forward(self, 
                x: torch.Tensor,
                graph: Data) -> torch.Tensor:
        """Re-scales and shifts the ouptput of the atomistic model.

        Args:
            x (torch.Tensor): Iutput of the atomistic model.
            graph (Data): Atomic data graph.

        Returns:
            torch.Tensor: Re-scaled and shifted ouptput of the atomistic model.
        """
        scale_species = self.scale_params.index_select(0, graph.species)
        shift_species = self.shift_params.index_select(0, graph.species)
        return scale_species * x + shift_species
    
    def __repr__(self):
        return f'{self.__class__.__name__}(scale_params={self.scale_params}({self.scale_params.dtype}), shift_params={self.shift_params}({self.shift_params.dtype}))'
