from collections import defaultdict
from typing import Dict, List
import importlib
import inspect

import torch
import torch.nn as nn


from workbench.utils.adam_cpr_fast import AdamCPR, group_parameters_for_cpr_optimizer




def group_parameters_for_cpr(model, optim_hps, avoid_keywords=[],
                               embedding_regularization=False,
                               bias_regularization=False,
                               normalization_regularization=False):
    if not avoid_keywords:
        avoid_keywords = []

    apply_cpr = set()
    apply_no_cpr = set()
    special = set()
    whitelist_weight_modules = (nn.Linear, nn.Conv2d)
    blacklist_weight_modules = ()
    if embedding_regularization:
        whitelist_weight_modules += (nn.Embedding,)
    else:
        blacklist_weight_modules += (nn.Embedding,)

    if normalization_regularization:
        whitelist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
                                     nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
                                     nn.GroupNorm, nn.SyncBatchNorm,
                                     nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
                                     nn.LayerNorm, nn.LocalResponseNorm)
    else:
        blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
                                     nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
                                     nn.GroupNorm, nn.SyncBatchNorm,
                                     nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
                                     nn.LayerNorm, nn.LocalResponseNorm)


    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
            # In case of parameter sharing, some parameters show up here but are not in
            # param_dict.keys()
            if not p.requires_grad or fpn not in param_dict:
                continue  # frozen weights
            if hasattr(p, '_optim'):
                special.add(fpn)
            elif isinstance(m, blacklist_weight_modules):
                apply_no_cpr.add(fpn)
            elif any([keyword in fpn for keyword in avoid_keywords]):
                apply_no_cpr.add(fpn)
            elif not bias_regularization and pn.endswith('bias'):
                apply_no_cpr.add(fpn)
            elif isinstance(m, whitelist_weight_modules):
                apply_cpr.add(fpn)
            # else:
            #     print(f"cpr_group_named_parameters: Not using any rule for {fpn} in {type(m)}")

    apply_cpr |= (param_dict.keys() - apply_no_cpr - special)

    # validate that we considered every parameter
    inter_params = apply_cpr & apply_no_cpr
    union_params = apply_cpr | apply_no_cpr
    assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both apply_cpr/apply_no_cpr sets!"
    assert len(param_dict.keys() - special - union_params) == 0, (f"parameters {str(param_dict.keys() - union_params)} "
                                                                  f" were not separated into either apply_cpr/apply_no_cpr set!")

    if not apply_no_cpr:
        param_groups = [{"params": [param_dict[pn] for pn in sorted(apply_cpr)],
                         "names": [pn for pn in sorted(apply_cpr)], 'apply_cpr': True, **optim_hps}]
    else:
        param_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(apply_cpr))],
             "names": [pn for pn in sorted(list(apply_cpr))], 'apply_cpr': True, **optim_hps},
            {"params": [param_dict[pn] for pn in sorted(list(apply_no_cpr))],
             "names": [pn for pn in sorted(list(apply_no_cpr))], 'apply_cpr': False, **optim_hps},
        ]
    # Add parameters with special hyperparameters
    # Unique dicts
    hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)]
    for hp in hps:
        params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp]
        param_groups.append({"params": params, **hp})

    return param_groups




def apply_CPR(model, optimizer_cls, kappa_init_param, kappa_init_method='warm_start', reg_function='l2',
              kappa_adapt=False, kappa_update=1.0, apply_lr=False,
              normalization_regularization=False, bias_regularization=False, embedding_regularization=False,
              **optimizer_args):


    parameters = group_parameters_for_cpr_optimizer(model)

    optimizer = AdamCPR(parameters, lr=optimizer_args['lr'],
                            kappa_init_param=kappa_init_param,
                            kappa_update=1.0,
                            kappa_init_method=kappa_init_method,
                            reg_function='l2',
                            )


    return optimizer

