import os
import numpy as np
from pkg_resources import split_sections
import torch
import torch.nn as nn
from timeit import default_timer as timer

import config as cfg

class FlowArbitrary(nn.Module):
    def __init__(self, model_canonicalize, model_deform):
        super(FlowArbitrary, self).__init__()

        self.model_canonicalize = model_canonicalize
        self.model_deform = model_deform
        

    def forward(self, query_points, surface_samples, surface_mask, surface_normals=None):
        # query_points: B x 2 x N x 3
        # surface_samples: B x 2 x N x 3
        # surface_mask: B x N x 3
        # surface_normals: B x 2 x N x 3
        
        ################################################################################################
        # From pose 0 to canonical pose.
        ################################################################################################  
        split_size = 10000
        if surface_normals is not None:
            ########
            if query_points.shape[1] > split_size:
                query_points0_canonicalize_list = []
                for query_points_split in torch.split(query_points[:, 0, :, :], split_size, dim=1):
                    query_points0_split_canonicalize = self.model_canonicalize(query_points_split, surface_samples[:, 0], None, None, surface_normals[:, 0])
                    query_points0_canonicalize_list.append(query_points0_split_canonicalize)
                query_points0_canonicalize  = torch.cat(query_points0_canonicalize_list, dim=1).contiguous()
            else:
                query_points0_canonicalize = self.model_canonicalize(query_points[:, 0], surface_samples[:, 0], None, None, surface_normals[:, 0])
            #######
            surface_samples0_canonicalize = self.model_canonicalize(surface_samples[:, 0], surface_samples[:, 0], None, None, surface_normals[:, 0])
        else:
            ########
            if query_points.shape[1] > split_size:
                query_points0_canonicalize_list = []
                for query_points_split in torch.split(query_points[:, 0, :, :], split_size, dim=1):
                    query_points0_split_canonicalize = self.model_canonicalize(query_points_split, surface_samples[:, 0], None, None)
                    query_points0_canonicalize_list.append(query_points0_split_canonicalize)
                query_points0_canonicalize  = torch.cat(query_points0_canonicalize_list, dim=1).contiguous()
            else:
                query_points0_canonicalize = self.model_canonicalize(query_points[:, 0], surface_samples[:, 0], None, None)
            #######
            surface_samples0_canonicalize = self.model_canonicalize(surface_samples[:, 0], surface_samples[:, 0], None, None)
        
        ################################################################################################
        # From canonical pose to pose 1
        ################################################################################################ 
        if surface_normals is not None:
            if query_points.shape[1] > split_size:
                query_points1_deformed_list = []
                for query_points0_split_canonicalize in query_points0_canonicalize_list:
                    query_points1_split_deformed = self.model_deform(query_points0_split_canonicalize, surface_samples0_canonicalize, surface_samples[:, 1], surface_mask[:, :, 0:1], surface_normals[:, 0])
                    query_points1_deformed_list.append(query_points1_split_deformed)
                query_points1_deformed = torch.cat(query_points1_deformed_list, dim=1).contiguous()
            else:
                query_points1_deformed = self.model_deform(query_points0_canonicalize, surface_samples0_canonicalize, surface_samples[:, 1], surface_mask[:, :, 0:1], surface_normals[:, 0])
        else:
            if query_points.shape[1] > split_size:
                query_points1_deformed_list = []
                for query_points0_split_canonicalize in query_points0_canonicalize_list:
                    query_points1_split_deformed = self.model_deform(query_points0_split_canonicalize, surface_samples0_canonicalize, surface_samples[:, 1], surface_mask[:, :, 0:1])
                    query_points1_deformed_list.append(query_points1_split_deformed)
                query_points1_deformed = torch.cat(query_points1_deformed_list, dim=1).contiguous()
            else:
                query_points1_deformed = self.model_deform(query_points0_canonicalize, surface_samples0_canonicalize, surface_samples[:, 1], surface_mask[:, :, 0:1])
        
        return query_points0_canonicalize, surface_samples0_canonicalize, query_points1_deformed