local M = {}
local MnistDataset= torch.class('resnet.MnistDataset', M)
require 'cutorch'

function MnistDataset:__init(imageInfo, opt, split)
   assert(imageInfo[split], split)
   self.randomProj=(opt.netType=='FFNN_resnet_single_layer');
   self.imageInfo = {labels=imageInfo[split].labels}
   if (self.randomProj==true) then
      self.imageInfo.data=torch.CudaTensor(imageInfo[split].data:size(1),784)
      for i=1,imageInfo[split].data:size(1) do
         local base=1
         for j=1,1 do
            self.imageInfo.data[i][{{base,base+783}}]:copy(imageInfo[split].data[i])
            base=base+784
         end
      end
   else
      self.imageInfo.data=torch.FloatTensor(imageInfo[split].data:size(1),784)
      for i=1,imageInfo[split].data:size(1) do
         self.imageInfo.data[i]:copy(imageInfo[split].data[i])
      end
   end
   self.split = split
end

function MnistDataset:get(i)

   local image = self.imageInfo.data[i]
   local label = self.imageInfo.labels[i]

   return {
      input = image,
      target = label,
   }
end

function MnistDataset:size()
   return self.imageInfo.data:size(1)
end

function MnistDataset:preprocess()
   return function(input)
      return input
   end
end

return M.MnistDataset
