'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  tensor_product.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
from typing import List, Optional

import torch
import torch.nn as nn

import torch.fx
import opt_einsum_fx

from src.utils.o3 import get_slices, get_shapes


L_MAX = 3
BATCH_SIZE = 10


class PlainTensorProduct(nn.Module):
    """Basic class representing irreducible tensor product without learnable weights. This class allows contracting tensors 
    obtained using various contraction paths in preceding steps.

    Args:
        in1_l_max (int): Maximal rotational order/rank of the first input tensor.
        in2_l_max (int): Maximal rotational order/rank of the second input tensor.
        out_l_max (int): Maximal rotational order/rank of the output tensor.
        in1_features (int): Number of features for the first input tensor.
        in2_features (int): Number of features for the second input tensor.
        out_features (int): Number of features for the output tensor.
        in1_paths (Optional[List[int]], optional): Number of contraction paths used to obtain the first input tensor. 
                                                   Defaults to None.
        in2_paths (Optional[List[int]], optional): Number of contraction paths used to obtain the second input tensor. 
                                                   Defaults to None.
        symmetric_product (bool, optional): If True, skip the calculation of symmetric contractions. Defaults to False.
    """
    def __init__(self,
                 in1_l_max: int,
                 in2_l_max: int,
                 out_l_max: int,
                 in1_features: int, 
                 in2_features: int,
                 out_features: int,
                 in1_paths: Optional[List[int]] = None,
                 in2_paths: Optional[List[int]] = None,
                 symmetric_product: bool = False):
        super(PlainTensorProduct, self).__init__()
        self.in1_l_max = in1_l_max
        self.in2_l_max = in2_l_max
        self.out_l_max = out_l_max
        
        # check the number of features in the input and output tensors
        if in1_features != in2_features or in1_features != out_features:
            raise RuntimeError('The number of input and output features has to be the same.')
        
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.out_features = out_features
        
        # define the number of paths resulted in irreducible tensors in the first 
        # and the second input tensor
        if in1_paths is None:
            self.in1_paths = [1 for _ in range(in1_l_max + 1)]
        else:
            self.in1_paths = in1_paths
        if in2_paths is None:
            self.in2_paths = [1 for _ in range(in2_l_max + 1)]
        else:
            self.in2_paths = in2_paths
        
        self.symmetric_product = symmetric_product
        
        if self.out_l_max > L_MAX or self.in1_l_max > L_MAX or self.in2_l_max > L_MAX:
            raise RuntimeError(f'Tensor product is implemented for l <= {L_MAX=}.')
        
        # slices and shapes for tensors of rank l in flattened input tensors
        self.in1_slices = get_slices(in1_l_max, in1_features, self.in1_paths)
        self.in2_slices = get_slices(in2_l_max, in2_features, self.in2_paths)
        self.in1_shapes = get_shapes(in1_l_max, in1_features, self.in1_paths, use_prod=False)
        self.in2_shapes = get_shapes(in2_l_max, in2_features, self.in2_paths, use_prod=False)
        
        # dimensions of the input tensors for sanity checks
        self.in1_dim = sum([(3 ** l) * in1_features * self.in1_paths[l] for l in range(in1_l_max + 1)])
        self.in2_dim = sum([(3 ** l) * in2_features * self.in2_paths[l] for l in range(in2_l_max + 1)])
        
        # define the number of paths (total and specific for l <= out_l_max)
        self.n_total_paths = _get_n_paths(in1_l_max, in2_l_max, out_l_max, in1_paths=self.in1_paths, 
                                          in2_paths=self.in2_paths, symmetric_product=symmetric_product)
        self.n_paths = []
        for l in range(out_l_max + 1):
            n_paths = _get_n_paths(in1_l_max, in2_l_max, l, in1_paths=self.in1_paths, 
                                   in2_paths=self.in2_paths, symmetric_product=symmetric_product)
            self.n_paths.append(n_paths - sum(self.n_paths))
        
        # correct the number of paths for the case where out_l_max cannot be obtained by the tensor product 
        if out_l_max > 0 and self.n_paths[1] == 0:
            self.n_paths[1] = 1
        if out_l_max > 1 and self.n_paths[2] == 0:
            self.n_paths[2] = 1
        if out_l_max > 2 and self.n_paths[3] == 0:
            self.n_paths[3] = 1

        # define identity matrix
        self.register_buffer('eye', torch.eye(3))
        
        # trace and optimize contractions
        self.contractions = torch.nn.ModuleList()
                
        # l = 0, shape: n_neighbors x n_feats x n_paths x n_paths
        contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aup,aur->aupr', x, y))
        contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                             example_inputs=(torch.randn(BATCH_SIZE, in1_features, self.in1_paths[0]),
                                                                             torch.randn(BATCH_SIZE, in2_features, self.in2_paths[0])))
        self.contractions.append(contraction_op)
        
        if self.in1_l_max > 0 and self.in2_l_max > 0:
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiup,aiur->aupr', x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3, in1_features, self.in1_paths[1]),
                                                                                 torch.randn(BATCH_SIZE, 3, in2_features, self.in2_paths[1])))
            self.contractions.append(contraction_op)
            
        if self.in1_l_max > 1 and self.in2_l_max > 1:
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijup,aijur->aupr', x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features, self.in1_paths[2]),
                                                                                 torch.randn(BATCH_SIZE, 3, 3, in2_features, self.in2_paths[2])))
            self.contractions.append(contraction_op)
        
        if self.in1_l_max > 2 and self.in2_l_max > 2:
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijkup,aijkur->aupr', x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3, 3, 3, in1_features, self.in1_paths[3]),
                                                                                 torch.randn(BATCH_SIZE, 3, 3, 3, in2_features, self.in2_paths[3])))
            self.contractions.append(contraction_op)
        
        # l = 1, shape: n_neighbors x 3 x n_feats x n_paths x n_paths
        if self.out_l_max > 0:
            if self.in1_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiup,aur->aiupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, in1_features, self.in1_paths[1]),
                                                                                     torch.randn(BATCH_SIZE, in2_features, self.in2_paths[0])))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 1 and self.in2_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijup,ajur->aiupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features, self.in1_paths[2]),
                                                                                     torch.randn(BATCH_SIZE, 3, in2_features, self.in2_paths[1])))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 2 and self.in2_l_max > 1:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijkup,ajkur->aiupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, 3, in1_features, self.in1_paths[3]),
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in2_features, self.in2_paths[2])))
                self.contractions.append(contraction_op)
            
            if not self.symmetric_product:
                if self.in2_l_max > 0:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aup,aiur->aiupr', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, in1_features, self.in1_paths[0]),
                                                                                         torch.randn(BATCH_SIZE, 3, in2_features, self.in2_paths[1])))
                    self.contractions.append(contraction_op)
                    
                if self.in1_l_max > 0 and self.in2_l_max > 1:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiup,aijur->ajupr', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, in1_features, self.in1_paths[1]),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, in2_features, self.in2_paths[2])))
                    self.contractions.append(contraction_op)
                    
                if self.in1_l_max > 1 and self.in2_l_max > 2:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijup,aijkur->akupr', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features, self.in1_paths[2]),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, 3, in2_features, self.in2_paths[3])))
                    self.contractions.append(contraction_op)
        
        # l = 2, shape: n_neighbors x 3 x 3 x n_feats x n_paths x n_paths
        if self.out_l_max > 1:
            if self.in1_l_max > 1:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijup,aur->aijupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features, self.in1_paths[2]),
                                                                                     torch.randn(BATCH_SIZE, in2_features, self.in2_paths[0])))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 0 and self.in2_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiup,ajur->aijupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, in1_features, self.in1_paths[1]),
                                                                                     torch.randn(BATCH_SIZE, 3, in2_features, self.in2_paths[1])))
                self.contractions.append(contraction_op)
                contraction_tr = torch.fx.symbolic_trace(lambda x, e: torch.einsum('aup,ij->aijup', x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(torch.randn(BATCH_SIZE, out_features, self.in1_paths[1] * self.in2_paths[1]),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 1 and self.in2_l_max > 1:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijup,ajkur->aikupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features, self.in1_paths[2]),
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in2_features, self.in2_paths[2])))
                self.contractions.append(contraction_op)
                contraction_tr = torch.fx.symbolic_trace(lambda x, e: torch.einsum('aup,ij->aijup', x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(torch.randn(BATCH_SIZE, out_features, self.in1_paths[2] * self.in2_paths[2]),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
                
            if self.in1_l_max > 2 and self.in2_l_max > 2:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijkup,ajklur->ailupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, 3, in1_features, self.in1_paths[3]),
                                                                                     torch.randn(BATCH_SIZE, 3, 3, 3, in2_features, self.in2_paths[3])))
                self.contractions.append(contraction_op)
                contraction_tr = torch.fx.symbolic_trace(lambda x, e: torch.einsum('aup,ij->aijup', x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(torch.randn(BATCH_SIZE, out_features, self.in1_paths[3] * self.in2_paths[3]),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
                
            if self.in1_l_max > 2 and self.in2_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijkup,akur->aijupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, 3, in1_features, self.in1_paths[3]),
                                                                                     torch.randn(BATCH_SIZE, 3, in2_features, self.in2_paths[1])))
                self.contractions.append(contraction_op)
                
            if not self.symmetric_product:
                if self.in2_l_max > 1:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aup,aijur->aijupr', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                         example_inputs=(torch.randn(BATCH_SIZE, in1_features, self.in1_paths[0]),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, in2_features, self.in2_paths[2])))
                    self.contractions.append(contraction_op)
                
                if self.in1_l_max > 0 and self.in2_l_max > 2:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiup,aijkur->ajkupr', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, in1_features, self.in1_paths[1]),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, 3, in2_features, self.in2_paths[3])))
                    self.contractions.append(contraction_op)
        
        # l = 3, shape: n_neighbors x 3 x 3 x 3 x n_feats x n_paths x n_paths
        if self.out_l_max > 2:
            if self.in1_l_max > 2:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijkup,aur->aijkupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, 3, in1_features, self.in1_paths[0]),
                                                                                     torch.randn(BATCH_SIZE, in2_features, self.in2_paths[0])))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 1 and self.in2_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijup,akur->aijkupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features, self.in1_paths[0]),
                                                                                     torch.randn(BATCH_SIZE, 3, in2_features, self.in2_paths[0])))
                self.contractions.append(contraction_op)
                contraction_tr = torch.fx.symbolic_trace(lambda x, e: torch.einsum('aiup,jk->aijkup', x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, out_features, self.in1_paths[0] * self.in2_paths[0]),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 2 and self.in2_l_max > 1:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijkup,aklur->aijlupr', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, 3, in1_features, self.in1_paths[0]),
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in2_features, self.in2_paths[0])))
                self.contractions.append(contraction_op)
                contraction_tr = torch.fx.symbolic_trace(lambda x, e: torch.einsum('aiup,jk->aijkup', x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, out_features, self.in1_paths[0] * self.in2_paths[0]),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
                
            if not self.symmetric_product:
                if self.in2_l_max > 2:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aup,aijkur->aijkupr', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, in1_features, self.in1_paths[0]),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, 3, in2_features, self.in2_paths[0])))
                    self.contractions.append(contraction_op)
                    
                if self.in1_l_max > 0 and self.in2_l_max > 1:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiup,ajkur->aijkupr', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, in1_features, self.in1_paths[0]),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, in2_features, self.in2_paths[0])))
                    self.contractions.append(contraction_op)
                    contraction_tr = torch.fx.symbolic_trace(lambda x, e: torch.einsum('aiup,jk->aijkup', x, e))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, out_features, self.in1_paths[0] * self.in2_paths[0]),
                                                                                         self.eye))
                    self.contractions.append(contraction_op)
                    
                if self.in1_l_max > 1 and self.in2_l_max > 2:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijup,ajklur->aiklupr', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features, self.in1_paths[0]),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, 3, in2_features, self.in2_paths[0])))
                    self.contractions.append(contraction_op)
                    contraction_tr = torch.fx.symbolic_trace(lambda x, e: torch.einsum('aiup,jk->aijkup', x, e))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, out_features, self.in1_paths[0] * self.in2_paths[0]),
                                                                                         self.eye))
                    self.contractions.append(contraction_op)
    
    def forward(self, 
                x: torch.Tensor, 
                y: torch.Tensor) -> torch.Tensor:
        """Computes tensor products/contractions between the two input tensors.

        Args:
            x (torch.Tensor): First input tensor.
            y (torch.Tensor): Second input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        torch._assert(x.shape[-1] == self.in1_dim, 'Incorrect last dimension for x.')
        torch._assert(y.shape[-1] == self.in2_dim, 'Incorrect last dimension for y.')
        
        if self.symmetric_product: torch._assert(torch.equal(x, y), 'Symmetric product is possible only if x == y.')
        
        x_0 = x[:, self.in1_slices[0]].view(x.shape[0], *self.in1_shapes[0])
        y_0 = y[:, self.in2_slices[0]].view(y.shape[0], *self.in2_shapes[0])
        if self.in1_l_max > 0: x_1 = x[:, self.in1_slices[1]].view(x.shape[0], *self.in1_shapes[1])
        if self.in2_l_max > 0: y_1 = y[:, self.in2_slices[1]].view(y.shape[0], *self.in2_shapes[1])
        if self.in1_l_max > 1: x_2 = x[:, self.in1_slices[2]].view(x.shape[0], *self.in1_shapes[2])
        if self.in2_l_max > 1: y_2 = y[:, self.in2_slices[2]].view(y.shape[0], *self.in2_shapes[2])
        if self.in1_l_max > 2: x_3 = x[:, self.in1_slices[3]].view(x.shape[0], *self.in1_shapes[3])
        if self.in2_l_max > 2: y_3 = y[:, self.in2_slices[3]].view(x.shape[0], *self.in2_shapes[3])
        
        i_contraction = 0
        
        cp_0 = [_norm_l1l2l3(0, 0, 0) * self.contractions[i_contraction](x_0, y_0).view(x.shape[0], self.out_features, self.in1_paths[0] * self.in2_paths[0])]
        i_contraction += 1
        
        if self.in1_l_max > 0 and self.in2_l_max > 0:
            xy_110 = self.contractions[i_contraction](x_1, y_1).view(x.shape[0], self.out_features, self.in1_paths[1] * self.in2_paths[1])
            i_contraction += 1
            
            cp_0.append(_norm_l1l2l3(1, 1, 0) * xy_110)
        
        if self.in1_l_max > 1 and self.in2_l_max > 1:
            xy_220 =  self.contractions[i_contraction](x_2, y_2).view(x.shape[0], self.out_features, self.in1_paths[2] * self.in2_paths[2])
            i_contraction += 1
            
            cp_0.append(_norm_l1l2l3(2, 2, 0) * xy_220)
        
        if self.in1_l_max > 2 and self.in2_l_max > 2:
            xy_330 = self.contractions[i_contraction](x_3, y_3).view(x.shape[0], self.out_features, self.in1_paths[3] * self.in2_paths[3])
            i_contraction += 1
            
            cp_0.append(_norm_l1l2l3(3, 3, 0) * xy_330)

        # shape: n_neighbors x (out_features * n_paths)
        cp_0 = torch.cat(cp_0, dim=-1).view(x.shape[0], self.out_features * self.n_paths[0])
        
        if self.out_l_max == 0:
            return cp_0
        
        cp_1 = []
        
        if self.in1_l_max > 0:
            cp_1.append(_norm_l1l2l3(1, 0, 1) * self.contractions[i_contraction](x_1, y_0).view(x.shape[0], 3, self.out_features, self.in1_paths[1] * self.in2_paths[0]))
            i_contraction += 1
           
        if self.in1_l_max > 1 and self.in2_l_max > 0:
            xy_211 = self.contractions[i_contraction](x_2, y_1).view(x.shape[0], 3, self.out_features, self.in1_paths[2] * self.in2_paths[1])
            i_contraction += 1
            
            cp_1.append(_norm_l1l2l3(2, 1, 1) * xy_211)
        
        if self.in1_l_max > 2 and self.in2_l_max > 1:
            xy_321 = self.contractions[i_contraction](x_3, y_2).view(x.shape[0], 3, self.out_features, self.in1_paths[3] * self.in2_paths[2])
            i_contraction += 1
            
            cp_1.append(_norm_l1l2l3(3, 2, 1) * xy_321)
        
        if not self.symmetric_product:
            if self.in2_l_max > 0:
                cp_1.append(_norm_l1l2l3(0, 1, 1) * self.contractions[i_contraction](x_0, y_1).view(x.shape[0], 3, self.out_features, self.in1_paths[0] * self.in2_paths[1]))
                i_contraction += 1
                
            if self.in1_l_max > 0 and self.in2_l_max > 1:
                xy_121 = self.contractions[i_contraction](x_1, y_2).view(x.shape[0], 3, self.out_features, self.in1_paths[1] * self.in2_paths[2])
                i_contraction += 1
                
                cp_1.append(_norm_l1l2l3(1, 2, 1) * xy_121)
            
            if self.in1_l_max > 1 and self.in2_l_max > 2:
                xy_231 = self.contractions[i_contraction](x_2, y_3).view(x.shape[0], 3, self.out_features, self.in1_paths[2] * self.in2_paths[3])
                i_contraction += 1
                
                cp_1.append(_norm_l1l2l3(2, 3, 1) * xy_231)
        
        # append zeros if out_l_max > 0 is requested
        if len(cp_1) == 0:
            cp_1.append(torch.zeros(cp_0.shape[0], 3, self.out_features, device=cp_0.device, dtype=cp_0.dtype))
        
        # shape: n_neighbors x 3 x (out_features * n_paths)
        cp_1 = torch.cat(cp_1, dim=-1).view(x.shape[0], 3, self.out_features * self.n_paths[1])
        
        if self.out_l_max == 1:
            return torch.cat([cp_0, 
                              cp_1.view(x.shape[0], 3 * self.out_features * self.n_paths[1])], -1)
        
        cp_2 = []
        
        if self.in1_l_max > 1:
            cp_2.append(_norm_l1l2l3(2, 0, 2) * self.contractions[i_contraction](x_2, y_0).view(x.shape[0], 3, 3, self.out_features, self.in1_paths[2] * self.in2_paths[0]))
            i_contraction += 1
           
        if self.in1_l_max > 0 and self.in2_l_max > 0:
            x_x = self.contractions[i_contraction](x_1, y_1).view(x.shape[0], 3, 3, self.out_features, self.in1_paths[1] * self.in2_paths[1])
            x_x = x_x + x_x.permute(0, 2, 1, 3, 4)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](xy_110, self.eye)
            i_contraction += 1
            
            cp_2.append(_norm_l1l2l3(1, 1, 2) * (x_x - 2. / 3. * xx_e))
        
        if self.in1_l_max > 1 and self.in2_l_max > 1:
            x_x = self.contractions[i_contraction](x_2, y_2).view(x.shape[0], 3, 3, self.out_features, self.in1_paths[2] * self.in2_paths[2])
            x_x = x_x + x_x.permute(0, 2, 1, 3, 4)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](xy_220, self.eye)
            i_contraction += 1
            
            cp_2.append(_norm_l1l2l3(2, 2, 2) * (x_x - 2. / 3. * xx_e))
           
        if self.in1_l_max > 2 and self.in2_l_max > 2:
            x_x = self.contractions[i_contraction](x_3, y_3).view(x.shape[0], 3, 3, self.out_features, self.in1_paths[3] * self.in2_paths[3])
            x_x = x_x + x_x.permute(0, 2, 1, 3, 4)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](xy_330, self.eye)
            i_contraction += 1
            
            cp_2.append(_norm_l1l2l3(3, 3, 2) * (x_x - 2. / 3. * xx_e))
            
        if self.in1_l_max > 2 and self.in2_l_max > 0:
            cp_2.append(_norm_l1l2l3(3, 1, 2) * self.contractions[i_contraction](x_3, y_1).view(x.shape[0], 3, 3, self.out_features, self.in1_paths[3] * self.in2_paths[1]))
            i_contraction += 1
        
        if not self.symmetric_product:
            if self.in2_l_max > 1:
                cp_2.append(_norm_l1l2l3(0, 2, 2) * self.contractions[i_contraction](x_0, y_2).view(x.shape[0], 3, 3, self.out_features, self.in1_paths[0] * self.in2_paths[2]))
                i_contraction += 1
            
            if self.in1_l_max > 0 and self.in2_l_max > 2:
                cp_2.append(_norm_l1l2l3(1, 3, 2) * self.contractions[i_contraction](x_1, y_3).view(x.shape[0], 3, 3, self.out_features, self.in1_paths[1] * self.in2_paths[3]))
                i_contraction += 1
        
        # append zeros if out_l_max > 0 is requested
        if len(cp_2) == 0:
            cp_2.append(torch.zeros(cp_0.shape[0], 3, 3, self.out_features, device=cp_0.device, dtype=cp_0.dtype))
        
        # shape: n_neighbors x 3 x 3 x (out_features * n_paths)
        cp_2 = torch.cat(cp_2, dim=-1).view(x.shape[0], 3, 3, self.out_features * self.n_paths[2])
        
        if self.out_l_max == 2:
            return torch.cat([cp_0, 
                              cp_1.view(x.shape[0], 3 * self.out_features * self.n_paths[1]),
                              cp_2.view(x.shape[0], (3 ** 2) * self.out_features * self.n_paths[2])], -1)
        
        cp_3 = []
        
        if self.in1_l_max > 2:
            cp_3.append(_norm_l1l2l3(3, 0, 3) * self.contractions[i_contraction](x_3, y_0).view(x.shape[0], 3, 3, 3, self.out_features, self.in1_paths[3] * self.in2_paths[0]))
            i_contraction += 1
           
        if self.in1_l_max > 1 and self.in2_l_max > 0:
            x_x = self.contractions[i_contraction](x_2, y_1).view(x.shape[0], 3, 3, 3, self.out_features, self.in1_paths[2] * self.in2_paths[1])
            x_x = x_x + x_x.permute(0, 2, 3, 1, 4, 5) + x_x.permute(0, 3, 1, 2, 4, 5)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](xy_211, self.eye)
            xx_e = xx_e + xx_e.permute(0, 2, 3, 1, 4, 5) + xx_e.permute(0, 3, 1, 2, 4, 5)
            i_contraction += 1
            
            cp_3.append(_norm_l1l2l3(2, 1, 3) * (x_x - 2. / 5. * xx_e))
           
        if self.in1_l_max > 2 and self.in2_l_max > 1:
            x_x = self.contractions[i_contraction](x_3, y_2).view(x.shape[0], 3, 3, 3, self.out_features, self.in1_paths[3] * self.in2_paths[2])
            x_x = x_x + x_x.permute(0, 2, 3, 1, 4, 5) + x_x.permute(0, 3, 1, 2, 4, 5)
            i_contraction += 1
            xx_e = self.contractions[i_contraction](xy_321, self.eye)
            xx_e = xx_e + xx_e.permute(0, 2, 3, 1, 4, 5) + xx_e.permute(0, 3, 1, 2, 4, 5)
            i_contraction += 1
            
            cp_3.append(_norm_l1l2l3(3, 2, 3) * (x_x - 2. / 5. * xx_e))
        
        if not self.symmetric_product:
            if self.in2_l_max > 2:
                cp_3.append(_norm_l1l2l3(0, 3, 3) *  self.contractions[i_contraction](x_0, y_3).view(x.shape[0], 3, 3, 3, self.out_features, self.in1_paths[0] * self.in2_paths[3]))
                i_contraction += 1

            if self.in1_l_max > 0 and self.in2_l_max > 1:
                x_x = self.contractions[i_contraction](x_1, y_2).view(x.shape[0], 3, 3, 3, self.out_features, self.in1_paths[1] * self.in2_paths[2])
                x_x = x_x + x_x.permute(0, 2, 3, 1, 4, 5) + x_x.permute(0, 3, 1, 2, 4, 5)
                i_contraction += 1
                
                xx_e = self.contractions[i_contraction](xy_121, self.eye)
                xx_e = xx_e + xx_e.permute(0, 2, 3, 1, 4, 5) + xx_e.permute(0, 3, 1, 2, 4, 5)
                i_contraction += 1
                
                cp_3.append(_norm_l1l2l3(1, 2, 3) * (x_x - 2. / 5. * xx_e))
                
            if self.in1_l_max > 1 and self.in2_l_max > 2:
                x_x = self.contractions[i_contraction](x_2, y_3).view(x.shape[0], 3, 3, 3, self.out_features, self.in1_paths[2] * self.in2_paths[3])
                x_x = x_x + x_x.permute(0, 2, 3, 1, 4, 5) + x_x.permute(0, 3, 1, 2, 4, 5)
                i_contraction += 1
                
                xx_e = self.contractions[i_contraction](xy_231, self.eye)
                xx_e = xx_e + xx_e.permute(0, 2, 3, 1, 4, 5) + xx_e.permute(0, 3, 1, 2, 4, 5)
                
                cp_3.append(_norm_l1l2l3(2, 3, 3) * (x_x - 2. / 5. * xx_e))
        
        if len(cp_3) == 0:
            # append zeros if out_l_max > 0 is requested
            cp_3.append(torch.zeros(cp_0.shape[0], 3, 3, 3, self.out_features, device=cp_0.device, dtype=cp_0.dtype))
        
        # shape: n_neighbors x 3 x 3 x 3 x (out_features * n_paths)
        cp_3 = torch.cat(cp_3, dim=-1).view(x.shape[0], 3, 3, 3, self.out_features * self.n_paths[3])
        
        if self.out_l_max == 3:
            return torch.cat([cp_0, 
                              cp_1.view(x.shape[0], 3 * self.out_features * self.n_paths[1]),
                              cp_2.view(x.shape[0], (3 ** 2) * self.out_features * self.n_paths[2]),
                              cp_3.view(x.shape[0], (3 ** 3) * self.out_features * self.n_paths[3])], -1)
            
    def __repr__(self) -> str:
        return (f"{self.__class__.__name__} ({self.in1_l_max} x {self.in2_l_max} -> {self.out_l_max} | {self.n_total_paths} total paths | {self.n_paths} paths)")


class WeightedTensorProduct(nn.Module):
    """Basic class representing irreducible tensor product with learnable weights. 

    Args:
        in1_l_max (int): Maximal rotational order/rank of the first input tensor.
        in2_l_max (int): Maximal rotational order/rank of the second input tensor.
        out_l_max (int): Maximal rotational order/rank of the output tensor.
        in1_features (int): Number of features for the first input tensor.
        in2_features (int): Number of features for the second input tensor.
        out_features (int): Number of features for the output tensor.
        symmetric_product (bool, optional): If True, skip the calculation of symmetric contractions. Defaults to False.
        connection_mode (str, optional): Connection mode for computing the products with learnable weights. 
                                         Defaults to 'uvu'. 'uvw' and 'uvu' are the possible choises, in line with the 
                                         e3nn code.
        internal_weights (bool, optional): If True, use internal weights. Defaults to True.
        shared_weights (bool, optional): If True, share weights across the batch dimension. Defaults to True.
    """
    def __init__(self,
                 in1_l_max: int,
                 in2_l_max: int,
                 out_l_max: int,
                 in1_features: int, 
                 in2_features: int,
                 out_features: int,
                 symmetric_product: bool = False,
                 connection_mode: str = 'uvu',
                 internal_weights: bool = True,
                 shared_weights: bool = True):
        super(WeightedTensorProduct, self).__init__()
        self.in1_l_max = in1_l_max
        self.in2_l_max = in2_l_max
        self.out_l_max = out_l_max
        
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.out_features = out_features
        
        self.symmetric_product = symmetric_product
                
        self.connection_mode = connection_mode
        self.internal_weights = internal_weights
        self.shared_weights = shared_weights
        
        if connection_mode not in ['uvu', 'uvw']:
            raise RuntimeError(f'{connection_mode=} is not implemented. Use "uvu" or "uvw" instead.')
        
        if self.out_l_max > L_MAX or self.in1_l_max > L_MAX or self.in2_l_max > L_MAX:
            raise RuntimeError(f'Tensor product is implemented for l <= {L_MAX=}.')
        
        # slices and shapes for tensors of rank l in flattened input tensors
        self.in1_slices = get_slices(in1_l_max, in1_features)
        self.in2_slices = get_slices(in2_l_max, in2_features)
        self.in1_shapes = get_shapes(in1_l_max, in1_features)
        self.in2_shapes = get_shapes(in2_l_max, in2_features)
        
        # dimensions of the input tensors for sanity checks
        self.in1_dim = sum([(3 ** l) * in1_features for l in range(in1_l_max + 1)])
        self.in2_dim = sum([(3 ** l) * in2_features for l in range(in2_l_max + 1)])
        
        # define the number of paths (total and specific for l <= out_l_max)
        self.n_total_paths = _get_n_paths(in1_l_max, in2_l_max, out_l_max, symmetric_product=symmetric_product)
        self.n_paths = []
        for out_l in range(out_l_max + 1):
            self.n_paths.append(_get_n_paths(in1_l_max, in2_l_max, out_l, symmetric_product=symmetric_product) - sum(self.n_paths))
        
        # correct the number of paths for the case where out_l_max cannot be obtained by the tensor product 
        if out_l_max > 0 and self.n_paths[1] == 0:
            self.n_paths[1] = 1
        if out_l_max > 1 and self.n_paths[2] == 0:
            self.n_paths[2] = 1
        if out_l_max > 2 and self.n_paths[3] == 0:
            self.n_paths[3] = 1
        
        # define prefix and postfix for einsum
        if connection_mode == 'uvw':
            self.prefix, self.postfix = 'uvw,', 'w'
        else:
            self.prefix, self.postfix = 'uv,', 'u'
        
        # add the batch dimension
        if not self.shared_weights:
            self.prefix = 'a' + self.prefix
            
        # define normalization
        self.alpha = []
        for n_paths in self.n_paths:
            if connection_mode == 'uvw':
                self.alpha.extend([(n_paths * in1_features * in2_features) ** (-0.5)] * n_paths)
            else:
                self.alpha.extend([(n_paths * in2_features) ** (-0.5)] * n_paths)
                
        # define weight for the tensor product
        if internal_weights:
            assert self.shared_weights, 'Having internal weights impose shared weights'
            if connection_mode == 'uvw':
                self.weight = nn.ParameterList([])
                for _ in range(self.n_total_paths):
                    self.weight.append(nn.Parameter(torch.randn(in1_features, in2_features, out_features)))
            else:
                assert in1_features == out_features
                self.weight = nn.ParameterList([])
                for _ in range(self.n_total_paths):
                    self.weight.append(nn.Parameter(torch.randn(in1_features, in2_features)))
        else:
            self.register_buffer('weight', torch.Tensor())

        # define identity matrix
        self.register_buffer('eye', torch.eye(3))
        
        # trace and optimize contractions
        self.contractions = torch.nn.ModuleList()
        
        if self.shared_weights:
            if self.connection_mode == 'uvw':
                example_weight = torch.randn(in1_features, in2_features, out_features)
            else:
                example_weight = torch.randn(in1_features, in2_features)
        else:
            if self.connection_mode == 'uvw':
                example_weight = torch.randn(BATCH_SIZE, in1_features, in2_features, out_features)
            else:
                example_weight = torch.randn(BATCH_SIZE, in1_features, in2_features)
                
        # l = 0, shape: n_neighbors x n_feats
        contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}au,av->a{self.postfix}', w, x, y))
        contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                             example_inputs=(example_weight,
                                                                             torch.randn(BATCH_SIZE, in1_features),
                                                                             torch.randn(BATCH_SIZE, in2_features)))
        self.contractions.append(contraction_op)
        
        if self.in1_l_max > 0 and self.in2_l_max > 0:
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiu,aiv->auv', x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3, in1_features),
                                                                                 torch.randn(BATCH_SIZE, 3, in2_features)))
            self.contractions.append(contraction_op)
            
            contraction_tr = torch.fx.symbolic_trace(lambda w, x: torch.einsum(f'{self.prefix}auv->a{self.postfix}', w, x))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(example_weight,
                                                                                 torch.randn(BATCH_SIZE, in1_features, in2_features)))
            self.contractions.append(contraction_op)
            
        if self.in1_l_max > 1 and self.in2_l_max > 1:
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiju,aijv->auv', x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features),
                                                                                 torch.randn(BATCH_SIZE, 3, 3, in2_features)))
            self.contractions.append(contraction_op)
            
            contraction_tr = torch.fx.symbolic_trace(lambda w, x: torch.einsum(f'{self.prefix}auv->a{self.postfix}', w, x))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(example_weight,
                                                                                 torch.randn(BATCH_SIZE, in1_features, in2_features)))
            self.contractions.append(contraction_op)
        
        if self.in1_l_max > 2 and self.in2_l_max > 2:
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijku,aijkv->auv', x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3, 3, 3, in1_features),
                                                                                 torch.randn(BATCH_SIZE, 3, 3, 3, in2_features)))
            self.contractions.append(contraction_op)
            
            contraction_tr = torch.fx.symbolic_trace(lambda w, x: torch.einsum(f'{self.prefix}auv->a{self.postfix}', w, x))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(example_weight,
                                                                                 torch.randn(BATCH_SIZE, in1_features, in2_features)))
            self.contractions.append(contraction_op)
        
        # l = 1, shape: n_neighbors x 3 x n_feats
        if self.out_l_max > 0:
            if self.in1_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aiu,av->ai{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, in2_features)))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 1 and self.in2_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiju,ajv->aiuv', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, 3, in2_features)))
                self.contractions.append(contraction_op)
                
                contraction_tr = torch.fx.symbolic_trace(lambda w, x: torch.einsum(f'{self.prefix}aiuv->ai{self.postfix}', w, x))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, in1_features, in2_features)))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 2 and self.in2_l_max > 1:
                contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aijku,ajkv->aiuv', x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(torch.randn(BATCH_SIZE, 3, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in2_features)))
                self.contractions.append(contraction_op)
                
                contraction_tr = torch.fx.symbolic_trace(lambda w, x: torch.einsum(f'{self.prefix}aiuv->ai{self.postfix}', w, x))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, in1_features, in2_features)))
                self.contractions.append(contraction_op)
                
            if not self.symmetric_product:
                if self.in2_l_max > 0:
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}au,aiv->ai{self.postfix}', w, x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(example_weight,
                                                                                         torch.randn(BATCH_SIZE, in1_features),
                                                                                         torch.randn(BATCH_SIZE, 3, in2_features)))
                    self.contractions.append(contraction_op)
                    
                if self.in1_l_max > 0 and self.in2_l_max > 1:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiu,aijv->ajuv', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, in1_features),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, in2_features)))
                    self.contractions.append(contraction_op)
                    
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x: torch.einsum(f'{self.prefix}aiuv->ai{self.postfix}', w, x))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(example_weight,
                                                                                         torch.randn(BATCH_SIZE, 3, in1_features, in2_features)))
                    self.contractions.append(contraction_op)
                
                if self.in1_l_max > 1 and self.in2_l_max > 2:
                    contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum('aiju,aijkv->akuv', x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(torch.randn(BATCH_SIZE, 3, 3, in1_features),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, 3, in2_features)))
                    self.contractions.append(contraction_op)
                    
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x: torch.einsum(f'{self.prefix}aiuv->ai{self.postfix}', w, x))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(example_weight,
                                                                                         torch.randn(BATCH_SIZE, 3, in1_features, in2_features)))
                    self.contractions.append(contraction_op)
        
        # l = 2, shape: n_neighbors x 3 x 3 x n_feats
        if self.out_l_max > 1:
            if self.in1_l_max > 1:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aiju,av->aij{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, in2_features)))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 0 and self.in2_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aiu,ajv->aij{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, 3, in2_features)))
                self.contractions.append(contraction_op)
                
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, e: torch.einsum(f'{self.prefix}auv,ij->aij{self.postfix}', w, x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, in1_features, in2_features),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 1 and self.in2_l_max > 1:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aiju,ajkv->aik{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in2_features)))
                self.contractions.append(contraction_op)
                
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, e: torch.einsum(f'{self.prefix}auv,ij->aij{self.postfix}', w, x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, in1_features, in2_features),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
                
            if self.in1_l_max > 2 and self.in2_l_max > 2:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aijku,ajklv->ail{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, 3, 3, 3, in2_features)))
                self.contractions.append(contraction_op)
                
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, e: torch.einsum(f'{self.prefix}auv,ij->aij{self.postfix}', w, x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, in1_features, in2_features),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
                
            if self.in1_l_max > 2 and self.in2_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aijku,akv->aij{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, 3, in2_features)))
                self.contractions.append(contraction_op)
            
            if not self.symmetric_product:
                if self.in2_l_max > 1:
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}au,aijv->aij{self.postfix}', w, x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(example_weight,
                                                                                         torch.randn(BATCH_SIZE, in1_features),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, in2_features)))
                    self.contractions.append(contraction_op)
                    
                if self.in1_l_max > 0 and self.in2_l_max > 2:
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aiu,aijkv->ajk{self.postfix}', w, x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(example_weight,
                                                                                         torch.randn(BATCH_SIZE, 3, in1_features),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, 3, in2_features)))
                    self.contractions.append(contraction_op)
        
        # l = 3, shape: n_neighbors x 3 x 3 x 3 x n_feats
        if self.out_l_max > 2:
            if self.in1_l_max > 2:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aijku,av->aijk{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, in2_features)))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 1 and self.in2_l_max > 0:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aiju,akv->aijk{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, 3, in2_features)))
                self.contractions.append(contraction_op)
                
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, e: torch.einsum(f'{self.prefix}aiuv,jk->aijk{self.postfix}', w, x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, in1_features, in2_features),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
            
            if self.in1_l_max > 2 and self.in2_l_max > 1:
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aijku,aklv->aijl{self.postfix}', w, x, y))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, 3, 3, in1_features),
                                                                                     torch.randn(BATCH_SIZE, 3, 3, in2_features)))
                self.contractions.append(contraction_op)
                
                contraction_tr = torch.fx.symbolic_trace(lambda w, x, e: torch.einsum(f'{self.prefix}aiuv,jk->aijk{self.postfix}', w, x, e))
                contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                     example_inputs=(example_weight,
                                                                                     torch.randn(BATCH_SIZE, 3, in1_features, in2_features),
                                                                                     self.eye))
                self.contractions.append(contraction_op)
            
            if not self.symmetric_product:
                if self.in2_l_max > 2:
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}au,aijkv->aijk{self.postfix}', w, x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(example_weight,
                                                                                         torch.randn(BATCH_SIZE, in1_features),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, 3, in2_features)))
                    self.contractions.append(contraction_op)
                    
                if self.in1_l_max > 0 and self.in2_l_max > 1:
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aiu,ajkv->aijk{self.postfix}', w, x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(example_weight,
                                                                                         torch.randn(BATCH_SIZE, 3, in1_features),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, in2_features)))
                    self.contractions.append(contraction_op)
                        
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x, e: torch.einsum(f'{self.prefix}aiuv,jk->aijk{self.postfix}', w, x, e))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                        example_inputs=(example_weight,
                                                                                        torch.randn(BATCH_SIZE, 3, in1_features, in2_features),
                                                                                        self.eye))
                    self.contractions.append(contraction_op)
                
                if self.in1_l_max > 1 and self.in2_l_max > 2:
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(f'{self.prefix}aiju,ajklv->aikl{self.postfix}', w, x, y))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                         example_inputs=(example_weight,
                                                                                         torch.randn(BATCH_SIZE, 3, 3, in1_features),
                                                                                         torch.randn(BATCH_SIZE, 3, 3, 3, in2_features)))
                    self.contractions.append(contraction_op)
                    
                    contraction_tr = torch.fx.symbolic_trace(lambda w, x, e: torch.einsum(f'{self.prefix}aiuv,jk->aijk{self.postfix}', w, x, e))
                    contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                        example_inputs=(example_weight,
                                                                                        torch.randn(BATCH_SIZE, 3, in1_features, in2_features),
                                                                                        self.eye))
                    self.contractions.append(contraction_op)
    
    def _get_weights(self, weight: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Prepares weights before computing the tensor product.

        Args:
            weight (Optional[torch.Tensor], optional): Weight tensor. Defaults to None.

        Returns:
            torch.Tensor: Weight tensor.
        """
        if weight is None:
            if not self.internal_weights:
                raise RuntimeError('Weights must be provided if no `internal_weights` are defined.')
            return self.weight
        else:
            if self.shared_weights:
                if self.connection_mode == 'uvw':
                    assert weight.shape == (self.n_total_paths, self.in1_features, self.in2_features, self.out_features), 'Invalid weight shape.'
                else:
                    assert weight.shape == (self.n_total_paths, self.in1_features, self.in2_features), 'Invalid weight shape.'
                return weight
            else:
                if self.connection_mode == 'uvw':
                    assert weight.shape[1:] == (self.n_total_paths, self.in1_features, self.in2_features, self.out_features), 'Invalid weight shape.'
                    assert weight.ndim == 5, 'When shared weights is False, weights must have batch dimension.'
                else:
                    assert weight.shape[1:] == (self.n_total_paths, self.in1_features, self.in2_features), 'Invalid weight shape.'
                    assert weight.ndim == 4, 'When shared weights is False, weights must have batch dimension.'
                return weight.permute(1, 0, *[i for i in range(2, len(weight.shape))])
    
    def forward(self, 
                x: torch.Tensor, 
                y: torch.Tensor, 
                weight: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Computes tensor product between input tensors `x` and `y`. Both tensors must contain flattened 
        irreducible Cartesian tensors/Cartesian harmonics that are accessed using pre-computed 
        slices and shapes.

        Args:
            x (torch.Tensor): First input tensor.
            y (torch.Tensor): Second input tensor.
            weight (Optional[torch.Tensor], optional): Optional external weights. Defaults to None.

        Returns:
            torch.Tensor: Irreducible Cartesian tensors obtained as products of input tensors. The number 
                          of features is larger by the number of paths leading to the respective tensor 
                          rank contained in the flattened output tensor.
        """
        torch._assert(x.shape[-1] == self.in1_dim, 'Incorrect last dimension for x.')
        torch._assert(y.shape[-1] == self.in2_dim, 'Incorrect last dimension for y.')
        
        if self.symmetric_product: torch._assert(torch.equal(x, y), 'Symmetric product is possible only if x == y.')
        
        # weight shape ('uvw'): n_paths x (n_neighbors x) in1_features x in2_features x out_features 
        # weight shape ('uvu'): n_paths x (n_neighbors x) in1_features x in2_features
        weight = self._get_weights(weight)
        
        x_0 = x[:, self.in1_slices[0]].view(x.shape[0], *self.in1_shapes[0])
        y_0 = y[:, self.in2_slices[0]].view(-1, *self.in2_shapes[0])
        if self.in1_l_max > 0: x_1 = x[:, self.in1_slices[1]].view(x.shape[0], *self.in1_shapes[1])
        if self.in2_l_max > 0: y_1 = y[:, self.in2_slices[1]].view(x.shape[0], *self.in2_shapes[1])
        if self.in1_l_max > 1: x_2 = x[:, self.in1_slices[2]].view(x.shape[0], *self.in1_shapes[2])
        if self.in2_l_max > 1: y_2 = y[:, self.in2_slices[2]].view(x.shape[0], *self.in2_shapes[2])
        if self.in1_l_max > 2: x_3 = x[:, self.in1_slices[3]].view(x.shape[0], *self.in1_shapes[3])
        if self.in2_l_max > 2: y_3 = y[:, self.in2_slices[3]].view(x.shape[0], *self.in2_shapes[3])
        
        i_path = 0
        i_contraction = 0
        
        cp_0 = [_norm_l1l2l3(0, 0, 0) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_0, y_0)]
        i_contraction += 1
        i_path += 1
        
        if self.in1_l_max > 0 and self.in2_l_max > 0:
            xy_110 = self.contractions[i_contraction](x_1, y_1)
            i_contraction += 1
            
            cp_0.append(_norm_l1l2l3(1, 1, 0) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_110))
            i_contraction += 1
            i_path += 1
        
        if self.in1_l_max > 1 and self.in2_l_max > 1:
            xy_220 = self.contractions[i_contraction](x_2, y_2)
            i_contraction += 1
            
            cp_0.append(_norm_l1l2l3(2, 2, 0) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path] , xy_220))
            i_contraction += 1
            i_path += 1
        
        if self.in1_l_max > 2 and self.in2_l_max > 2:
            xy_330 = self.contractions[i_contraction](x_3, y_3)
            i_contraction += 1
            
            cp_0.append(_norm_l1l2l3(3, 3, 0) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_330))
            i_contraction += 1
            i_path += 1

        # shape: n_neighbors x (out_features * n_paths)
        cp_0 = torch.cat(cp_0, dim=-1)
        
        if self.out_l_max == 0:
            return cp_0
        
        cp_1 = []
        
        if self.in1_l_max > 0:
            cp_1.append(_norm_l1l2l3(1, 0, 1) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_1, y_0))
            i_contraction += 1
            i_path += 1
           
        if self.in1_l_max > 1 and self.in2_l_max > 0:
            xy_211 = self.contractions[i_contraction](x_2, y_1)
            i_contraction += 1
            
            cp_1.append(_norm_l1l2l3(2, 1, 1) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_211))
            i_contraction += 1
            i_path += 1
        
        if self.in1_l_max > 2 and self.in2_l_max > 1:
            xy_321 = self.contractions[i_contraction](x_3, y_2)
            i_contraction += 1
            
            cp_1.append(_norm_l1l2l3(3, 2, 1) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_321))
            i_contraction += 1
            i_path += 1
            
        if not self.symmetric_product:
            if self.in2_l_max > 0:
                cp_1.append(_norm_l1l2l3(0, 1, 1) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_0, y_1))
                i_contraction += 1
                i_path += 1
                
            if self.in1_l_max > 0 and self.in2_l_max > 1:
                xy_121 = self.contractions[i_contraction](x_1, y_2)
                i_contraction += 1
                
                cp_1.append(_norm_l1l2l3(1, 2, 1) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_121))
                i_contraction += 1
                i_path += 1
                
            if self.in1_l_max > 1 and self.in2_l_max > 2:
                xy_231 = self.contractions[i_contraction](x_2, y_3)
                i_contraction += 1
                
                cp_1.append(_norm_l1l2l3(2, 3, 1) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_231))
                i_contraction += 1
                i_path += 1
        
        if len(cp_1) == 0:
            # append zeros if out_l_max > 0 is requested
            cp_1.append(torch.zeros(cp_0.shape[0], 3, self.out_features, device=cp_0.device, dtype=cp_0.dtype))
        
        # shape: n_neighbors x 3 x (out_features * n_paths)
        cp_1 = torch.cat(cp_1, dim=-1)
        
        if self.out_l_max == 1:
            return torch.cat([cp_0, 
                              cp_1.view(x.shape[0], 3 * self.out_features * self.n_paths[1])], -1)
        
        cp_2 = []
        
        if self.in1_l_max > 1:
            cp_2.append(_norm_l1l2l3(2, 0, 2) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path],  x_2, y_0))
            i_contraction += 1
            i_path += 1
           
        if self.in1_l_max > 0 and self.in2_l_max > 0:
            x_x = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_1, y_1)
            x_x = x_x  + x_x.permute(0, 2, 1, 3)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_110, self.eye)
            i_contraction += 1
            i_path += 1
            
            cp_2.append(_norm_l1l2l3(1, 1, 2) * (x_x - 2. / 3. * xx_e))
        
        if self.in1_l_max > 1 and self.in2_l_max > 1:
            x_x = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_2, y_2)
            x_x = x_x + x_x.permute(0, 2, 1, 3)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_220, self.eye)
            i_contraction += 1
            i_path += 1
            
            cp_2.append(_norm_l1l2l3(2, 2, 2) * (x_x - 2. / 3. * xx_e))
           
        if self.in1_l_max > 2 and self.in2_l_max > 2:
            x_x = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_3, y_3)
            x_x = x_x + x_x.permute(0, 2, 1, 3)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_330, self.eye)
            i_contraction += 1
            i_path += 1
            
            cp_2.append(_norm_l1l2l3(3, 3, 2) * (x_x - 2. / 3. * xx_e))
            
        if self.in1_l_max > 2 and self.in2_l_max > 0:
            cp_2.append(_norm_l1l2l3(3, 1, 2) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_3, y_1))
            i_contraction += 1
            i_path += 1
            
        if not self.symmetric_product:
            if self.in2_l_max > 1:
                cp_2.append(_norm_l1l2l3(0, 2, 2) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_0, y_2))
                i_contraction += 1
                i_path += 1
            
            if self.in1_l_max > 0 and self.in2_l_max > 2:
                cp_2.append(_norm_l1l2l3(1, 3, 2) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_1, y_3))
                i_contraction += 1
                i_path += 1
        
        if len(cp_2) == 0:
            # append zeros if out_l_max > 0 is requested
            cp_2.append(torch.zeros(cp_0.shape[0], 3, 3, self.out_features, device=cp_0.device, dtype=cp_0.dtype))
        
        # shape: n_neighbors x 3 x 3 x (out_features * n_paths)
        cp_2 = torch.cat(cp_2, dim=-1)
        
        if self.out_l_max == 2:
            return torch.cat([cp_0, 
                              cp_1.view(x.shape[0], 3 * self.out_features * self.n_paths[1]),
                              cp_2.view(x.shape[0], (3 ** 2) * self.out_features * self.n_paths[2])], -1)
        
        cp_3 = []
        
        if self.in1_l_max > 2:
            cp_3.append(_norm_l1l2l3(3, 0, 3) * self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_3, y_0))
            i_contraction += 1
            i_path += 1
           
        if self.in1_l_max > 1 and self.in2_l_max > 0:
            x_x = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_2, y_1)
            x_x = x_x + x_x.permute(0, 2, 3, 1, 4) + x_x.permute(0, 3, 1, 2, 4)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_211, self.eye)
            xx_e = xx_e + xx_e.permute(0, 2, 3, 1, 4) + xx_e.permute(0, 3, 1, 2, 4)
            i_contraction += 1
            i_path += 1
            
            cp_3.append(_norm_l1l2l3(2, 1, 3) * (x_x - 2. / 5. * xx_e))
            
           
        if self.in1_l_max > 2 and self.in2_l_max > 1:
            x_x = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_3, y_2)
            x_x = x_x + x_x.permute(0, 2, 3, 1, 4) + x_x.permute(0, 3, 1, 2, 4)
            i_contraction += 1
            
            xx_e = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_321, self.eye)
            xx_e = xx_e + xx_e.permute(0, 2, 3, 1, 4) + xx_e.permute(0, 3, 1, 2, 4)
            i_contraction += 1
            i_path += 1
            
            cp_3.append(_norm_l1l2l3(3, 2, 3) * (x_x - 2. / 5. * xx_e))
            
        
        if not self.symmetric_product:
            if self.in2_l_max > 2:
                cp_3.append(_norm_l1l2l3(0, 3, 3) *  self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_0, y_3))
                i_contraction += 1
                i_path += 1
                
            if self.in1_l_max > 0 and self.in2_l_max > 1:
                x_x = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_1, y_2)
                x_x = x_x + x_x.permute(0, 2, 3, 1, 4) + x_x.permute(0, 3, 1, 2, 4)
                i_contraction += 1
                
                xx_e = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_121, self.eye)
                xx_e = xx_e + xx_e.permute(0, 2, 3, 1, 4) + xx_e.permute(0, 3, 1, 2, 4)
                i_contraction += 1
                i_path += 1
                
                cp_3.append(_norm_l1l2l3(1, 2, 3) * (x_x - 2. / 5. * xx_e))
                
            if self.in1_l_max > 1 and self.in2_l_max > 2:
                x_x = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], x_2, y_3)
                x_x = x_x + x_x.permute(0, 2, 3, 1, 4) + x_x.permute(0, 3, 1, 2, 4)
                i_contraction += 1
                
                xx_e = self.contractions[i_contraction](weight[i_path] * self.alpha[i_path], xy_231, self.eye)
                xx_e = xx_e + xx_e.permute(0, 2, 3, 1, 4) + xx_e.permute(0, 3, 1, 2, 4)
                i_contraction += 1
                i_path += 1
                
                cp_3.append(_norm_l1l2l3(2, 3, 3) * (x_x - 2. / 5. * xx_e))
        
        if len(cp_3) == 0:
            # append zeros if out_l_max > 0 is requested
            cp_3.append(torch.zeros(cp_0.shape[0], 3, 3, 3, self.out_features, device=cp_0.device, dtype=cp_0.dtype))
        
        # shape: n_neighbors x 3 x 3 x 3 x (out_features * n_paths)
        cp_3 = torch.cat(cp_3, dim=-1)
        
        if self.out_l_max == 3:
            return torch.cat([cp_0, 
                              cp_1.view(x.shape[0], 3 * self.out_features * self.n_paths[1]),
                              cp_2.view(x.shape[0], (3 ** 2) * self.out_features * self.n_paths[2]),
                              cp_3.view(x.shape[0], (3 ** 3) * self.out_features * self.n_paths[3])], -1)
            
    def __repr__(self) -> str:
        if self.connection_mode == 'uvw':
            weight_numel = self.n_total_paths * self.in1_features * self.in2_features * self.out_features
        else:
            weight_numel = self.n_total_paths * self.in1_features * self.in2_features
        return (f"{self.__class__.__name__} ({self.in1_l_max} x {self.in2_l_max} -> {self.out_l_max} | {self.n_total_paths} total paths | {self.n_paths} paths | {weight_numel} weights)")


def _get_n_paths(in1_l_max: int, 
                 in2_l_max: int, 
                 out_l_max: int,
                 in1_paths: Optional[List[int]] = None,
                 in2_paths: Optional[List[int]] = None,
                 symmetric_product: bool = False) -> int:
    """Counts the number of re-coupling paths for Cartesian harmonics depending on the rank of input tensors 
    and the maximal rank of the expected output tensor. 
    
    We count the number of paths obtained by re-coupling Cartesian harmonics, but also paths obtained through 
    re-coupling/contracting Cartesian harmonics in previous calculations.


    Args:
        in1_l_max (int): Maximal rotational order/rank of the first input tensor.
        in2_l_max (int): Maximal rotational order/rank of the second input tensor.
        out_l_max (int): Maximal rotational order/rank of the output tensor.
        in1_paths (Optional[List[int]], optional): Number of contraction paths used to obtain the first input tensor. 
                                                   Defaults to None.
        in2_paths (Optional[List[int]], optional): Number of contraction paths used to obtain the second input tensor. 
                                                   Defaults to None.
        symmetric_product (bool, optional): If True, skip the calculation of symmetric contractions. Defaults to False.

    Returns:
        int: Total number of contraction paths.
    """
    if in1_paths is None: in1_paths = [1 for _ in range(in1_l_max + 1)]
    if in2_paths is None: in2_paths = [1 for _ in range(in2_l_max + 1)]
    
    # input tensors have by default scalar features
    n_paths = in1_paths[0] * in2_paths[0]
    
    # count paths leading to l = 0
    if in1_l_max > 0 and in2_l_max > 0: n_paths += in1_paths[1] * in2_paths[1]
    if in1_l_max > 1 and in2_l_max > 1: n_paths += in1_paths[2] * in2_paths[2]
    if in1_l_max > 2 and in2_l_max > 2: n_paths += in1_paths[3] * in2_paths[3]
    if out_l_max == 0: return n_paths
    
    # count paths leading to l = 1
    if in1_l_max > 0: n_paths += in1_paths[1] * in2_paths[0]
    if in1_l_max > 1 and in2_l_max > 0: n_paths += in1_paths[2] * in2_paths[1]
    if in1_l_max > 2 and in2_l_max > 1: n_paths += in1_paths[3] * in2_paths[2]
    if not symmetric_product:
        if in2_l_max > 0: n_paths += in1_paths[0] * in2_paths[1]
        if in1_l_max > 0 and in2_l_max > 1: n_paths += in1_paths[1] * in2_paths[2]
        if in1_l_max > 1 and in2_l_max > 2: n_paths += in1_paths[2] * in2_paths[3]
    if out_l_max == 1: return n_paths
    
    # count paths leading to l=2
    if in1_l_max > 1: n_paths += in1_paths[2] * in2_paths[0]
    if in1_l_max > 0 and in2_l_max > 0: n_paths += in1_paths[1] * in2_paths[1]
    if in1_l_max > 1 and in2_l_max > 1: n_paths += in1_paths[2] * in2_paths[2]
    if in1_l_max > 2 and in2_l_max > 2: n_paths += in1_paths[3] * in2_paths[3]
    if in1_l_max > 2 and in2_l_max > 0: n_paths += in1_paths[3] * in2_paths[1]
    if not symmetric_product:
        if in2_l_max > 1: n_paths += in1_paths[0] * in2_paths[2]
        if in1_l_max > 0 and in2_l_max > 2: n_paths += in1_paths[1] * in2_paths[3]
    if out_l_max == 2: return n_paths
    
    # count paths leading to l=3
    if in1_l_max > 2: n_paths += in1_paths[3] * in2_paths[0]
    if in1_l_max > 1 and in2_l_max > 0: n_paths += in1_paths[2] * in2_paths[1]
    if in1_l_max > 2 and in2_l_max > 1: n_paths += in1_paths[3] * in2_paths[2]
    if not symmetric_product:
        if in2_l_max > 2: n_paths += in1_paths[0] * in2_paths[3]
        if in1_l_max > 0 and in2_l_max > 1: n_paths += in1_paths[1] * in2_paths[2]
        if in1_l_max > 1 and in2_l_max > 2: n_paths += in1_paths[2] * in2_paths[3]
    if out_l_max == 3: return n_paths


def _factorial(n: int) -> int:
    """Computes factorial.

    Args:
        n (int): Input integer.

    Returns:
        int: Output integer.
    """
    if n == 0:
        return 1
    else:
        return n * _factorial(n-1)


def _doublefactorial(n: int) -> int:
    """Computes double factorial.

    Args:
        n (int): Input integer.

    Returns:
        int: Output integer.
    """
    if n <= 0:
        return 1
    else:
        return n * _doublefactorial(n-2)


def _norm_l1l2l3(l1: int,
                 l2: int,
                 l3: int) -> float:
    """Computes the normalization factor for the irreducible recoupling of Cartesian harmonics.
    This function implements the normalization factor for an even recoupling.

    Args:
        l1 (int): Rank on the first input Cartesian harmonics.
        l2 (int): Rank on the second input Cartesian harmonics.
        l3 (int): Rank on the output Cartesian harmonics.

    Returns:
        float: Normalization factor.
    """
    assert (l1 + l2 - l3) % 2 == 0
    J = l1 + l2 + l3
    J1 = J - 2 * l1 - 1
    J2 = J - 2 * l2 - 1
    J3 = J - 2 * l3 - 1
    num = _factorial(l1) * _factorial(l2) * _doublefactorial(int(2 * l3 - 1)) * _factorial(int((J1+1)/2)) * _factorial(int((J2+1)/2))
    den = _factorial(l3) * _doublefactorial(J1) * _doublefactorial(J2) * _doublefactorial(J3) * _factorial(int(J/2))
    return (num / den)
