import math

import torch
import torch.nn as nn
import numpy as np


class embedding_module_log(nn.Module):
    def __init__(self, funcs=[torch.sin, torch.cos], num_freqs=20, max_freq=10, ch_dim=1, include_in=True):
        super().__init__()
        self.functions = funcs
        self.num_functions = list(range(len(funcs)))
        self.freqs = torch.nn.Parameter(2.0**torch.from_numpy(np.linspace(start=0.0,stop=max_freq, num=num_freqs).astype(np.single)), requires_grad=False)
        self.ch_dim = ch_dim
        self.funcs = funcs
        self.include_in = include_in

    def forward(self, x_input):
        if self.include_in:
            out_list = [x_input]
        else:
            out_list = []
        for func in self.funcs:
            for freq in self.freqs:
                out_list.append(func(x_input*freq))
        return torch.cat(out_list, dim=self.ch_dim)

class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)

class basic_project2(nn.Module):
    def __init__(self, input_ch, output_ch):
        super(basic_project2, self).__init__()
        self.proj = nn.Linear(input_ch, output_ch, bias=True)
    def forward(self, x):
        return self.proj(x)

class kernel_linear_act(nn.Module):
    def __init__(self, input_ch, output_ch):
        super(kernel_linear_act, self).__init__()
        self.block = nn.Sequential(nn.PReLU(),
                                   basic_project2(input_ch, output_ch))
    def forward(self, input_x):
        return self.block(input_x)

class att_act(nn.Module):
    def __init__(self, input_ch, num_head):
        super(att_act, self).__init__()
        self.act = nn.LeakyReLU(negative_slope=0.1)
        self.block = nn.MultiheadAttention(input_ch, num_head)
    def forward(self, input_x):
        input_x = self.act(input_x)
        out, _ = self.block(input_x, input_x,  input_x)
        return out
class AcousticRayTracing(nn.Module):
    """Baseline implicit model for acoustic ray tracing.
       Point based acoustic source.
    """

    def __init__(self, n_bins, patches, source_input_ch=42,
                 time_input_ch=21,
                 patch_input_ch=42,
                 listen_input_ch=42,
                 intermediate_ch=256,
                 patch_block=4, all_block=6):
        super(AcousticRayTracing, self).__init__()
        self.t_bins = n_bins
        self.patches = torch.from_numpy(patches)
        self.num_patch = len(self.patches)
        self.listen_input_ch = listen_input_ch
        self.xyz_embedder = embedding_module_log(num_freqs=10, ch_dim=2, max_freq=7)
        self.dist_embedder = embedding_module_log(num_freqs=10, ch_dim=2, max_freq=7)
        self.time_embedder = embedding_module_log(num_freqs=10, ch_dim=2)
        self.patch_embedder = embedding_module_log(num_freqs=10, ch_dim=2, max_freq=7)
        self.times = 2*(torch.arange(0, self.t_bins))/self.t_bins - 1.0
        self.times = self.times.unsqueeze(1)
        self.source_proj = basic_project2(source_input_ch, intermediate_ch)
        self.time_proj = basic_project2(time_input_ch, intermediate_ch)
        self.patch_proj = basic_project2(patch_input_ch, intermediate_ch)
        self.patch_residual = nn.Sequential(basic_project2(patch_input_ch, intermediate_ch),
                                            nn.PReLU(),
                                        basic_project2(intermediate_ch, intermediate_ch))
        self.patch_layers = torch.nn.ModuleList()
        for k in range(patch_block - 2):
            self.patch_layers.append(kernel_linear_act(intermediate_ch, intermediate_ch))
        self.patch_blocks = len(self.patch_layers)
        self.listen_proj = basic_project2(self.listen_input_ch, intermediate_ch)
        self.all_input_ch = self.num_patch*3
        self.all_proj = basic_project2(self.all_input_ch, intermediate_ch)
        self.all_residual = nn.Sequential(basic_project2(self.all_input_ch, intermediate_ch),
                                          nn.PReLU(),
                                             basic_project2(intermediate_ch, intermediate_ch))
        self.all_layers = torch.nn.ModuleList()
        for k in range(all_block - 2):
            self.all_layers.append(kernel_linear_act(intermediate_ch, intermediate_ch))
        self.all_blocks = len(self.all_layers)

        for k in range(all_block - 1):
            self.register_parameter("left_right_{}".format(k), nn.Parameter(torch.randn(1,1,2,intermediate_ch)/math.sqrt(intermediate_ch), requires_grad=True))
        for k in range(4):
            self.register_parameter('rot_{}'.format(k), nn.Parameter(torch.randn(all_block-1, 1,1, intermediate_ch)/math.sqrt(intermediate_ch), requires_grad=True))


        self.out_layer = nn.Linear(intermediate_ch, 1)

    def forward(self, source_points, points, source_norm_pos, listen_norm_pos, rot_idx, b_range):
        """
        Args:
            source_points: `(1, 2)`, the xy positions of the source
            points: `(1, 2)`, the xy positions of the listeners,
            bounces: `(K, 2)`, the xy positions of sampled K
                    bounces corresponded to each listener point.
        Return:
            energy_hist: `(bs,)`, bs=n_bands*n_bins.
        """
        B, _ , _ = points.shape
        # source module
        # relative distance from source to patch
        patches = self.patches.float().to(points.device)
        p_s_dist = patches - source_norm_pos.expand(B, self.num_patch, -1) # B, K, 2
        p_s_dist_embed = self.dist_embedder(p_s_dist.float())
        source_decom_in = p_s_dist_embed
        source_out = self.source_proj(source_decom_in)
        patch_embed = self.patch_embedder(patches.unsqueeze(0)).repeat(B, 1, 1) # B, K, 42
        patch_out = self.patch_proj(patch_embed)
        for k in range(len(self.patch_layers)):
            patch_out = self.patch_layers[k](patch_out)
            if k == (self.patch_blocks//2 - 1):
                patch_out = patch_out + self.patch_residual(patch_embed)
        
        # listen module
        p_l_dist = patches - listen_norm_pos.float().expand(B, self.num_patch,-1)
        p_l_dist_embed = self.dist_embedder(p_l_dist.float())
        listen_decomp_in = p_l_dist_embed
        listen_out = self.listen_proj(listen_decomp_in)
        times = self.times
        time_embed = self.time_embedder(times.unsqueeze(0).to(points.device)).repeat(B, 1, 1)  # B,T,21
        time_out = self.time_proj(time_embed)

        source_out = time_out @ source_out.transpose(1, 2)
        patch_out = time_out @ patch_out.transpose(1, 2)
        listen_out = time_out @ listen_out.transpose(1, 2)

        all_in = torch.cat((source_out,
                            patch_out,
                            listen_out,
                           ), dim=2)

        rot_latent = torch.stack([getattr(self, "rot_{}".format(rot_idx_single.item())) for rot_idx_single in rot_idx], dim=0)
        out = self.all_proj(all_in).unsqueeze(2).repeat(1,1,2,1) + getattr(self, "left_right_0") + rot_latent[:, 0]
        for k in range(len(self.all_layers)):
            out = self.all_layers[k](out) + getattr(self, "left_right_{}".format(k+1)) + rot_latent[:, k+1]
            if k == (self.all_blocks//2-1):
                out = out + self.all_residual(all_in).unsqueeze(2).repeat(1, 1, 2, 1)
        return self.out_layer(out)