# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from golearn.core import AlgorithmBase
from golearn.core.utils import ALGORITHMS,IMB_ALGORITHMS
import torch
import random
from torch.utils.data import TensorDataset, DataLoader
from golearn.algorithms.utils.count import count_memory, refresh_memory


@IMB_ALGORITHMS.register('imbfullysupervised')
class FullySupervised(AlgorithmBase):
    """
        Train a fully supervised model using labeled data only. This serves as a baseline for comparison.

        Args:
            - args (`argparse`):
                algorithm arguments
            - net_builder (`callable`):
                network loading function
            - tb_log (`TBLog`):
                tensorboard logger
            - logger (`logging.Logger`):
                logger to use
        """

    def __init__(self, args, net_builder, tb_log=None, logger=None):
        super().__init__(args, net_builder, tb_log, logger)
        self.buffer = torch.tensor([]).cuda(self.gpu)
        self.label = torch.tensor([]).cuda(self.gpu)

    def update_memory(self, x_w, y):
        if len(self.buffer) == 0:
            self.buffer = x_w
            self.label = y
        elif len(self.buffer) + len(x_w) < self.buffer_size:
            torch.cat((self.buffer, x_w))
            torch.cat((self.label, y))
        else:
            # Class Balanced Sampling
            for i, l in enumerate(y):
                class_num_list, major_class = count_memory(self.buffer)
                if l != major_class:
                    index = refresh_memory(self.label, major_class)
                    self.label[index] = l
                    self.buffer[index] = x_w[i]
                else:
                    # Reservoir Sampling
                    p = random.random()
                    if p < (class_num_list[major_class] - 1) / class_num_list[major_class]:
                        index = refresh_memory(self.label, major_class)
                        self.label[index] = l
                        self.buffer[index] = x_w[i]

    def train_step(self, x_w, y):
        # inference and calculate sup/unsup losses
        with self.amp_cm():
            logits_x_w = self.model(x_w)['logits']
            sup_loss = self.ce_loss(logits_x_w, y, reduction='mean')

        out_dict = self.process_out_dict(loss=sup_loss)
        log_dict = self.process_log_dict(sup_loss=sup_loss.item())
        return out_dict, log_dict

    def train(self):
        # lb: labeled, ulb: unlabeled
        self.model.train()
        self.epoch = 1
        self.call_hook("before_run")
        self.call_hook("before_train_epoch")
        for data in self.loader_dict['train']:
            # prevent the training iterations exceed args.num_train_iter
            self.call_hook("before_train_step")
            self.out_dict, self.log_dict = self.train_step(**self.process_batch(**data))
            self.call_hook("after_train_step")
            self.it += 1
            self.update_memory(**self.process_batch(**data))
            # experience replay
            if self.it % 100 == 1:
                self.call_hook("after_train_epoch")
                self.call_hook("Experience Replay")
                self.replay()
        self.call_hook("after_run")

    def replay(self):
        supervised_memory = TensorDataset(self.buffer, self.label)
        memory_loader = DataLoader(dataset=supervised_memory, batch_size=8, shuffle=True)
        for data in memory_loader:
            x, y = data
            logits = self.model(x)['logits']
            sup_loss = self.ce_loss(logits, y, reduction='mean')
            out_dict = self.process_out_dict(rep_loss=sup_loss)
            log_dict = self.process_log_dict(sup_loss=sup_loss.item())
            return out_dict, log_dict

IMB_ALGORITHMS['imbfullysupervised'] = FullySupervised