import torch
import numpy as np
import pytorch_lightning as pl
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import utils
from pytorch_lightning.utilities import AttributeDict

import attr
from functools import partial
from typing import List
from typing import Optional
from typing import Union
from .base_module import Base_Module

__all__ = ['SI_Linear_Module']

class SI_Linear_Module(Base_Module):        
    def __init__(self, hparams = None, **kwargs):
        self.criterion = utils.__dict__[hparams.loss]
        self.train_dataset, self.w = utils.gen_linear_regression(
            D=hparams.D, N=hparams.train_size, gauss = hparams.gauss, cache=hparams.cache, split='train')
        print("The ground truth: ", self.w[0])
        print("Projection on solution manifold: ", self.w[0].dot(self.w[1]).item(), self.w[0].dot(self.w[2]).item())
        self.test_dataset, _ = utils.gen_linear_regression(
            D=hparams.D, N=hparams.test_size, gauss = hparams.gauss, cache=hparams.cache, split='test', w0=self.w[0])
        
        ## set dataset first!!
        super().__init__( hparams,**kwargs) 
        # self.model.full_trainset = self.train_dataset.tensors[0]
        
    def train_dataloader(self):
        dataloader = DataLoader(self.train_dataset, batch_size=self.hparams.train_batch_size, num_workers=self.hparams.num_data_workers, shuffle=True, drop_last=self.hparams.drop_last_batch)
        # dataloader = SimpleLoader(self.train_dataset, batch_size=self.hparams.train_batch_size)
        return dataloader
    
    def val_dataloader(self):
        dataloader = DataLoader(self.test_dataset, batch_size=self.hparams.test_batch_size, num_workers=self.hparams.num_data_workers)
        return dataloader

    def training_step(self, batch, batch_idx):
        loss = super().training_step(batch, batch_idx)
        loss += self.hparams.noise * torch.randn_like(self.model.w).dot(self.model.w / self.model.w.norm())
        # if self.hparams.svag is not None:
        #     loss += self.svag_loss(self.hparams.svag) #hparams.svag is a dictionary 
        return loss
    
    # def svag_loss(self,svag):
    #     dataset = utils.gen_linear_regression(D=self.hparams.D,N=svag.batch_size,gauss = svag.gauss)
    #     input,target = dataset.tensors[0],dataset.tensors[1]
    #     output = self.model.forward(input, normalization_method = svag.normalization_method)
        
    #     loss = self.criterion(output,target) * svag.kappa
    #     return loss*(-1 +2*torch.bernoulli(torch.tensor(0.5)))
        
        

    
# class SimpleLoader(object):
#     def __init__(self,dataset,batch_size): ##dataset should be tensordataset
#         self.dataset = dataset
#         self.batch_size = batch_size
#         self.replace = not (len(self.dataset.tensors[0])==self.batch_size) # no replacement if full batch
        
#     def __len__(self):
#         return float('inf')
        
#     def __iter__(self):
#         return self
    
#     def __next__(self):
#         index = torch.from_numpy(np.random.choice(len(self.dataset.tensors[0]), self.batch_size,replace = self.replace))
#         return [tensor[index] for tensor in self.dataset.tensors]