""" mazes_data.py
    Maze related dataloaders

    Borrowed from code for DeepThinking project
"""

import torch
from torch.utils import data
from easy_to_hard_data import MazeDataset

def prepare_maze_loader(train_batch_size, test_batch_size, train_data, test_data, shuffle=True):

    train_data = MazeDataset("../../../data", train=True, size=train_data, download=True)
    testset = MazeDataset("../../../data", train=False, size=test_data, download=True)

    train_split = int(0.8 * len(train_data))

    trainset, valset = torch.utils.data.random_split(train_data,
                                                     [train_split,
                                                      int(len(train_data) - train_split)],
                                                     generator=torch.Generator().manual_seed(42))
    trainloader = data.DataLoader(trainset,
                                  num_workers=0,
                                  batch_size=train_batch_size,
                                  shuffle=shuffle,
                                  drop_last=True)
    valloader = data.DataLoader(valset,
                                num_workers=0,
                                batch_size=test_batch_size,
                                shuffle=False,
                                drop_last=False)
    testloader = data.DataLoader(testset,
                                 num_workers=0,
                                 batch_size=test_batch_size,
                                 shuffle=False,
                                 drop_last=False)

    loaders = {"train": trainloader, "test": testloader, "val": valloader}

    return loaders
