import os
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import math
import torch.optim as optim
import numpy as np

__all__ = ['si_linear_net']



class si_linear_net(nn.Module):
    def __init__(self, D = 40, epsilon=0, cache=0, **kwargs):
        super(si_linear_net, self).__init__()
        self.D = D
        cache_file = f'model_si_linear_d{D}_{cache}.pt' if cache > 0 else None
        if cache > 0 and os.path.exists(cache_file):
            self.w = torch.load(cache_file)
        else:
            self.w = torch.randn(self.D).double()
            if cache > 0:
                torch.save(self.w, cache_file)
        self.w = nn.Parameter(self.w, requires_grad = True)
        self.epsilon = epsilon
        
    def forward(self,x):
        return x @ (self.w / (self.w.norm() + self.epsilon))
    
    
