import torch
from .hook import Hook


class ParamUpdateHook(Hook):
    """
    Parameter Update Hook

    necessary for update the model parameters
    """
    
    def before_train_step(self, algorithm):
        if hasattr(algorithm, 'start_run'):
            torch.cuda.synchronize()
            algorithm.start_run.record()

    # call after each train_step to update parameters
    def after_train_step(self, algorithm):
        loss = algorithm.out_dict['loss']
        # algorithm.optimizer.zero_grad()
        if len(algorithm.out_dict) > 1:
            reg = algorithm.out_dict['reg']
            reg.backward()
            grads_mem_m = []
            grads_mem_c = []
            for p in algorithm.model.parameters():
                if p.grad is not None:
                    grads_mem_m.append(p.grad.clone())
                else:
                    grad_tmp = p.data.clone().zero_()
                    grads_mem_m.append(grad_tmp)
            for p in algorithm.classifier.parameters():
                grads_mem_c.append(p.grad.clone())

            algorithm.model.zero_grad()
            algorithm.classifier.zero_grad()
            loss.backward()
            grads_core_m = []
            grads_core_c = []
            for p in algorithm.model.parameters():
                if p.grad is not None:
                    grads_core_m.append(p.grad.clone())
                else:
                    grad_tmp = p.data.clone().zero_()
                    grads_core_m.append(grad_tmp)
            for p in algorithm.classifier.parameters():
                grads_core_c.append(p.grad.clone())

            align, ref = 0, 0
            for t1, t2 in zip(grads_core_m, grads_mem_m):
                t1 = torch.flatten(t1)
                t2 = torch.flatten(t2)
                align += torch.dot(t1,t2)
                ref += torch.dot(t2, t2)

            for t1, t2 in zip(grads_core_c, grads_mem_c):
                t1 = torch.flatten(t1)
                t2 = torch.flatten(t2)
                align += torch.dot(t1,t2)
                ref += torch.dot(t2, t2)

            if align >= 0:
                torch.nn.utils.clip_grad_norm_(algorithm.model.parameters(), algorithm.clip_grad)
                torch.nn.utils.clip_grad_norm_(algorithm.classifier.parameters(), algorithm.clip_grad)
                algorithm.optimizer.step()
            else:
                grads_m = []
                for t1, t2 in zip(grads_core_m, grads_mem_m):
                    grads_m.append(t1 - align / ref * t2)
                for p, grad in zip(algorithm.model.parameters(), grads_m):
                    p.data.copy_(grad)
                # grads_c = []
                # for t1, t2 in zip(grads_core_c, grads_mem_c):
                #     grads_c.append(t1 - align / ref * t2)
                # for p, grad in zip(algorithm.classifier.parameters(), grads_c):
                #     p.data.copy_(grad)
                algorithm.optimizer.step()

        else:
            # update parameters
            if algorithm.use_amp:
                algorithm.loss_scaler.scale(loss).backward()
                if (algorithm.clip_grad > 0):
                    algorithm.loss_scaler.unscale_(algorithm.optimizer)
                    torch.nn.utils.clip_grad_norm_(algorithm.model.parameters(), algorithm.clip_grad)
                    torch.nn.utils.clip_grad_norm_(algorithm.classifier.parameters(), algorithm.clip_grad)
                algorithm.loss_scaler.step(algorithm.optimizer)
                algorithm.loss_scaler.update()
            else:
                loss.backward()
                if (algorithm.clip_grad > 0):
                    torch.nn.utils.clip_grad_norm_(algorithm.model.parameters(), algorithm.clip_grad)
                    torch.nn.utils.clip_grad_norm_(algorithm.classifier.parameters(), algorithm.clip_grad)
                algorithm.optimizer.step()

            if algorithm.scheduler is not None:
                algorithm.scheduler.step()
            algorithm.model.zero_grad()

            if hasattr(algorithm, 'end_run'):
                algorithm.end_run.record()
                torch.cuda.synchronize()
                algorithm.log_dict['train/run_time'] = algorithm.start_run.elapsed_time(algorithm.end_run) / 1000.

