
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch import nn

from models.psmnet import PSMNet
from models.gwcnet import GwcNet
from models import __models__, __loss__

from datasets.exrDatasetLoader import exrImagePairDataset

import os
import argparse
import collections

from datasets import exrDatasetLoader

if __name__ == "__main__" : 
	
	parser = argparse.ArgumentParser(description='Cascade Stereo Network (CasStereoNet) finetuning')
	
	parser.add_argument("--traindata", help="Path to the training images")
	#parser.add_argument("--validationdata", help="Path to the validation images")
	parser.add_argument("--numepochs", default=10, type=int, help="Number of epochs to run")
	parser.add_argument("--batchsize", default=4, type=int, help="Batch size")
	parser.add_argument("--numworkers", default=4, type=int, help="Number of workers threads used for loading dataset")
	parser.add_argument('--learningrate', default = 1e-2, type=float, help="Learning rate for the optimizer")
	parser.add_argument('--ramcache', action="store_true", help="cache the whole dataset into ram. Do this only if you are certain it can fit.")

	parser.add_argument('--model', default='gwcnet-c', help='select a model structure', choices=__models__.keys())

	parser.add_argument('-p', '--pretrained', default='./models/cas-gwcnet-c-kitti2015.ckpt', help="Pretrained weights")
	parser.add_argument('-o', '--output', default='./models/cas-gwcnet-c-apstereo.pth', help="Pretrained weights")
	
	# parse arguments
	args = parser.parse_args()
	
	ModelClass = __models__[args.model]
	
	model = ModelClass(maxdisp=192,
						ndisps=[48,24],
						disp_interval_pixel=[4,1],
						cr_base_chs=[32,32,16],
						grad_method='detach',
						using_ns=True,
						ns_size=13)
	
	model.cuda()
	
	checkpoint = torch.load(args.pretrained)
	model.load_state_dict(checkpoint['model'])
	
	model = model.cpu()

	#model = nn.DataParallel(model)
	
	cache = False
	
	dats = exrImagePairDataset(imagedir = args.traindata,
							left_nir_channel = 'Left.SimulatedNir.A', 
							right_nir_channel = 'Right.SimulatedNir.A',
							cache = cache,
							ramcache = args.ramcache,
							direction = 'l2r')
	
	datl = DataLoader(dats, 
					   batch_size= args.batchsize, 
					   shuffle=True, 
					   num_workers=args.numworkers)
	
	def buildOptimizer(parameters) :
		return Adam(parameters, lr=args.learningrate)
	
	optimizer = buildOptimizer(model.parameters())
	model_loss = __loss__[args.model]
	
	for ep in range(args.numepochs) :
		
		count = 0
		aggr = 0
		
		for batch_id, sampl in enumerate(datl) :
			
			imgLeft = sampl['frameLeft']
			imgRight = sampl['frameRight']
			imgGtDisp = sampl['trueDisparity']
			
			imgLeft = torch.cat((imgLeft, imgLeft, imgLeft), dim=1)
			imgRight = torch.cat((imgRight, imgRight, imgRight), dim=1)
			
			r = model(imgLeft, imgRight)
			
			mask = (imgGtDisp > 0).squeeze(1)
			l = model_loss(r, imgGtDisp.squeeze(1), mask, dlossw=[0.5, 2.0])
			
			optimizer.zero_grad()
					
			l.backward()
			optimizer.step()
			
			lval = l.item()
			
			#print(f"\tEpoch {ep}, batch {batch_id}: loss = {lval}")
			
			aggr += lval
			count += 1.
			
		print(f"Epoch {ep}: avg loss = {aggr/count}")
			
	model = model.cuda()
	torch.save({"model" : model.state_dict()}, args.output)
			
