import torch
from golearn.algorithms.utils.count import count_memory, refresh_memory
import random


def reservoir_update(self, x_w, y):
    if len(self.buffer) == 0:
        self.buffer = x_w
        self.label = y
    elif len(self.buffer) + len(x_w) <= self.buffer_size:
        self.buffer = torch.cat((self.buffer, x_w))
        self.label = torch.cat((self.label, y))
    else:
        # Class Balanced Sampling
        for i, l in enumerate(y):
            class_num_list, major_class = count_memory(self.label)
            if l != major_class:
                index = refresh_memory(self.label, major_class)
                self.label[index] = l
                self.buffer[index] = x_w[i]
            else:
                # Reservoir Sampling
                p = random.random()
                if p < (class_num_list[major_class] - 1) / class_num_list[major_class]:
                    index = refresh_memory(self.label, major_class)
                    self.label[index] = l
                    self.buffer[index] = x_w[i]
