# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

'''
this folder and code is modified base on VideoPose code,
https://github.com/facebookresearch/VideoPose3D
the VPose model for single frame setting.
'''


import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.utils import get_pose_features

class TemporalModelBase(nn.Module):
    """
    Do not instantiate this class.
    """
    
    def __init__(self, num_joints_in, in_features, num_joints_out,
                 filter_widths, causal, dropout, channels):
        super().__init__()
        
        # Validate input
        for fw in filter_widths:
            assert fw % 2 != 0, 'Only odd filter widths are supported'
        
        self.num_joints_in = num_joints_in
        self.in_features = in_features
        self.num_joints_out = num_joints_out
        self.filter_widths = filter_widths
        
        self.drop = nn.Dropout(dropout)
        self.relu = nn.ReLU(inplace=True)
        
        self.pad = [ filter_widths[0] // 2 ]
        self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1)
        self.shrink = nn.Conv1d(channels, 56 + 32, 1)
        self.root = nn.Conv1d(channels, 1, 1)
        # self.bone = nn.Conv1d(channels, 15, 1)
        self.sign = nn.Conv1d(channels, 40, 1)
        self.index = torch.tensor([[1, 1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1], [1, -1, -1],[-1, 1, -1], [-1, -1, 1], [-1, -1, -1]]).cuda().to(torch.float)
        self.bone = torch.tensor([0.1335, 0.4479, 0.4446, 0.1335, 0.4479, 0.4446, 0.2393, 0.2541, 0.1854, 0.1599, 0.2801, 0.2485, 0.1599, 0.2801, 0.2483]).cuda().to(torch.float)

    def set_bn_momentum(self, momentum):
        self.expand_bn.momentum = momentum
        for bn in self.layers_bn:
            bn.momentum = momentum
            
    def receptive_field(self):
        """
        Return the total receptive field of this model as # of frames.
        """
        frames = 0
        for f in self.pad:
            frames += f
        return 1 + 2*frames
    
    def total_causal_shift(self):
        """
        Return the asymmetric offset for sequence padding.
        The returned value is typically 0 if causal convolutions are disabled,
        otherwise it is half the receptive field.
        """
        frames = self.causal_shift[0]
        next_dilation = self.filter_widths[0]
        for i in range(1, len(self.filter_widths)):
            frames += self.causal_shift[i] * next_dilation
            next_dilation *= self.filter_widths[i]
        return frames

    def one_hot(self, logits):

        logits = nn.Softplus()(logits)
        y_soft = logits / logits.sum(dim=-1, keepdim=True)

        # logits = nn.Sigmoid()(logits)
        # y_soft = torch.zeros_like(logits)
        # y_soft[:,0] = logits[:,0]
        # y_soft[:,1] = 1 - y_soft[:,0]
        
        # sampled_indices = torch.multinomial(y_soft, 1, replacement=True)
        sampled_indices = torch.argmax(y_soft,dim=-1)[:, None]
        y_hard = F.one_hot(sampled_indices, num_classes=logits.size(1))[:,0]

        # y_soft = y_hard * y_soft
        soft = 0
        
        return y_hard - y_soft.detach() + y_soft , 0

    def get_sign(self, s):
        depth, depth_prob = self.one_hot(s.reshape(-1,8))


        sign = torch.matmul(depth.reshape(-1,5, 8), self.index).reshape(-1, 15)
        sign[:, 3] = -sign[:, 0]
        
        return sign

    def get_pose_from_features(self, pose_lenth, pose_sign, p2d, root):
        p2d = p2d.permute(0,2,1).reshape(-1, 2, self.num_joints_in)
        p3d_updated = torch.zeros_like(torch.cat([p2d , p2d[:,0:1]], dim=1))
        n = 0

        f = 1
        p3d_updated[:, :, 0] = torch.cat([p2d[:, :, 0] / f * root, root], dim=1)
        bone_inx = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 8, 10, 11, 8, 13, 14]
        
        for i, j in enumerate(bone_inx):
            if j == -1:
                pass
            else:
                xf = (p2d[:,0,i] / f)
                yf = (p2d[:,1,i] / f) 
                D = pose_lenth[:, n]
                sign = pose_sign[:, n]
            
                a = xf ** 2 + yf ** 2 + 1
                b = (xf * p3d_updated[:,0,j].clone()+ yf * p3d_updated[:,1,j].clone() + p3d_updated[:,2,j].clone())
                c = p3d_updated[:,0,j].clone() ** 2 + p3d_updated[:,1,j].clone() ** 2 + p3d_updated[:,2,j].clone() ** 2 - D ** 2
                d = (b ** 2 - a * c)
         
                t = (b + sign * torch.sqrt(d / 2 + abs(d).detach() / 2 + 1e-9)) / a
                

                p3d_updated[:, :, i] = torch.stack([xf * t, yf * t, t], dim=1) * 1
        


                n += 1

        return p3d_updated.permute(0,2,1)

    def forward(self, x):
        """
        input: bx16x2 / bx32
        output: bx16x3
        """
        if len(x.shape) == 2:
            x = x.view(x.shape[0], 16, 2)
        # pre-processing
        x = x.view(x.shape[0], 1, 16, 2)
        p2d = x.clone()[:,0]
        assert len(x.shape) == 4
        assert x.shape[-2] == self.num_joints_in
        assert x.shape[-1] == self.in_features
        
        sz = x.shape[:3]
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)
        
        x = self._forward_blocks(x).squeeze(2)
        
        # root_depth = (x[:, 0].unsqueeze(1))
        # bone_lenth = nn.Tanh()(x[:, 1:16]) * 0.1 + self.bone.unsqueeze(0)
        p3d = torch.zeros_like(x[:, :48].reshape(-1, 16, 3))

        depth = x[:, :16]
        p3d[:, :, 2] = depth
        p3d[:, :, :2] = p2d * depth.unsqueeze(2)

        root_depth = p3d[:, 0, 2].unsqueeze(1)
        bone_lenth, depth_sign_ = get_pose_features(p3d)

        depth_sign = self.get_sign((x[:, 16:56]))
        depth_sign = torch.where(x[:, 16:31] == 0, -1, x[:, 16:31])
        depth_sign = depth_sign / abs(depth_sign).detach()
        # depth_sign = depth_sign * depth_sign_
        depth_sign[:, 3] = -depth_sign[:, 0]
        p2d = p2d #+ nn.Tanh()(x[:, 56:].reshape(-1,16,2)) * 0.1
        # 
        # print(out.mean(0) * 1000)
        if self.training:
            return p3d, bone_lenth, depth_sign, root_depth
        else:
            out = self.get_pose_from_features(bone_lenth, depth_sign, p2d, root_depth)
            return out

class TemporalModel(TemporalModelBase):
    """
    Reference 3D pose estimation model with temporal convolutions.
    This implementation can be used for all use-cases.
    """
    
    def __init__(self, num_joints_in, in_features, num_joints_out,
                 filter_widths, causal=False, dropout=0.25, channels=1024, dense=False):
        """
        Initialize this model.
        
        Arguments:
        num_joints_in -- number of input joints (e.g. 17 for Human3.6M)
        in_features -- number of input features for each joint (typically 2 for 2D input)
        num_joints_out -- number of output joints (can be different than input)
        filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field
        causal -- use causal convolutions instead of symmetric convolutions (for real-time applications)
        dropout -- dropout probability
        channels -- number of convolution channels
        dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment)
        """
        super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels)
        
        self.expand_conv = nn.Conv1d(num_joints_in*in_features, channels, filter_widths[0], bias=False)
        
        layers_conv = []
        layers_bn = []
        
        self.causal_shift = [ (filter_widths[0]) // 2 if causal else 0 ]
        next_dilation = filter_widths[0]
        for i in range(1, len(filter_widths)):
            self.pad.append((filter_widths[i] - 1)*next_dilation // 2)
            self.causal_shift.append((filter_widths[i]//2 * next_dilation) if causal else 0)
            
            layers_conv.append(nn.Conv1d(channels, channels,
                                         filter_widths[i] if not dense else (2*self.pad[-1] + 1),
                                         dilation=next_dilation if not dense else 1,
                                         bias=False))
            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
            layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
            
            next_dilation *= filter_widths[i]
            
        self.layers_conv = nn.ModuleList(layers_conv)
        self.layers_bn = nn.ModuleList(layers_bn)
        
    def _forward_blocks(self, x):
        x = self.drop(self.relu(self.expand_bn(self.expand_conv(x))))
        
        for i in range(len(self.pad) - 1):
            pad = self.pad[i+1]
            shift = self.causal_shift[i+1]
            res = x[:, :, pad + shift : x.shape[2] - pad + shift]
            
            x = self.drop(self.relu(self.layers_bn[2*i](self.layers_conv[2*i](x))))
            x = res + self.drop(self.relu(self.layers_bn[2*i + 1](self.layers_conv[2*i + 1](x))))
        
        root = self.root(x)
        bone = self.bone(x)
        sign = self.sign(x)

        return torch.cat([root, bone, sign], dim=1)
    
class Vpose_BDCS(TemporalModelBase):
    """
    3D pose estimation model optimized for single-frame batching, i.e.
    where batches have input length = receptive field, and output length = 1.
    This scenario is only used for training when stride == 1.
    
    This implementation replaces dilated convolutions with strided convolutions
    to avoid generating unused intermediate results. The weights are interchangeable
    with the reference implementation.
    """
    
    def __init__(self, num_joints_in, in_features, num_joints_out,
                 filter_widths, causal=False, dropout=0.25, channels=1024):
        """
        Initialize this model.
        
        Arguments:
        num_joints_in -- number of input joints (e.g. 17 for Human3.6M)
        in_features -- number of input features for each joint (typically 2 for 2D input)
        num_joints_out -- number of output joints (can be different than input)
        filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field
        causal -- use causal convolutions instead of symmetric convolutions (for real-time applications)
        dropout -- dropout probability
        channels -- number of convolution channels
        """
        super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels)
        
        self.expand_conv = nn.Conv1d(num_joints_in*in_features, channels, filter_widths[0], stride=filter_widths[0], bias=False)
        
        layers_conv = []
        layers_bn = []
        
        self.causal_shift = [ (filter_widths[0] // 2) if causal else 0 ]
        next_dilation = filter_widths[0]
        for i in range(1, len(filter_widths)):
            self.pad.append((filter_widths[i] - 1)*next_dilation // 2)
            self.causal_shift.append((filter_widths[i]//2) if causal else 0)
            
            layers_conv.append(nn.Conv1d(channels, channels, filter_widths[i], stride=filter_widths[i], bias=False))
            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
            layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
            next_dilation *= filter_widths[i]
            
        self.layers_conv = nn.ModuleList(layers_conv)
        self.layers_bn = nn.ModuleList(layers_bn)
        
    def _forward_blocks(self, x):
        x = self.drop(self.relu(self.expand_bn(self.expand_conv(x))))
        
        for i in range(len(self.pad) - 1):
            res = x[:, :, self.causal_shift[i+1] + self.filter_widths[i+1]//2 :: self.filter_widths[i+1]]
            
            x = self.drop(self.relu(self.layers_bn[2*i](self.layers_conv[2*i](x))))
            x = res + self.drop(self.relu(self.layers_bn[2*i + 1](self.layers_conv[2*i + 1](x))))
        
        x = self.shrink(x)
        return x
    

# from curses import reset_shell_mode
import os
from re import I
# from this import d
import time
import math
import copy
import pickle
import argparse
import functools
from collections import deque

import numpy as np
from scipy import integrate

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from torch.autograd import Variable
from torch import autograd
from utils.utils import get_pose_features
from torch.distributions import Normal, Independent
import torch.distributions as dist
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        # Randomly sample weights during initialization. These weights are fixed
        # during optimization and are not trainable.
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
    """A fully connected layer that reshapes outputs to feature maps."""
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.dense(x)[..., None, None]


class ClassifierFreeSampler(nn.Module):
    def __init__(self, model, w):
        super().__init__()
        self.model = model  # model is the actual model to run
        self.w = w  # guidance stength, 0: no guidance, [0, 4.0] in original paper

    def forward(self, batch, t, condition):
        """
        batch: [B, j, 3] or [B, j, 1]
        t: [B, 1]
        condition: [B, j, 2]
        Return: [B, j, 3] or [B, j, 1] same dim as batch
        """
        out = self.model(batch, t, condition)
        # TODO: fine-grained zero-out
        zeros = torch.zeros_like(condition)
        out_uncond = self.model(batch, t, zeros)
        return out + self.w * (out - out_uncond)

# @torch.no_grad()
# def unit_norm_clipper(module):
#     if hasattr(module, 'weight'):
#         w = module.weight.data
#         norm = torch.linalg.norm(w)
#         if norm > 1:
#             w.div_(norm)

def get_sigmas(config):
    """Get sigmas --- the set of noise levels for SMLD from config files.
    Args:
      config: A ConfigDict object parsed from the config file
    Returns:
      sigmas: a jax numpy arrary of noise levels
    """
    sigmas = np.exp(
        np.linspace(np.log(50), np.log(0.01), 1000))

    return sigmas


def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
    half_dim = embedding_dim // 2
    # magic number 10000 is from transformers
    emb = math.log(max_positions) / (half_dim - 1)
    # emb = math.log(2.) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
    # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = F.pad(emb, (0, 1), mode='constant')
    assert emb.shape == (timesteps.shape[0], embedding_dim)
    return emb

class res_block(nn.Module):
    def __init__(self, n_blocks, hidden_dim, embed_dim):
        super(res_block, self).__init__()
        self.n_blocks = n_blocks
        self.act = nn.SiLU()
        self.dropout = nn.Dropout(p=0.25)
        for idx in range(n_blocks):
            setattr(self, f'b{idx+1}_dense1', nn.Linear(hidden_dim, hidden_dim))
            setattr(self, f'b{idx+1}_dense1_t', nn.Linear(embed_dim, hidden_dim))
            setattr(self, f'b{idx+1}_dense1_cond', nn.Linear(hidden_dim, hidden_dim))
            setattr(self, f'b{idx+1}_gnorm1', nn.GroupNorm(32, num_channels=hidden_dim))

            setattr(self, f'b{idx+1}_dense2', nn.Linear(hidden_dim, hidden_dim))
            setattr(self, f'b{idx+1}_dense2_t', nn.Linear(embed_dim, hidden_dim))
            setattr(self, f'b{idx+1}_dense2_cond', nn.Linear(hidden_dim, hidden_dim))
            setattr(self, f'b{idx+1}_gnorm2', nn.GroupNorm(32, num_channels=hidden_dim))

    def forward(self, h, temb, cond):

        for idx in range(self.n_blocks):
            h1 = getattr(self, f'b{idx+1}_dense1')(h)
            h1 += getattr(self, f'b{idx+1}_dense1_t')(temb)
            h1 += getattr(self, f'b{idx+1}_dense1_cond')(cond)
            h1 = getattr(self, f'b{idx+1}_gnorm1')(h1)
            h1 = self.act(h1)
            # dropout, maybe
            h1 = self.dropout(h1)

            h2 = getattr(self, f'b{idx+1}_dense2')(h1)
            h2 += getattr(self, f'b{idx+1}_dense2_t')(temb)
            h2 += getattr(self, f'b{idx+1}_dense2_cond')(cond)
            h2 = getattr(self, f'b{idx+1}_gnorm2')(h2)
            h2 = self.act(h2)
            # dropout, maybe
            if idx != self.n_blocks - 1:
                h2 = self.dropout(h2)   

            h = h + h2

        return h

class GFPose(nn.Module):
    """
    Independent condition feature projection layers for each block
    """
    def __init__(self, config,
        n_joints=17, joint_dim=3, hidden_dim=1024, embed_dim=512, cond_dim=3,
        n_blocks=2):
        super(GFPose, self).__init__()

        self.config = config
        self.n_joints = n_joints
        self.joint_dim = joint_dim
        self.n_blocks = n_blocks

        self.act = nn.SiLU()

        self.pre_dense= nn.Linear(n_joints * 3, hidden_dim)
        self.pre_dense_t = nn.Linear(embed_dim, hidden_dim)
        self.pre_dense_cond = nn.Linear(hidden_dim, hidden_dim)
        self.pre_gnorm = nn.GroupNorm(32, num_channels=hidden_dim)
        self.dropout = nn.Dropout(p=0.25)

        # time embedding

        self.gauss_proj = GaussianFourierProjection(embed_dim=embed_dim)
  

        self.shared_time_embed = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            self.act,
        )
        self.register_buffer('sigmas', torch.tensor(get_sigmas(config)))

        # conditional embedding
        self.cond_embed = nn.Sequential(
            nn.Linear(n_joints * cond_dim, hidden_dim),
            self.act
        )

        self.gauss_num = config.gauss_num
        self.res_bone = res_block(n_blocks, hidden_dim, embed_dim)

        self.res_sign = res_block(n_blocks, hidden_dim, embed_dim)

        self.res_p2d = res_block(n_blocks, hidden_dim, embed_dim)

        self.bone_head = nn.Linear(hidden_dim, self.gauss_num * (self.n_joints -1) * 3)
        # self.bone_sig_head = nn.Linear(embed_dim, 15)

        self.sign_head = nn.Linear(hidden_dim, 2 + 5 * 8)
        self.p2d_head = nn.Linear(hidden_dim, self.gauss_num * 2 * self.n_joints * 3)
    

        self.index = torch.tensor([[1, 1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1], [1, -1, -1],[-1, 1, -1], [-1, -1, 1], [-1, -1, -1]]).cuda().to(torch.float)

        self.cond_pose_mask_prob = 0.1
        self.cond_part_mask_prob = 0.1
        self.cond_joint_mask_prob = 0.1

        if self.cond_part_mask_prob > 0:
            self.part_mask = self.generate_part_mask()  # [p, j]

    def one_hot(self, logits):

        logits = nn.Softplus()(logits)
        y_soft = logits / logits.sum(dim=-1, keepdim=True)

        # logits = nn.Sigmoid()(logits)
        # y_soft = torch.zeros_like(logits)
        # y_soft[:,0] = logits[:,0]
        # y_soft[:,1] = 1 - y_soft[:,0]
        
        sampled_indices = torch.multinomial(y_soft, 1, replacement=True)
        y_hard = F.one_hot(sampled_indices, num_classes=logits.size(1))[:,0]

        # y_soft = y_hard * y_soft
        soft = 0
        
        return y_hard - y_soft.detach() + y_soft , 0

    def random_mask_condition(self, condition):
        """
        During model.train(): Random mask the condition.
        During model.eval(): no operation.
        condition: [B, j*2]
        """
        batch_size = condition.shape[0]
        # mask poses
        if self.cond_pose_mask_prob > 0:
            mask = torch.bernoulli(
                torch.ones(batch_size, device=condition.device) * self.cond_pose_mask_prob
            ).view(batch_size, 1)  # 1-> use null condition, 0-> use real condition
            condition = condition * (1. - mask)

        # TODO: mask parts
        if self.cond_part_mask_prob > 0:
            final_mask = np.ones((batch_size, self.n_joints), dtype=np.float32)
            mask = torch.bernoulli(
                torch.ones(batch_size, self.num_parts) * self.cond_part_mask_prob
            ).numpy().astype(bool)  # [b, p]
            for idx, row in enumerate(mask):
                if np.sum(row) > 0:
                    selected_mask = self.part_mask[row]  # [s, j]
                    overlapped_mask = np.prod(selected_mask, axis=0)  # [j]
                    final_mask[idx] = overlapped_mask
            final_mask = torch.tensor(final_mask, device=condition.device).unsqueeze(-1)  # [b, j, 1]
            condition = condition.view(batch_size, self.n_joints, -1) * final_mask
            condition = condition.view(batch_size, -1)

        # mask joints
        if self.cond_joint_mask_prob > 0:
            mask = torch.bernoulli(
                torch.ones((batch_size, self.n_joints, 1), device=condition.device) * self.cond_joint_mask_prob
            ) # 1-> use null condition, 0-> use real condition
            condition = condition.view(batch_size, self.n_joints, -1) * (1. - mask)
            condition = condition.view(batch_size, -1)

        return condition

    def get_pose_from_features(self, pose_lenth, pose_sign, p2d, root):
        p2d = p2d.reshape(-1, 2, self.n_joints)
        p3d_updated = torch.zeros_like(torch.cat([p2d , p2d[:,0:1]], dim=1))
        n = 0

        f = 1
        p3d_updated[:, :, 0] = torch.cat([p2d[:, :, 0] / f * root, root], dim=1)
        bone_inx = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15]
        # f = f[:,0]
        # f = (p2d[:, 0] * p3d[:, 2] / p3d[:, 0]).mean(-1)
        
        for i, j in enumerate(bone_inx):
            if j == -1:
                pass
            else:
                xf = (p2d[:,0,i] / f)
                yf = (p2d[:,1,i] / f) 
                # xf = p3d[:, 0, i] / p3d[:, 2, i]
                # yf = p3d[:, 1, i] / p3d[:, 2, i]
                D = pose_lenth[:, n]
                sign = pose_sign[:, n]
            
                a = xf ** 2 + yf ** 2 + 1
                b = (xf * p3d_updated[:,0,j].clone()+ yf * p3d_updated[:,1,j].clone() + p3d_updated[:,2,j].clone())
                c = p3d_updated[:,0,j].clone() ** 2 + p3d_updated[:,1,j].clone() ** 2 + p3d_updated[:,2,j].clone() ** 2 - D ** 2
                d = (b ** 2 - a * c)
         
                t = (b + sign * torch.sqrt(d / 2 + abs(d).detach() / 2 + 1e-9)) / a
                

                p3d_updated[:, :, i] = torch.stack([xf * t, yf * t, t], dim=1) * 1

                # exit()
                # if i == 7:
                #     y1 = t[193].item()
                #     y2 = ((b - sign * torch.sqrt(nn.ReLU()(b ** 2 - a * c) + 1e-9)) / a)[193].item()
                #     y0 = (y1 + y2)/2
                #     x0 = p3d[193][2,6].item()
                #     print("첫번째 :",round(y1, 2),", 두번쨰 :",round(y2, 2),", 가운데 :",round(y0, 2),", 루트 :" ,round(x0, 2))
                #     exit()
        


                n += 1

        return p3d_updated


    def generate_part_mask(self):
        """
        given parts
        """
        part_list = [[1, 2, 3], [4, 5, 6], [11, 12, 13],
        [14, 15, 16], [0, 7, 8, 9, 10]]
        self.num_parts = len(part_list)

        part_mask = np.ones((self.num_parts, self.n_joints))  # [p, 17]
        for idx, part in enumerate(part_list):
            part_mask[idx][part] = 0

        return part_mask
    
    def mdn_fn(self, outs, y):
        '''
        Function to calculate the loss function, given a set
        of Gaussians (mean of location) outs, and targets (y)

        outs: (mu, sigma, pi)
        mu - [batch_size, n_nodes, n_gaussians, 3]
        sigma - [batch_size, n_nodes, n_gaussians, 1] after activation function
        pi - [batch_size, n_nodes, n_gaussians] before activation
        y: [batch_size, n_nodes, 3]
        pose_level: whether to use pose level distributions or node levek
        '''
        mu = outs[...,0]
        sigma = torch.max(F.elu(outs[..., 1]) + torch.ones_like(mu), 1e-10 * torch.ones_like(mu))
        pi = nn.Softmax(dim=-1)(outs[...,2])

        mixture_distribution = dist.Categorical(probs=pi)
        component_distribution = dist.Normal(loc=mu, scale=sigma)

        mixture = dist.MixtureSameFamily(mixture_distribution=mixture_distribution, component_distribution=component_distribution)

        if self.training:
            samples = mixture.sample(sample_shape=torch.Size([1]))
        else:
            samples = mixture.sample(sample_shape=torch.Size([200]))
        samples = samples.permute(1, 0, 2).reshape(-1, samples.shape[-1])
        loss = -mixture.log_prob(y).mean()

        return samples, loss

    def get_sign(self, s_root, s):
        depth, depth_prob = self.one_hot(s.reshape(-1,8))

        r_depth, _ = self.one_hot(s_root)
        r_depth = (r_depth * torch.tensor([[1, -1]]).cuda()).sum(-1, keepdim=True)

        sign = torch.matmul(depth.reshape(-1,5, 8), self.index).reshape(-1, 15)
        sign = torch.cat([sign[:,:6], r_depth, sign[:,6:]], dim = 1)
        sign[:, 3] = -sign[:, 0]
        
        return sign

    def forward(self, p2d, gt = None):
        """
        batch: [B, j, 3] or [B, j, 1]
        t: [B]
        condition: [B, j, 2 or 3]
        mask: [B, j, 2 or 3] only used during evaluation
        Return: [B, j, 3] or [B, j, 1] same dim as batch
        """
        # bs = batch.shape[0]



        # batch = batch.view(bs, -1)  # [B, j*3]
        # gt = gt.view(bs, -1)  # [B, j*3]

        # training with random mask
        # if self.training:
            # condition = self.random_mask_condition(condition)  # [B, j*3]


        scale_p2d = 30
        xxx = p2d.reshape(-1, 2, self.n_joints) * scale_p2d         
        gt = gt.reshape(-1, 3, self.n_joints)
        gt_p2d = gt[:, :2] / gt[:, 2:3] * scale_p2d
        gt_root = gt.reshape(-1,3,self.n_joints)[:,2,0]
        pred_3d_stack = torch.zeros((xxx.shape[0], 3, self.n_joints), device=xxx.device)
        pred_3d_stack[:, 2, 0] = gt_root
        pred_3d_stack[:, :2] = xxx
        t = torch.zeros((xxx.shape[0],), device=xxx.device)
        root = gt.reshape(-1,3,self.n_joints)[:,2,:1]
        
        used_sigmas = t + 1
        temb = self.gauss_proj(torch.log(used_sigmas))


        temb = self.shared_time_embed(temb)

 
       
        if self.training:
            pred_3d_stack = pred_3d_stack.permute(0,2,1).reshape(-1, self.n_joints * 3)
            pred_3d_stack = self.random_mask_condition(pred_3d_stack)  # [B, j*3]
            pred_3d_stack = pred_3d_stack.view(-1, self.n_joints, 3).permute(0,2,1).reshape(-1, self.n_joints * 3)
        else:
            pred_3d_stack = pred_3d_stack.reshape(-1, self.n_joints * 3)

        # cond embedding
        cond = self.cond_embed(pred_3d_stack)  # [B, hidden]
        h = self.pre_dense(torch.zeros_like(pred_3d_stack))
        h += self.pre_dense_cond(cond)
        h += self.pre_dense_t(temb)
        h = self.pre_gnorm(h)
        h = self.act(h)
        h = self.dropout(h)
   
        b = self.res_bone(h, temb, cond)
        b = self.bone_head(b).reshape(-1, self.n_joints - 1, self.gauss_num, 3)

        p = self.res_p2d(h, temb, cond)
        p = self.p2d_head(p).reshape(-1,self.n_joints * 2, self.gauss_num, 3) 

        s = self.res_sign(h, temb, cond)
        s = self.sign_head(s)
        s_root = s[:, :2]
        s = s[:, 2:]
       

        p3d_from_gt_2d = torch.zeros_like(gt)
        p3d_from_gt_2d[:, :2] = xxx * gt[:, 2:3] / scale_p2d
        p3d_from_gt_2d[:, 2] = gt[:, 2]

        gt_bone, gt_sign = get_pose_features(gt)

        bone, b_loss = self.mdn_fn(b, gt_bone)
        pred_2d , p2_loss = self.mdn_fn(p, gt_p2d.reshape(-1, self.n_joints * 2))

        loss = b_loss + p2_loss



        self.multi_num = 1
        sign = self.get_sign(s_root, s)
        p3d = self.get_pose_from_features(bone, gt_sign, pred_2d / scale_p2d, root)

            
        return [p3d.reshape(-1, self.multi_num, 3 * self.n_joints) , p2d, xxx, p3d,[bone, sign, 0, p2d], loss]


        


