import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import time
import math


if 'models' in sys.modules :
	sys.modules.pop('models')
if 'models.utils' in sys.modules :
	sys.modules.pop('models.utils')
if 'models.submodule' in sys.modules :
	sys.modules.pop('models.submodule')
	
from models import *

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, level=-1) :
	
	model = stackhourglass(192)
	model = nn.DataParallel(model, device_ids=[0])
	model.cuda()
	
	pretrained = torch.load(pretrained)
	model.load_state_dict(pretrained['state_dict'])
	
	model.eval()
	
	def testFunc(imgL, imgR) :
		disparity = model(imgL[np.newaxis,...].cuda(), imgR[np.newaxis,...].cuda())[level]
		return torch.squeeze(disparity)
	
	return testFunc
