import sys,os
import math
import torch
from timeit import default_timer as timer


def evaluate(model, criterion, dataloader, batch_num, load_obj=False, use_normals=False):
    dataset_obj = dataloader.dataset
    dataset_batch_size = dataloader.batch_size
    total_size = len(dataset_obj)

    # Losses
    loss_sum = 0.0
    loss_flow_sum = 0.0

    max_num_batches = int(math.ceil(total_size / dataset_batch_size))
    total_num_batches = batch_num if batch_num != -1 else max_num_batches
    total_num_batches = min(max_num_batches, total_num_batches)

    print()

    for i, data in enumerate(dataloader):
        if i >= total_num_batches: 
            break

        sys.stdout.write("\r############# Eval iteration: {0} / {1}".format(i + 1, total_num_batches))
        sys.stdout.flush()

        # Data loading.
        if load_obj:
            flow_faces, flow_edges, flow_vertices, flow_samples,\
                grid, surface_normals, _, rotated2gaps, bbox_lower, bbox_upper, temp_idx, sample_idx = data
            surface_normals         = dataset_obj.unpack(surface_normals).cuda()
           
            flow_faces               = dataset_obj.unpack(flow_faces).cuda()
            flow_edges               = dataset_obj.unpack(flow_edges).cuda()
            flow_vertices            = dataset_obj.unpack(flow_vertices).cuda()
        else:
            uniform_samples, near_surface_samples, surface_samples, flow_samples,\
                grid, surface_normals, _, rotated2gaps, bbox_lower, bbox_upper, temp_idx, sample_idx = data
            surface_normals         = dataset_obj.unpack(surface_normals).cuda()
           
        surface_samples         = dataset_obj.unpack(surface_samples).cuda()
        flow_samples            = dataset_obj.unpack(flow_samples).cuda()    
        grid                    = dataset_obj.unpack(grid).cuda()
        rotated2gaps            = dataset_obj.unpack(rotated2gaps).cuda()

        # Merge uniform and near surface samples.
        batch_size = grid.shape[0]
        
        with torch.no_grad():
            # Compute augmented sdfs.
            sdfs = grid
            
            # Forward pass.
            if use_normals:
                flow_pred = model(flow_samples[:, 0], surface_samples[:, 0], surface_samples[:, 1], surface_samples[:, 2, :, 0:1], surface_normals[:, 0])
            else:
                flow_pred = model(flow_samples[:, 0], surface_samples[:, 0], surface_samples[:, 1], surface_samples[:, 2, :, 0:1])

            loss, loss_flow = criterion(flow_samples[:, 1], flow_pred, eval=True)

            loss_sum += loss.item()
            if loss_flow:       loss_flow_sum += loss_flow.item()
        
    # Losses
    loss_avg = loss_sum / total_num_batches
    loss_flow_avg = loss_flow_sum / total_num_batches

    losses = {
        "total": loss_avg,
        "flow": loss_flow_avg,
    }
    
    # Metrics.
    metrics = {
    }

    return losses, metrics


if __name__ == "__main__":
    pass