import torch
import numpy as np

from ImageLoader import ImageLoader
from NeuralSheet import NeuralSheet

if __name__ == '__main__':
    # size of the sheet
    shape = (32, 32)
    # size of input patch, ON-CELL / OFF-CELL / Empty-Cells (easier to reshape to image)
    patch_size = (7, 7, 3)
    # size of lateral excitation/inhibition
    n_lateral = 32
    # mode
    mode = "SONS"
    #mode = "SNN"
    #mode = "SOM"
    sheet = NeuralSheet(shape, patch_size, n_lateral, mode)
    # data
    dataloader = ImageLoader(patch_size=patch_size)

    # how long we keep the sim running after image is presented
    duration = 16
    # how many bins to convert to spikes
    n_bins = 8
    print(mode, shape, patch_size, n_lateral, duration)
    for i in range(len(dataloader)):
        if i % 100 == 0:
            print(mode, "processing", i)
        # fetch image patch
        img, label = dataloader[i % len(dataloader)]
        img = img.to(sheet.device)
        # reset neural states each new image
        sheet.reset_states()
        # loop over time
        for t in range(duration):
            # spike at time floor(p) with prob p NOTE: spike most salient First not Last
            img_spiked = (t == (n_bins - torch.floor(img * n_bins) - 1)) * torch.sign(img)
            # forward pass
            sheet.forward(img_spiked.float().squeeze(2), x_raw=img)
            # update weights
            sheet.update()
            # visualize spiking behaviour
            if i % 5000 < 2:
                sheet.visualize_spiking(i, t, mode)
        # visualize features
        if i % 5000 == 0:
            print("saving figures")
            sheet.visualize_features(i, mode)
            sheet.reset_tracking(img)