import argparse
import os
import numpy as np
import trimesh
import torch
from tqdm import tqdm

from model.net.pointnet2_utils import index_points, farthest_point_sample


if __name__ == '__main__':
    parser = argparse.ArgumentParser('FPS THuman raw scans')
    parser.add_argument(
        '--src_dataset_path',
        type=str,
        default='data/thuman2',
        help='Path to THuman dataset.'
    )
    parser.add_argument(
        '--dst_dataset_path',
        type=str,
        default='data/thuman2_sampling',
        help='Directory path to store FPS data.',
    )
    args = parser.parse_args()
    num_samples = [8_192, 16_384, 32_768, 65_536, 131_072]
    dst_dataset_path_list = []
    for num_sample in num_samples:
        dst_dataset_path = args.dst_dataset_path + f'_{num_sample}'
        os.makedirs(dst_dataset_path, exist_ok=True)
        dst_dataset_path_list.append(dst_dataset_path)

    subject_list = np.loadtxt(os.path.join(args.src_dataset_path, 'all.txt'), dtype=str).tolist()

    for subject in tqdm(subject_list):
        scan_path = os.path.join(args.src_dataset_path, f'scans/{subject}/{subject}.obj')
        scan = trimesh.load(scan_path)
        scan_vertices = np.asarray(scan.vertices)
        scan_vertices = torch.Tensor(scan_vertices).cuda().unsqueeze(0)

        for i, num_sample in enumerate(num_samples):
            sampled_points = index_points(
                scan_vertices,
                farthest_point_sample(scan_vertices, num_sample)
            )

            sampled_points = sampled_points[0].cpu().numpy()

            np.save(os.path.join(dst_dataset_path_list[i], f'{subject}.npy'), sampled_points)
