import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import math

if 'models' in sys.modules :
	sys.modules.pop('models')
if 'models.utils' in sys.modules :
	sys.modules.pop('models.utils')
if 'models.submodule' in sys.modules :
	sys.modules.pop('models.submodule')
	
from models import hsm

import numpy as np
import os
import pdb
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import time
from models.submodule import *
from utils.eval import mkdir_p, save_pfm
from utils.preprocess import get_transform

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, level=1) :
	
	model = hsm(128,-1,level=level)
	model = nn.DataParallel(model, device_ids=[0])
	model.cuda()
	
	pretrained_dict = torch.load(pretrained)
	pretrained_dict['state_dict'] =  {k:v for k,v in pretrained_dict['state_dict'].items() if 'disp' not in k}
	model.load_state_dict(pretrained_dict['state_dict'],strict=False)
	
	model.eval()
	
	def testFunc(imgL, imgR) :
		disparity = model(imgL[np.newaxis,...].cuda(), imgR[np.newaxis,...].cuda())[0]
		return disparity
	
	return testFunc
