import os
from functools import partial
from datetime import datetime
import argparse
import copy
from pathlib import Path
import pickle
from pcdet.datasets import build_dataloader
from pcdet.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file
from pcdet.utils import common_utils, box_utils
from pcdet.models.model_utils.model_nms_utils import class_agnostic_nms
from pcdet.datasets.kitti.kitti_object_eval_python import eval as kitti_eval

import torch
import numpy as np
from easydict import EasyDict
from scipy import integrate
from tqdm import tqdm, trange
import torch.nn.functional as F

import sys
sys.path.insert(0, os.path.expanduser("~/repos/edm"))
from data_utils import forward_transform

sys.path.insert(0, os.path.expanduser("~/repos/modest_pp/generate_cluster_mask/discovery_utils/iou3d_nms"))
from iou3d_nms_utils import boxes_iou3d_gpu


def load_pickle(fname):
    with open(fname, "rb") as f:
        return pickle.load(f)


def get_points_update(
        x_cur, net, t_cur, t_next=0, pp_score=None, mask=None,
        num_steps=64, sigma_min=0.002, sigma_max=80, rho=7,
        S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
    if mask is None:
        mask = torch.zeros(x_cur.shape[:-1], dtype=torch.bool, device=x_cur.device)
    # Increase noise temporarily.
    gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
    t_hat = net.round_sigma(t_cur + gamma * t_cur)
    x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)

    denoised = net(x_hat, t_hat, pp_score=pp_score, mask=mask).to(torch.float64)
    d_cur = (x_hat - denoised) / t_hat
    points_update = (t_next - t_hat) * d_cur
    # scoring = -d_cur / t_hat
    # log_p_update = torch.sum(points_update * scoring)

    if t_next > 0:
        x_next = x_hat + points_update
        denoised = net(x_next, t_next, pp_score=pp_score, mask=mask).to(torch.float64)
        d_prime = (x_next - denoised) / t_next
        points_update = (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
        # scoring = -0.5 * d_cur / t_hat - 0.5 * d_prime / t_next
        # log_p_update = torch.sum(points_update * scoring)
    return points_update


def get_box_update_jac(ptc, box, points_update, pp_score=None, points_mask=None, limit=2.0):
    try:
        jac = torch.autograd.functional.jacobian(lambda box: forward_transform(ptc, box, pp_score, points_mask, limit)[0], box)  # [N, 3, 7]
        box_update = torch.linalg.lstsq(jac.reshape(-1, 7).to(ptc.dtype), points_update.reshape(-1, 1)).solution[:, 0]
    except:
        box_update = torch.zeros_like(box)
    return box_update


@torch.enable_grad()
def get_box_update_opt(ptc, box, points_update, pp_score=None, points_mask=None, limit=2.0):
    box1 = box.clone().detach().contiguous()
    target = forward_transform(ptc, box, pp_score, points_mask, limit)[0] + points_update
    target = target.detach().contiguous()
    for _ in range(1):
        box1.requires_grad = True
        optimizer = torch.optim.LBFGS([box1], lr=0.5, max_iter=20)
        def closure():
            outputs = forward_transform(ptc, box1, pp_score, points_mask, limit)[0]
            loss = F.mse_loss(outputs, target)

            optimizer.zero_grad()
            loss.backward()
            return loss
        optimizer.step(closure)
        box1 = box1.detach().contiguous()
    return box1 - box


@torch.enable_grad()
def get_box_update_opt_fixsize(ptc, box, points_update, pp_score=None, points_mask=None, limit=2.0):
    box1 = box[[0, 1, 2, 6]].clone().detach().contiguous()
    target = forward_transform(ptc, box, pp_score, points_mask, limit)[0] + points_update
    target = target.detach().contiguous()
    for _ in range(1):
        box1.requires_grad = True
        optimizer = torch.optim.LBFGS([box1], lr=0.5, max_iter=20)
        def closure():
            t_box = torch.cat([box1[:3], box[3:6], box1[3:]], dim=0)
            outputs = forward_transform(ptc, t_box, pp_score, points_mask, limit)[0]
            loss = F.mse_loss(outputs, target)

            optimizer.zero_grad()
            loss.backward()
            return loss
        optimizer.step(closure)
        box1 = box1.detach().contiguous()
    return torch.cat([box1[:3], box[3:6], box1[3:]], dim=0) - box


def subsample_indices(n_points, subsample_size=256, seed=-1):
    if seed >= 0:
        rng = np.random.default_rng(seed=seed)
        indices = np.sort(rng.choice(n_points, size=subsample_size, replace=False))
    else:
        indices = np.sort(np.random.choice(n_points, size=subsample_size, replace=False))
    return indices


def to_flattened_numpy(x):
    return x.detach().cpu().numpy().reshape((-1,))

def from_flattened_numpy(x, shape):
    return torch.from_numpy(x.reshape(shape))

def compute_prior_logp(z, sigma_data: float = 0.42):
    shape = z.shape
    N = np.prod(shape[1:])
    logps = -N / 2. * np.log(2 * np.pi) -N * np.log(sigma_data) - torch.sum((z / sigma_data) ** 2, dim=tuple(range(1, z.ndim))) / 2.
    return logps


def estimate_logp(model, data, pp_score=None, repeats=5):
    bpds = []
    with torch.no_grad():
        mask = torch.zeros(data.shape[:-1], dtype=torch.bool, device=data.device)
        shape = data.shape
        init = np.concatenate([to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)

        def drift_fn(model, x, t):
            denoised = model(x, t, pp_score=pp_score, mask=mask)
            return (x - denoised) / t

        def div_fn(model, x, t, noise):
            with torch.enable_grad():
                x.requires_grad_(True)
                fn_eps = torch.sum(drift_fn(model, x, t) * noise)
                grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
            x.requires_grad_(False)
            return torch.sum(grad_fn_eps * noise, dim=tuple(range(1, len(x.shape))))

        for _ in range(repeats):
            epsilon = torch.randint_like(data, low=0, high=2).to(device=data.device, dtype=data.dtype) * 2 - 1.

            def ode_func(t, x):
                sample = from_flattened_numpy(x[:-shape[0]], shape).to(device=data.device, dtype=data.dtype)
                vec_t = torch.ones(sample.shape[0], device=data.device, dtype=data.dtype) * t
                drift = to_flattened_numpy(drift_fn(model, sample, vec_t))
                logp_grad = to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
                return np.concatenate([drift, logp_grad], axis=0)

            solution = integrate.solve_ivp(ode_func, (0.002, 80), init, rtol=1e-2, atol=1e-2, method='RK45')
            zp = solution.y[:, -1]
            z = from_flattened_numpy(zp[:-shape[0]], shape).to(device=data.device, dtype=data.dtype)
            delta_logp = from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(device=data.device, dtype=data.dtype)
            prior_logp = compute_prior_logp(z)
            bpd = -(prior_logp + delta_logp) / np.log(2) / np.prod(shape[1:]) + 8
            bpds.append(bpd.item())
    return np.mean(bpds)


def estimate_logp_v2(
    model,
    data,
    pp_score=None,
    repeats=5,
    num_steps=64,
    device=torch.device("cuda"),
    sigma_max=80,
    sigma_min=0.002,
    rho=7,
):
    with torch.no_grad():
        x_points = data.repeat(repeats, 1, 1).clone()
        mask = torch.zeros(x_points.shape[:-1], dtype=torch.bool, device=data.device)
        if pp_score is not None:
            pp_score = pp_score.repeat(repeats, 1, 1)
        epsilon = torch.randint_like(x_points, low=0, high=2).to(device=data.device, dtype=data.dtype) * 2 - 1.

        step_indices = torch.arange(num_steps - 1, -1, -1, device=device)
        t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
                sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho

        delta_logp = 0

        for (t_cur, t_next) in zip(t_steps[:-1], t_steps[1:]):
            with torch.enable_grad():
                x_points.requires_grad_(True)
                denoised = model(
                    x_points,
                    torch.ones(repeats, device=data.device, dtype=data.dtype) * t_cur,
                    pp_score=pp_score,
                    mask=mask)
                d_cur = (x_points - denoised) / t_cur
                points_update = (t_next - t_cur) * d_cur

                #                 if False:
                #                     x_next = x_points + points_update
                #                     denoised = model(
                #                         x_next,
                #                         torch.ones(repeats, device=data.device, dtype=data.dtype) * t_next,
                #                         pp_score=pp_score,
                #                         mask=mask)
                #                     d_prime = (x_next - denoised) / t_next
                #                     points_update = (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)
                fn_eps = torch.sum(points_update * epsilon)
                grad_fn_eps = torch.autograd.grad(fn_eps, x_points)[0]
            x_points.requires_grad_(False)
            x_points += points_update
            delta_logp += torch.sum(grad_fn_eps * epsilon, dim=tuple(range(1, len(x_points.shape)))) * (t_next - t_cur)

        prior_logp = compute_prior_logp(x_points)
        bpd = -(prior_logp + delta_logp) / np.log(2) / np.prod(x_points.shape[1:]) + 8
        return torch.mean(bpd).item()


def parse_config():
    parser = argparse.ArgumentParser(description='Evaluate denoiser')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--dataset', type=str, default='lyft', choices=["kitti", "lyft", "ithaca365"])
    parser.add_argument('--trained-on-dataset', type=str, default="kitti")
    parser.add_argument('--model', type=str, default='pointrcnn', choices=["pointrcnn", "pv_rcnn"])
    parser.add_argument('--category', type=str, default='car', choices=["car", "cyclist", "pedestrian"])
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--subset', type=str, default="all")
    parser.add_argument('--ds-lo', type=int, default=0)
    parser.add_argument('--ds-hi', type=int, default=None)

    parser.add_argument('--num-steps', type=int, default=16)
    parser.add_argument('--early-stop', type=int, default=2)
    parser.add_argument('--sigma_min', type=float, default=0.02)
    parser.add_argument('--sigma_max', type=float, default=80)
    parser.add_argument('--sigma_max_lb', type=float, default=10.)
    parser.add_argument('--shape-weight', type=float, default=0.1)
    parser.add_argument('--min-score', type=float, default=0.0)
    parser.add_argument('--max-score', type=float, default=1.0)
    parser.add_argument('--jac', action="store_true")  # very slow, doesn't make much difference
    parser.add_argument('--use-pp', action="store_true")
    parser.add_argument('--use-sn', action="store_true")
    parser.add_argument('--use-ot', action="store_true")
    parser.add_argument('--use-roteda', action='store_true')
    parser.add_argument('--use-st3d', action='store_true')
    parser.add_argument('--freeze-size', action="store_true")
    parser.add_argument(
        '--noise-src', type=str, default="poly",
        choices=["GT", "fixed", "linear", "poly", "arctanh", "ODE"])
    parser.add_argument(
        '--score-src', type=str, default="original",
        choices=["original", "GT", "ODE"])  # being fixed, don't use now
    parser.add_argument('--rho', type=float, default=1)
    parser.add_argument('--perturb', action='store_true')
    parser.add_argument('--perturb-noise-level', type=float, default=0.3) # perturb the location
    parser.add_argument('--position-only', action='store_true') # do not update the size of the predicted box
    parser.add_argument('--repeats', type=int, default=1)
    parser.add_argument('--use-gt-boxes', action='store_true')
    parser.add_argument('--eval-only', action='store_true')
    parser.add_argument('--context-limit', type=float, default=4.0)
    parser.add_argument('--disable-box-norm', action='store_true')
    parser.add_argument('--sn-overfit', action='store_true')
    parser.add_argument('--inc-size-noise', action='store_true')

    parser.add_argument('--custom-det-path', default=None, type=str)
    parser.add_argument('--custom-cfg-file', default=None, type=str)
    parser.add_argument('--extra-tag', default="", type=str)
    parser.add_argument('--apply-nms', action='store_true', default=False)

    args = parser.parse_args()
    assert not (args.use_sn and args.use_ot)
    assert not (args.use_sn and args.use_roteda)
    assert not (args.use_sn and args.use_st3d)

    assert not (args.use_ot and args.use_roteda)
    assert not (args.use_ot and args.use_st3d)

    assert not (args.use_roteda and args.use_st3d)

    if args.use_roteda:
        assert args.trained_on_dataset == 'kitti'
    if args.use_st3d:
        assert args.dataset == 'lyft'
        assert args.trained_on_dataset == 'kitti'
        

    if args.trained_on_dataset is None:
        args.trained_on_dataset = args.dataset
    if args.dataset == args.trained_on_dataset:
        assert not args.use_sn and not args.use_ot
    args.position_only = args.position_only or args.disable_box_norm
    if args.sn_overfit:
        args.use_sn = True
        args.num_steps = 1
        args.early_stop = 0
        args.max_score = 1.0

    # cfg is only used to load test set, model doesn't matter here
    if args.custom_cfg_file is not None:
        cfg_file = args.custom_cfg_file
    elif args.dataset == "kitti":
        cfg_file = f"~/repos/edm/OpenPCDet/tools/cfgs/kitti_models/pv_rcnn.yaml"
    elif args.dataset == "lyft":
        if args.trained_on_dataset == "ithaca365":
            cfg_file = os.path.expanduser("~/repos/edm/OpenPCDet/tools/cfgs/lyft_models/pointrcnn_xyz2.yaml")
        else:
            cfg_file = os.path.expanduser("~/repos/edm/OpenPCDet/tools/cfgs/lyft_models/pointrcnn_xyz.yaml")
    elif args.dataset == "ithaca365":
        # cfg_file = os.path.expanduser("~/repos/edm/OpenPCDet/tools/cfgs/ithaca365_models/pointrcnn_source.yaml")
        cfg_file = os.path.expanduser("~/box-diffusion/OpenPCDet/tools/cfgs/ithaca_models_kitti_format/pointrcnn_xyz_dense.yaml")
    else:
        raise NotImplementedError
    cfg_from_yaml_file(cfg_file, cfg)
    cfg.TAG = Path(cfg_file).stem
    cfg.EXP_GROUP_PATH = '/'.join(cfg_file.split('/')[1:-1])

    torch.manual_seed(args.seed)

    return args, cfg


def main():
    device = torch.device("cuda")
    args, cfg = parse_config()
    if args.jac:
        get_box_update = get_box_update_jac
    elif args.freeze_size:
        get_box_update = get_box_update_opt_fixsize
    else:
        get_box_update = get_box_update_opt

    S_churn = 0
    S_min = 0
    S_max = float('inf')
    S_noise = 1

    logger = common_utils.create_logger()
    test_set, _, _ = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        class_names=cfg.CLASS_NAMES,
        batch_size=1,
        dist=False, workers=4, logger=logger, training=False
    )
    eval_gt_annos = [copy.deepcopy(info['annos']) for info in test_set.kitti_infos]
    args.subset = len(test_set) if args.subset == "all" else int(args.subset)
    # eval_gt_annos = eval_gt_annos[:args.subset]
    if args.ds_hi is None:
        args.ds_hi = len(test_set)
    eval_gt_annos = eval_gt_annos[args.ds_lo:args.ds_hi]

    shape_templates = {
        "kitti": {
            "car": torch.tensor([3.89, 1.62, 1.53]).to(device),
            "cyclist": torch.tensor([1.77, 0.57, 1.72]).to(device),
            "pedestrian": torch.tensor([0.82, 0.63, 1.77]).to(device),
        },
        "lyft": {
            "car": torch.tensor([4.74, 1.91, 1.71]).to(device),
            "cyclist": torch.tensor([1.75, 0.61, 1.36]).to(device),
            "pedestrian": torch.tensor([0.80, 0.78, 1.74]).to(device),
        },
        "ithaca365": {
            "car": torch.tensor([4.41, 1.75, 1.55]).to(device),
            "cyclist": torch.tensor([1.70, 0.70, 1.53]).to(device),
            "pedestrian": torch.tensor([0.60, 0.61, 1.70]).to(device),
        },
    }
    shape_template = shape_templates[args.dataset][args.category]

    if args.custom_det_path is not None:
        det_path = args.custom_det_path
    elif args.dataset == "kitti":
        assert not args.use_pp
        if args.trained_on_dataset == "kitti":
            if args.model == "pv_rcnn":
                det_path = os.path.join(
                    "~/box-diffusion/OpenPCDet/output/kitti_models/pv_rcnn_xyz/kitti2kitti_scorethresh0/eval/epoch_80/val/default",
                    "result.pkl"
                )
            elif args.model == "pointrcnn":
                det_path = os.path.join(
                    "~/box-diffusion/OpenPCDet/output/kitti_models/pointrcnn_xyz/kitti2kitti_scorethresh0/eval/epoch_80/val/default",
                    "result.pkl"
                )
        elif args.trained_on_dataset == "lyft":
            if args.use_sn:
                if args.model == "pv_rcnn":
                    raise NotImplementedError
                elif args.model == "pointrcnn":
                    raise NotImplementedError
            else:
                if args.model == "pv_rcnn":
                    raise NotImplementedError
                elif args.model == "pointrcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/kitti_models/pointrcnn_xyz/sparse_range_lyft2kitti_eval/eval/epoch_no_number/val/default",
                        "result.pkl"
                    )
        else:
            raise NotImplementedError

    elif args.dataset == "lyft":
        if args.trained_on_dataset == "kitti":
            assert not args.use_pp
            if args.use_sn:
                if args.model == "pv_rcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/lyft_models/pv_rcnn_xyz/test_SN_kitti2lyft/eval/epoch_80/val/default",
                        "result.pkl"
                    )
                elif args.model == "pointrcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/lyft_models/pointrcnn_xyz_dense/test_SN_kitti2lyft_dense/eval/epoch_80/val/default",
                        "result.pkl"
                    )
            
            elif args.use_roteda:
                if args.model == 'pv_rcnn':
                    raise NotImplementedError
                elif args.model == "pointrcnn":
                    det_path = "~/box-diffusion/OpenPCDet/output/lyft_models/pointrcnn_xyz_dense/kitti2lyft_roteda_scorethresh0/eval/epoch_10/val/default/result.pkl"
                    # det_path = "~/adaptation/downstream/OpenPCDet/output/lyft_models/pointrcnn_dense_point_load_p2_1ep_v4_p2_filter_no_margin/pointrcnn_dense_point_load_p2_alpha03_run1_10/eval/eval_with_train/epoch_1/val/result.pkl"
            
            elif args.use_st3d:
                if args.model == 'pv_rcnn':
                    raise NotImplementedError
                elif args.model == 'pointrcnn':
                    det_path = "~/box-diffusion/OpenPCDet/output/lyft_models/pointrcnn_xyz_dense/kitti2lyft_st3d_scorethresh0/eval/epoch_30/val/default/result.pkl"
            
            else:
                if args.model == "pv_rcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/lyft_models/pv_rcnn_xyz/kitti2lyft_scorethresh0/eval/epoch_80/val/default",
                        "result.pkl"
                    )
                elif args.model == "pointrcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/lyft_models/pointrcnn_xyz_dense/dense_kitti2lyft_scorethresh0/eval/epoch_80/val/default",
                        "result.pkl"
                    )
        elif args.trained_on_dataset == "lyft":
            if args.use_pp:
                if args.model == "pv_rcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/lyft_models/pv_rcnn_p2/range_lyft2lyft_p2_scorethresh0/eval/epoch_60/val/default",
                        "result.pkl"
                    )
                elif args.model == "pointrcnn":
                    raise NotImplementedError
            else:
                if args.model == "pv_rcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/lyft_models/pv_rcnn_xyz/range_lyft2lyft_scorethresh0/eval/epoch_60/val/default",
                        "result.pkl"
                    )
                elif args.model == "pointrcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/lyft_models/pointrcnn_xyz_dense/default/eval/eval_with_train/epoch_60/val",
                        "result.pkl"
                    )
        elif args.trained_on_dataset == "ithaca365":
            if args.use_pp:
                if args.model == "pv_rcnn":
                    raise NotImplementedError
                elif args.model == "pointrcnn":
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/lyft_models/pointrcnn_xyz2/range_ithaca2lyft_eval/eval/epoch_60/val/default/",
                        "result.pkl"
                    )
            else:
                if args.model == "pv_rcnn":
                    raise NotImplementedError
                elif args.model == "pointrcnn":
                    raise NotImplementedError

    elif args.dataset == "ithaca365":
        if args.trained_on_dataset == "kitti":
            assert not args.use_pp
            if args.use_sn:
                if args.model == 'pv_rcnn':
                    raise NotImplementedError
                elif args.model == 'pointrcnn':
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/ithaca_models_kitti_format/pointrcnn_xyz_dense/kitti2ithaca_SN/eval/epoch_80/test/default", 
                        "result.pkl"
                    )
            elif args.use_roteda:
                if args.model == 'pv_rcnn':
                    raise NotImplementedError
                elif args.model == 'pointrcnn':
                    det_path = "~/box-diffusion/OpenPCDet/output/ithaca_models_kitti_format/pointrcnn_xyz2_dense/kitti2ithaca_roteda_scorethresh0/eval/epoch_10/test/default/result.pkl"
            else:
                if args.model == 'pv_rcnn':
                    raise NotImplementedError
                elif args.model == 'pointrcnn':
                    det_path = os.path.join(
                        "~/box-diffusion/OpenPCDet/output/ithaca_models_kitti_format/pointrcnn_xyz_dense/kitti2ithaca/eval/epoch_80/test/default",
                        "result.pkl"
                    )
            
        elif args.trained_on_dataset == "lyft":
            if args.use_pp:
                if args.use_sn:
                    if args.model == "pv_rcnn":
                        raise NotImplementedError
                    elif args.model == "pointrcnn":
                        raise NotImplementedError
                else:
                    if args.model == "pv_rcnn":
                        raise NotImplementedError
                    elif args.model == "pointrcnn":
                        det_path = os.path.join(
                            "~/box-diffusion/OpenPCDet/output/ithaca_models_kitti_format/pointrcnn_dense_p2/lyft2ithaca_p2_kittiformat_scorethresh0/eval/epoch_60/test/default",
                            "result.pkl"
                        )
            else:
                if args.use_sn:
                    if args.model == "pv_rcnn":
                        raise NotImplementedError
                    elif args.model == "pointrcnn":
                        raise NotImplementedError
                else:
                    if args.model == "pv_rcnn":
                        raise NotImplementedError
                    elif args.model == "pointrcnn":
                        det_path = os.path.join(
                            "~/box-diffusion/OpenPCDet/output/ithaca_models_kitti_format/pointrcnn_xyz_dense/range_lyft2ithaca_kittiformat_xyz_dense_scorethresh0/eval/epoch_60/test/default",
                            "result.pkl"
                        )
        elif args.trained_on_dataset == "ithaca365":
            if args.use_pp:
                if args.model == "pv_rcnn":
                    raise NotImplementedError
                elif args.model == "pointrcnn":
                    raise NotImplementedError
            else:
                if args.model == "pv_rcnn":
                    raise NotImplementedError
                elif args.model == "pointrcnn":
                    raise NotImplementedError

    if args.ckpt is None:
        if args.trained_on_dataset == "kitti":
            if args.category == "car":
                if args.context_limit == 3.0:
                    args.ckpt = "kitti-v2-car-gpus1-batch128context_3.0-00000/training-state-086220.pt"
                elif args.context_limit == 4.0:
                    args.ckpt = "kitti-v2-car-gpus1-batch128context_4.0-00000/training-state-111820.pt"
                elif args.context_limit == 6.0:
                    args.ckpt = 'kitti-v2-car-gpus1-batch128context_6.0-00000/training-state-128716.pt'
                elif args.freeze_size:
                    args.ckpt = "kitti-v2-car-gpus1-batch128-00005/training-state-079052.pt"
                elif args.disable_box_norm:
                    args.ckpt = "kitti-v2-car-gpus1-batch128-00006/training-state-091340.pt"
                elif args.sn_overfit:
                    args.ckpt = "kitti-v2-car-gpus1-batch128-overfit-00000/training-state-046489.pt"
                elif args.inc_size_noise:
                    args.ckpt = "kitti-v2-car-gpus1-batch128-size_noise++-00001/training-state-060825.pt"
                else:
                    args.ckpt = "kitti-v2-car-gpus1-batch128-00004/training-state-078028.pt"
            elif args.category == "cyclist":
                if args.context_limit == 2.0:
                    args.ckpt = "kitti-v2-cyclist-gpus1-batch128-00002/training-state-045669.pt"
                elif args.context_limit == 4.0:
                    args.ckpt = "kitti-v2-cyclist-gpus1-batch128context_4.0-00000/training-state-100351.pt"
                else:
                    raise NotImplementedError
            elif args.category == "pedestrian":
                if args.context_limit == 2.0:
                    args.ckpt = "kitti-pedestrian-gpus1-batch128-00001/training-state-096050.pt"
                elif args.context_limit == 4.0:
                    args.ckpt = "kitti-v2-pedestrian-gpus1-batch128context_4.0-00000/training-state-052019.pt"
                else:
                    raise NotImplementedError

        elif args.trained_on_dataset == "lyft":
            if args.use_pp:
                if args.category == "car":
                    args.ckpt = "lyft_4-uncond-td-edm-gpus1-batch128-fp32-00005/training-state-066457.pt"
                elif args.category == "cyclist":
                    args.ckpt = "lyft-cyclist-usepp-gpus1-batch128-00001/training-state-079462.pt"
                elif args.category == "pedestrian":
                    args.ckpt = "lyft-pedestrian-usepp-gpus1-batch128-00000/training-state-027750.pt"
            else:
                if args.category == "car":
                    args.ckpt = "lyft-car-gpus1-batch128-00001/training-state-026214.pt"
                elif args.category == "cyclist":
                    args.ckpt = "lyft-cyclist-gpus1-batch128-00001/training-state-078335.pt"
                elif args.category == "pedestrian":
                    args.ckpt = "lyft-pedestrian-gpus1-batch128-00000/training-state-026316.pt"

        elif args.trained_on_dataset == "ithaca365":
            if args.use_pp:
                if args.category == "car":
                    args.ckpt = "ithaca365-v2-car-usepp-gpus1-batch128-00003/training-state-011673.pt"
                elif args.category == "cyclist":
                    args.ckpt = "ithaca365-v2-cyclist-usepp-gpus1-batch128-00001/training-state-011673.pt"
                elif args.category == "pedestrian":
                    args.ckpt = "ithaca365-v2-pedestrian-usepp-gpus1-batch128-00001/training-state-011980.pt"
            else:
                if args.category == "car":
                    args.ckpt = "ithaca365-v2-car-gpus1-batch128-00001/training-state-012083.pt"
                elif args.category == "cyclist":
                    args.ckpt = "ithaca365-v2-cyclist-gpus1-batch128-00000/training-state-052633.pt"
                elif args.category == "pedestrian":
                    args.ckpt = "ithaca365-v2-pedestrian-gpus1-batch128-00001/training-state-012083.pt"
        else:
            raise NotImplementedError

    print(f"{args.trained_on_dataset} -> {args.dataset}, {args.category}")
    print(f"ckpt: {args.ckpt}")
    print(f"det_path: {det_path}")
    print("apply_nms:", args.apply_nms)

    args.ckpt = os.path.join(os.path.expanduser("~/repos/edm/training-runs"), args.ckpt)
    assert os.path.isfile(args.ckpt)
    net = torch.load(args.ckpt)["net"].eval()

    det_annos = load_pickle(det_path)
    assert len(det_annos) == len(test_set)

    # _, ap_dict_b4 = kitti_eval.get_official_eval_result(
    #     eval_gt_annos, det_annos[:args.subset], cfg.CLASS_NAMES)
    _, ap_dict_b4 = kitti_eval.get_official_eval_result(
        eval_gt_annos, det_annos[args.ds_lo:args.ds_hi], cfg.CLASS_NAMES)


    timestep_ious = [[] for _ in range(args.num_steps - args.early_stop + 1)]
    with torch.no_grad():
        for frame_id in trange(len(eval_gt_annos)):
            det_anno = det_annos[frame_id]
            det_mask = det_anno['name'] == args.category.title()
            if np.sum(det_mask) == 0:
                continue

            gt_anno = test_set[frame_id]
            gt_mask = gt_anno['gt_boxes'][:, 7] == cfg.CLASS_NAMES.index(args.category.title()) + 1
            if np.sum(gt_mask) == 0:
                continue
            gt_boxes = torch.from_numpy(gt_anno['gt_boxes'][gt_mask, :7]).float().cuda()

            iou_before_raw = boxes_iou3d_gpu(
                torch.from_numpy(det_anno['boxes_lidar'][det_mask]).float().cuda(),
                gt_boxes)
            iou_before = torch.max(iou_before_raw, dim=0).values.mean().item() if np.sum(gt_mask) > 0 else -1

            ptc = torch.from_numpy(gt_anno['points'][:, :3]).float().cuda()
            pp_score = None

            for obj_cnt, obj_id in enumerate(np.where(det_mask)[0]):
                if det_anno['score'][obj_id] < args.min_score or det_anno['score'][obj_id] > args.max_score:
                    continue
                x_box = torch.from_numpy(det_anno['boxes_lidar'][obj_id]).float().cuda().clone().detach().contiguous()
                if args.use_ot:
                    x_box[3:6] *= shape_templates[args.dataset][args.category] / shape_templates[args.trained_on_dataset][args.category]
                iou = boxes_iou3d_gpu(x_box.unsqueeze(dim=0), gt_boxes)[0]
                gt_idx = torch.argmax(iou).item()
                if iou[gt_idx] <= 1e-6:
                    continue
                timestep_ious[0].append(float(iou[gt_idx]))

                sigma_min = max(args.sigma_min, net.sigma_min)
                if args.use_gt_boxes:
                    x_box = gt_boxes[gt_idx].clone().detach().contiguous()
                    x_box[6] += torch.randn(1)[0].to(x_box.device) * 0.01  # not sure why but without this noise AP will drop A LOT
                elif not args.eval_only:
                    if args.noise_src == "fixed":
                        sigma_max = min(args.sigma_max, net.sigma_max)
                    elif args.noise_src == "linear":
                        sigma_max = (args.sigma_max - args.sigma_min) * (1 - det_anno['score'][obj_id]) + args.sigma_min
                        sigma_min = 0.002
                    elif args.noise_src == "poly":
                        if args.sigma_max_lb is None:
                            sigma_max_lb = sigma_min
                        else:
                            sigma_max_lb = args.sigma_max_lb
                        sigma_max = (args.sigma_max ** (1 / args.rho) + det_anno['score'][obj_id] * (
                                    sigma_max_lb ** (1 / args.rho) - args.sigma_max ** (1 / args.rho))) ** args.rho
                    elif args.noise_src == "arctanh":
                        sigma_max = np.arctanh(min(1 - det_anno['score'][obj_id], 1 - 1e-4)) * args.sigma_max  # arctanh
                    elif args.noise_src == "GT":
                        if gt_boxes.shape[0] == 0:
                            sigma_max = args.sigma_max
                        else:
                            iou = boxes_iou3d_gpu(x_box.unsqueeze(dim=0), gt_boxes)
                            if torch.max(iou).item() == 0:
                                sigma_max = args.sigma_max
                            else:
                                gt_box = gt_boxes[torch.argmax(iou, dim=1).item()]
                                x, y = x_box[0:1] - gt_box[0:1], x_box[1:2] - gt_box[1:2]
                                box_diff = torch.cat([
                                    x * torch.cos(gt_box[6:7]) + y * torch.sin(gt_box[6:7]),
                                    -x * torch.sin(gt_box[6:7]) + y * torch.cos(gt_box[6:7]),
                                    x_box[2:3] - gt_box[2:3],
                                    torch.log(x_box[3:6] / gt_box[3:6]),
                                    (x_box[6:7] - gt_box[6:7] + np.pi / 2) % np.pi - np.pi / 2
                                ])
                                noise_level = torch.tensor((0.7, 0.35, 0.15, 0.2, 0.1, 0.1, 0.3)).to(dtype=x_box.dtype, device=x_box.device)
                                noise = box_diff / 2 / noise_level * 80
                                sigma_max = torch.sqrt(torch.mean(noise * noise)).item()

                                # info = f"score {det_anno['score'][obj_id]:0.3f}, iou {torch.max(iou).item():0.3f}, sigma {sigma_max:0.3f}"
                                # print(info, noise)
                                # with open(f"{args.dataset}_{args.category}.txt", "a") as f:
                                #     f.write(f"{info}\n")

                    elif args.noise_src == "ODE":
                        x_points, x_points_pp, x_points_mask = forward_transform(ptc, x_box, pp_score, limit=args.context_limit)

                        n_points_orig = x_points.shape[0]
                        if x_points.shape[0] > 1024:
                            indices = subsample_indices(n_points_orig, subsample_size=1024, seed=-1)
                            x_points = x_points[indices, :]
                            x_points_pp = x_points_pp[indices] if x_points_pp is not None else None
                            indices_not_selected = np.delete(np.arange(n_points_orig), indices)
                            x_points_mask_indices_not_selected = np.where(x_points_mask.detach().cpu().numpy())[0][
                                indices_not_selected]
                            x_points_mask[x_points_mask_indices_not_selected] = False

                        x_points = x_points.unsqueeze(dim=0)
                        x_points_pp = x_points_pp.unsqueeze(dim=-1).unsqueeze(dim=0) if x_points_pp is not None else None

                        score = 20 - estimate_logp_v2(
                            net,
                            x_points,
                            pp_score=x_points_pp,
                            repeats=5,
                            num_steps=64,
                            device=torch.device("cuda"),
                            sigma_max=80,
                            sigma_min=0.002,
                            rho=7,
                        )
                        sigma_max = max(args.sigma_max, np.clip(score, 2, 80))

                    for r in range(args.repeats):
                        step_indices = torch.arange(args.num_steps, device=device)
                        if args.sn_overfit:
                            # denoise 1 step
                            t_steps = torch.tensor([1, 0]).float().cuda()
                        else:
                            t_steps = (sigma_max ** (1 / args.rho) + step_indices / (args.num_steps - 1) * (
                                        sigma_min ** (1 / args.rho) - sigma_max ** (1 / args.rho))) ** args.rho
                            t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])  # t_N = 0

                        if args.perturb:
                            noise_xyz = args.perturb_noise_level * (1 - det_anno['score'][obj_id])
                            noise_xyz = torch.randn(3).to(x_box.dtype).to(x_box.device) * noise_xyz
                            x_box[:3] += noise_xyz

                        if args.disable_box_norm:
                            original_size = x_box[3:6]
                            x_box[3:6] = torch.ones_like(x_box[3:6])

                        for i, (t_cur, t_next) in list(enumerate(zip(t_steps[:-1], t_steps[1:])))[:args.num_steps-args.early_stop]:
                            x_points, x_points_pp, x_points_mask = forward_transform(ptc, x_box, pp_score, limit=args.context_limit)

                            n_points_orig = x_points.shape[0]
                            if x_points.shape[0] > 1024:
                                indices = subsample_indices(n_points_orig, subsample_size=1024, seed=-1)
                                x_points = x_points[indices, :]
                                x_points_pp = x_points_pp[indices] if x_points_pp is not None else None
                                indices_not_selected = np.delete(np.arange(n_points_orig), indices)
                                x_points_mask_indices_not_selected = np.where(x_points_mask.detach().cpu().numpy())[0][
                                    indices_not_selected]
                                x_points_mask[x_points_mask_indices_not_selected] = False

                            x_points = x_points.unsqueeze(dim=0)
                            x_points_pp = x_points_pp.unsqueeze(dim=-1).unsqueeze(dim=0) if x_points_pp is not None else None

                            points_update = get_points_update(
                                x_points, net, t_cur, t_next, pp_score=x_points_pp,
                                num_steps=args.num_steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=args.rho,
                                S_churn=S_churn, S_min=S_min, S_max=S_max, S_noise=S_noise,
                            )
                            box_update = get_box_update(ptc, x_box, points_update.float()[0], pp_score, x_points_mask, limit=args.context_limit).float()
                            if args.position_only:
                                x_box[:3] = x_box[:3] + box_update[:3]
                                x_box[6:] = x_box[6:] + box_update[6:]
                            else:
                                x_box = x_box + box_update
                                x_box[3:6] = x_box[3:6] * (1 - args.shape_weight) + args.shape_weight * shape_template

                            timestep_ious[i+1].append(float(torch.max(boxes_iou3d_gpu(x_box.unsqueeze(dim=0), gt_boxes)[0]).item()))

                        with open(os.path.expanduser("~/scratch/box_time.pkl"), "wb") as f:
                            pickle.dump(timestep_ious, f)

                        if args.disable_box_norm:
                            x_box[3:6] = original_size

                det_anno['boxes_lidar'][obj_id] = x_box.detach().cpu().numpy()

                if args.score_src == "GT":
                    det_anno['score'][obj_id] = torch.max(boxes_iou3d_gpu(x_box.unsqueeze(dim=0), gt_boxes)).clamp(0, 1).item() if gt_boxes.shape[0] > 0 else 0.0
                elif args.score_src == "ODE":
                    x_points, x_points_pp, x_points_mask = forward_transform(ptc, x_box, pp_score, limit=args.context_limit)

                    n_points_orig = x_points.shape[0]
                    if x_points.shape[0] > 1024:
                        indices = subsample_indices(n_points_orig, subsample_size=1024, seed=-1)
                        x_points = x_points[indices, :]
                        x_points_pp = x_points_pp[indices] if x_points_pp is not None else None
                        indices_not_selected = np.delete(np.arange(n_points_orig), indices)
                        x_points_mask_indices_not_selected = np.where(x_points_mask.detach().cpu().numpy())[0][
                            indices_not_selected]
                        x_points_mask[x_points_mask_indices_not_selected] = False

                    x_points = x_points.unsqueeze(dim=0)
                    x_points_pp = x_points_pp.unsqueeze(dim=-1).unsqueeze(dim=0) if x_points_pp is not None else None

                    # det_anno['score'][obj_id] = estimate_logp(net, x_points, pp_score=x_points_pp)
                    score = 100 - estimate_logp_v2(
                        net,
                        x_points,
                        pp_score=x_points_pp,
                        repeats=5,
                        num_steps=64,
                        device=torch.device("cuda"),
                        sigma_max=80,
                        sigma_min=0.002,
                        rho=7,
                    )
                    print(f"score {det_anno['score'][obj_id]:0.2f} -> {score}")
                    det_anno['score'][obj_id] = score

            if args.score_src != "original":
                score_mask = det_anno['score'] > cfg.MODEL.POST_PROCESSING.SCORE_THRESH
                det_anno['boxes_lidar'] = det_anno['boxes_lidar'][score_mask]
                det_anno['score'] = det_anno['score'][score_mask]
                det_anno['bbox'] = det_anno['bbox'][score_mask]
                det_anno['name'] = det_anno['name'][score_mask]
                det_anno['alpha'] = det_anno['alpha'][score_mask]
                det_anno['rotation_y'] = det_anno['rotation_y'][score_mask]
                det_mask = det_mask[score_mask]

            if args.apply_nms and np.sum(det_mask) > 0:
                selected, _ = class_agnostic_nms(
                    box_scores=torch.from_numpy(det_anno['score']).cuda().float(),
                    box_preds=torch.from_numpy(det_anno['boxes_lidar']).cuda().float(),
                    nms_config=EasyDict(
                        NMS_PRE_MAXSIZE=det_anno['score'].shape[0],
                        NMS_TYPE='nms_gpu',
                        NMS_THRESH=0.1,
                        NMS_POST_MAXSIZE=det_anno['score'].shape[0])
                )
                selected = selected.detach().cpu().numpy()
                det_anno['boxes_lidar'] = det_anno['boxes_lidar'][selected]
                det_anno['score'] = det_anno['score'][selected]
                det_anno['bbox'] = det_anno['bbox'][selected]
                det_anno['name'] = det_anno['name'][selected]
                det_anno['alpha'] = det_anno['alpha'][selected]
                det_anno['rotation_y'] = det_anno['rotation_y'][selected]
                det_mask = det_mask[selected]

            pred_boxes_camera = box_utils.boxes3d_lidar_to_kitti_camera(det_anno['boxes_lidar'], gt_anno['calib'])
            det_anno['dimensions'] = pred_boxes_camera[:, 3:6]
            det_anno['location'] = pred_boxes_camera[:, 0:3]
            det_anno['alpha'] = det_anno['alpha'] + pred_boxes_camera[:, 6] - det_anno['rotation_y']
            det_anno['rotation_y'] = pred_boxes_camera[:, 6]

            iou_after = torch.max(
                boxes_iou3d_gpu(
                    torch.from_numpy(det_anno['boxes_lidar'][det_mask]).float().cuda(),
                    gt_boxes), dim=0
            ).values.mean().item() if np.sum(gt_mask) > 0 and np.sum(det_mask) > 0 else -1
            print(f"{float(iou_before)} -> {float(iou_after)}")

    # _, ap_dict_after = kitti_eval.get_official_eval_result(
    #     eval_gt_annos, det_annos[:args.subset], cfg.CLASS_NAMES)
    _, ap_dict_after = kitti_eval.get_official_eval_result(
        eval_gt_annos, det_annos[args.ds_lo:args.ds_hi], cfg.CLASS_NAMES)

    for k, v in vars(args).items():
        print(f"{k}: {v}")

    for k in ap_dict_after:
        if args.category.title() in k and ("3d" in k or "bev" in k) and "R40" in k:
            print(f"{k}: {ap_dict_b4[k]:0.2f} -> {ap_dict_after[k]:0.2f}")

    info_str = []
    for suffix, iou_threshold in zip(["_R40", "_R40_0.5", "_R40_0.25"], ["0.7", "0.5", "0.25"]):
        s0 = [f"mAP@{iou_threshold}: "]
        for dist in ["0-30m", "30-50m", "50-80m", "0-80m"]:
            s = []
            for metric in ["bev", "3d"]:
                k = f"{args.category.title()}_{metric}/{dist}{suffix}"
                s.append(f"{ap_dict_after[k]:0.2f}")
            s0.append(" / ".join(s))
        info_str.append("\t".join(s0))
    info_str = "\n".join(info_str)

    save_dir = os.path.expanduser("~/scratch/box_diffusion_outputs")
    os.makedirs(save_dir, exist_ok=True)
    # import ipdb; ipdb.set_trace()
    fname = f"{args.dataset}_{args.category}_{args.model}_{datetime.today().strftime('%Y-%m-%d-%H:%M:%S')}"
    save_path = os.path.join(save_dir, fname + ".pkl")
    with open(save_path, "wb") as f:
        pickle.dump(det_annos, f)

    save_path = os.path.join(save_dir, fname + ".txt")
    with open(save_path, "w") as f:
        f.write(f"det_path: {det_path}\n")
        for k, v in vars(args).items():
            f.write(f"{k}: {v}\n")
        f.write(info_str)
        try:
            f.write(f"\n\nextra_tag: {args.extra_tag}\n")
        except:
            pass
    print(info_str)
    # import ipdb; ipdb.set_trace()
    print(f"scp g2:{save_path} . && tail -n 3 {fname}.txt")

if __name__ == "__main__":
    main()