'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  calculators.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
from typing import Dict, Union, Any, List, Tuple

import torch

from src.model.forward import ForwardAtomisticNetwork

from src.utils.math import segment_sum
from src.utils.torch_geometric import Data


class TorchCalculator:
    """Computes atomic properties, e.g., (total) energy, atomic forces, stress.
    """
    def __call__(self,
                 graph: Data,
                 **kwargs: Any) -> Dict[str, Union[torch.Tensor, Any]]:
        """Performs calculation on the provided (batch) graph data.

        Args:
            graph (Data): Atomic data graph.

        Returns: 
            Dict[str, Union[torch.Tensor, Any]]: Results dictionary.
        """
        raise NotImplementedError()

    def get_device(self) -> str:
        """Provides the device on which calculations are performed.
        
        Returns: 
            str: Device on which calculations are performed.
        """
        raise NotImplementedError()

    def to(self, device: str) -> 'TorchCalculator':
        """Moves the calculator to the provided device.
        
        Args:
            device: Device to which calculator has to be moved.

        Returns: 
            TorchCalculator: The `TorchCalculator` object.
        """
        raise NotImplementedError()


def prepare_gradients(graph: Data,
                      forces: bool = False,
                      stress: bool = False,
                      virials: bool = False,
                      **kwargs: Any) -> Tuple[Data, List[str]]:
    """Prepares gradient calculation by setting `requires_grad=True` for the selected atomic features. 

    Args:
        graph (Data): Atomic data graph.
        forces (bool): If True, gradients with respect to positions/coordinates are calculated. 
                       Defaults to False.
        stress (bool): If True, gradients with respect to strain deformations are calculated. 
                       Defaults to False.
        virials (bool): If True, gradients with respect to strain deformations are calculated. 
                        Defaults to False.
    
        Returns:
            Tuple[Data, List[str]]: Updated graph and list of properties which require gradients.
    """
    require_gradients = []
    if forces:
        require_gradients.append('positions')
        if not graph.positions.requires_grad:
            # request gradients wrt. positions/coordinates
            graph.positions.requires_grad = True
    if stress or virials:
        require_gradients.append('strain')
        if not graph.strain.requires_grad:
            # define displacements corresponding to:
            # Knuth et. al. Comput. Phys. Commun 190, 33-50, 2015
            # similar implementations are provided by NequIP (https://github.com/mir-group/nequip)
            # and SchNetPack (https://github.com/atomistic-machine-learning/schnetpack)
            graph.strain.requires_grad = True
            # symmetrize to account for possible numerical issues
            symmetric_strain = 0.5 * (graph.strain + graph.strain.transpose(-1, -2))
            # update cell
            graph.cell = graph.cell + torch.matmul(graph.cell, symmetric_strain)
            # update positions
            symmetric_strain_i = symmetric_strain.index_select(0, graph.batch)
            graph.positions = graph.positions + torch.matmul(graph.positions.unsqueeze(-2),
                                                             symmetric_strain_i).squeeze(-2)
            # update the shifts
            symmetric_strain_ij = symmetric_strain_i.index_select(0, graph.edge_index[0, :])
            graph.shifts = graph.shifts + torch.matmul(graph.shifts.unsqueeze(-2), symmetric_strain_ij).squeeze(-2)
    return graph, require_gradients


class StructurePropertyCalculator(TorchCalculator):
    """Calculates total energy, atomic forces, stress tensors from atomic energies.

    Args:
        model (ForwardAtomisticNetwork): Forward atomistic neural network object (provides atomic/node energies).
    """
    def __init__(self,
                 model: ForwardAtomisticNetwork,
                 training: bool = False,
                 **config: Any):
        if training:
            self.model = model
        else:
            self.model = torch.compile(model, backend="inductor")

    def __call__(self,
                 graph: Data,
                 forces: bool = False,
                 stress: bool = False,
                 virials: bool = False,
                 create_graph: bool = False,
                 **kwargs: Any) -> Dict[str, torch.Tensor]:
        """Performs calculations for the atomic data graph.

        Args:
            graph (Data): Atomic data graph.
            forces (bool): If True, atomic forces are computed. Defaults to False.
            stress (bool): If True, stress tensor is computed. Defaults to False.
            virials (bool): If True, virials = - stress * volume are computed. Defaults to False.
            create_graph (bool): If True, computational graph is created allowing the computation of 
                                 backward pass for multiple times. Defaults to False.

        Returns: 
            Dict[str, torch.Tensor]: Results dict.
        """
        results = {}
        # prepare graph and the list containing graph attributes requiring gradients
        graph, require_gradients = prepare_gradients(graph=graph, forces=forces, stress=stress, virials=virials)
        # compute atomic energy
        atomic_energies = self.model(graph)
        results['atomic_energies'] = atomic_energies
        # sum up atomic contributions for a structure
        total_energies = segment_sum(atomic_energies, idx_i=graph.batch, dim_size=graph.n_atoms.shape[0])
        # write total energy to results
        results['energy'] = total_energies
        if require_gradients:
            # compute gradients wrt. positions, strain, etc.
            grads = torch.autograd.grad([atomic_energies], [getattr(graph, key) for key in require_gradients],
                                        torch.ones_like(atomic_energies), create_graph=create_graph)
        if forces:
            # compute forces as negative of the gradient wrt. positions
            results['forces'] = torch.neg(grads[0])
        if virials:
            # compute virials as negative of the gradient wrt. strain (note that other conventions are possible,
            # but here we use virials = -1 * stress * volume)
            if grads[-1] is not None:
                results['virials'] = torch.neg(grads[-1])
            else:
                results['virials'] = torch.zeros_like(graph.cell)
        if stress:
            # compute stress as -1 * virials / volume
            volume = torch.einsum('bi, bi -> b', graph.cell[:, 0, :],
                                  torch.cross(graph.cell[:, 1, :], graph.cell[:, 2, :], dim=1))
            if grads[-1] is not None:
                results['stress'] = grads[-1] / volume[:, None, None]
            else:
                results['stress'] = torch.zeros_like(graph.cell) / volume[:, None, None]
        return results

    def get_device(self) -> str:
        return self.model.get_device()

    def to(self, device: str) -> TorchCalculator:
        self.model.to(device)
        return self
