import torch

class WeightMatrix:

    def __init__(self, n_neurons, dtype=torch.float32, device='cuda:0'):
        self.weights = torch.randn(size=(n_neurons, n_neurons), device=device)**2
        self.normalize()

    def initialize(self, is_input_neuron):
        self.weights[is_input_neuron, :] = 0
        self.weights[:, ~is_input_neuron] = 0
        self.normalize()

    def forward(self, x):
        out = torch.matmul(x.float(), self.weights.t())
        return out

    def update_values(self, dw, index):
        self.weights[index] += dw
        self.weights[index] = self.weights[index].clamp(0, 1)

    def normalize(self):
        # normalize on dim=1
        norm = torch.sum(self.weights.abs(), dim=-1, keepdim=True) + 1e-15
        self.weights = self.weights / norm

    def visualize_dense(self):
        import matplotlib.pyplot as plt
        plt.figure(20)
        plt.clf()
        plt.imshow(self.weights.cpu().numpy())
        plt.pause(0.0001)