import torch
import argparse
import logging
import os
import wandb

from eufm_numerics.dsets import OrderedCifar10, OrderedCifar10Test
from eufm_numerics.training import RealNNTrainerAndAnalyzer
from torch.utils.data import DataLoader
from eufm_numerics.nn_models import ExtendedResNet20, ExtendedResNet32, ExtendedResNet44, ExtendedResNet56


parser = argparse.ArgumentParser()
parser.add_argument('--resnet_depth', type=int, choices=[20, 32, 44, 56],
                    help='The number of layers of the resnet backbone.')
parser.add_argument('--input_dim_vects', nargs='+', type=int,
                    help='The dimensions of feature refiner. Start with resnet output size.')
parser.add_argument('--cifar_classes', nargs='+', type=int, help='Which classes from cifar to do the experiment on.')
parser.add_argument('--samples_per_class', type=int, default=5000, help='Number of samples per class.')
parser.add_argument('--starting_weight_decay', type=float, default=0.0001, help='Smallest wd in the set of experiments')
parser.add_argument('--weight_decay_increment', type=int, default=4,
                    help='Multiplicative factor between weight decays of two consecutive runs.')
parser.add_argument('--num_runs', type=int, default=4, help='Number of runs, each with different weight decay.')
parser.add_argument('--batch_norm', action='store_true', help='Indicates whether or not to use the batch norm.')
parser.add_argument('--lr', type=float, default=0.01, help='The training learning rate.')
parser.add_argument('--scheduler_milestone', type=int, default=1000, help='When to decrease the learning rate.')
parser.add_argument('--scheduler_gamma', type=float, default=0.2, help='The multiplication factor of the lr decrease.')
parser.add_argument('--train_batch_size', type=int, default=125, help='The batch size to be used during training.')
parser.add_argument('--test_batch_size', type=int, default=500, help='The batch size to be used during testing.')
parser.add_argument('--eval_batch_size', type=int, default=500,
                    help='The batch size to be used when computing NC metrics.')
parser.add_argument('--num_test_samples', type=int, default=2000, help='Number of test samples.')
parser.add_argument('--device', type=str, default='cuda', help='Which device to compute on.')
parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for the dataloader.')
parser.add_argument('--num_epochs', type=int, default=500, help='Number of big epochs to do.')
parser.add_argument('--log_every', type=int, default=1, help='Each how many epochs to do the analysis.')
parser.add_argument('--save_model', action='store_true', help='Indicates whether to save the model or not.')
parser.add_argument('--exp_name', type=str, default='default', help='Experiment name.')
parser.add_argument('--use_wandb', action='store_true', help='Indicates whether to use wandb.')
parser.add_argument('--penalize_uf', action='store_true',
                    help='Indicates whether weight decay will be applied on unconstrained '
                         'features rather than resnet weights.')
parser.add_argument('--bn_start', action='store_true',
                    help='Indicates if starting batch norm is placed in the feature refiner.')
parser.add_argument('--extra_conv', action='store_true',
                    help='Indicates whether an extra convolutional layer will be placed on top of the resnet.')
parser.add_argument('--extra_fc', action='store_true',
                    help='Indicates, whether to use extra fully connected layer in the end of resnet.')
parser.add_argument('--handput_wds', nargs='+', type=float, default=None,
                    help='If inputted, these weight decays will be used instead of the automatically computed ones.')
parser.add_argument('--fix_fr', action='store_true',
                    help='Indicates if the dufm part of the model is fixed throughout the training.')
parser.add_argument('--pretrained_fr_path', type=str, default=None,
                    help='Gives the path to the pre-trained feature refiner. Requires fix_fr argument to be true.')
args = parser.parse_args()


resnets = {20: ExtendedResNet20, 32: ExtendedResNet32, 44: ExtendedResNet44, 56: ExtendedResNet56}


def get_exp_path(exp_name):
    path = 'experiments'
    if not os.path.exists(path):
        os.mkdir(path)
    path = os.path.join(path, 'resnet_depth_' + str(args.resnet_depth))
    if not os.path.exists(path):
        os.mkdir(path)
    path = os.path.join(path, 'num_layers_' + str(len(args.input_dim_vects)-1))
    if not os.path.exists(path):
        os.mkdir(path)
    if args.batch_norm:
        path = os.path.join(path, 'with_batch_norm')
        if not os.path.exists(path):
            os.mkdir(path)
    else:
        path = os.path.join(path, 'without_batch_norm')
        if not os.path.exists(path):
            os.mkdir(path)
    path = os.path.join(path, 'cifar_classes_' + str(args.cifar_classes))
    if not os.path.exists(path):
        os.mkdir(path)
    path = os.path.join(path, exp_name)
    if not os.path.exists(path):
        os.mkdir(path)
    exp_idx = len(os.listdir(path))
    exp_number = 'run_' + str(exp_idx)
    suffix = exp_number
    path = os.path.join(path, exp_number)
    os.mkdir(path)

    return path, suffix


if __name__ == '__main__':

    for i in range(args.num_runs):

        path, suffix = get_exp_path(args.exp_name)
        the_logger = logging.getLogger()
        the_logger.setLevel(logging.DEBUG)
        logfile = logging.FileHandler(os.path.join(path, 'log.txt'))
        logfile.setLevel(logging.DEBUG)
        streamer = logging.StreamHandler()
        streamer.setLevel(logging.DEBUG)
        the_logger.addHandler(logfile)
        the_logger.addHandler(streamer)

        if args.use_wandb:
            # jobs are grouped into job types, then groups, then projects.
            # We use project for overall big goal, group for lower level segmenting (eg, dataset/model),
            # job type for the pruning configuration specifics, and within that the runs should only differ by the
            # random seed.
            wandb.init(project='neural-collapse',
                       group=args.exp_name,
                       job_type=suffix,
                       config={
                           "resnet_depth": args.resnet_depth,
                           "dufm_depth": len(args.input_dim_vects)-1,
                           "classes_used": args.cifar_classes,
                       }
                       )

        the_logger.info('Everything works as planned.')
        input_dims = args.input_dim_vects
        samples_per_class = args.samples_per_class
        num_layers = len(input_dims) - 1

        scaling = 1
        weight_decays = (args.weight_decay_increment * i + 1) * args.starting_weight_decay * torch.ones(num_layers + 1)
        if args.handput_wds is not None:
            weight_decays = torch.Tensor(args.handput_wds)

        the_model = resnets[args.resnet_depth](bias=False, relu=True, batch_norm=args.batch_norm, num_layers=num_layers,
                                               input_dims=input_dims, num_classes=input_dims[num_layers],
                                               num_per_class=samples_per_class, weight_decays=weight_decays,
                                               scaling=scaling, dist='kaiming', bn_start=args.bn_start,
                                               extra_conv=args.extra_conv, extra_fc=args.extra_fc)
        the_model.to(args.device)

        if args.fix_fr:
            the_mlp = torch.load(f=args.pretrained_fr_path, map_location=args.device)
            for layer_idx in range(num_layers):
                setattr(getattr(getattr(the_model.fr, the_model.fr.weight_names[layer_idx]), 'weight'), 'data', getattr(getattr(getattr(the_mlp, the_mlp.weight_names[layer_idx]), 'weight'), 'data'))

        if args.penalize_uf:
            if not args.fix_fr:
                optimizer = torch.optim.SGD(the_model.parameters(), lr=args.lr, momentum=0, weight_decay=0)
            if args.fix_fr:
                optimizer = torch.optim.SGD(the_model.backbone.parameters(), lr=args.lr, momentum=0, weight_decay=0)
        else:
            optimizer = torch.optim.SGD(the_model.parameters(), lr=args.lr, momentum=0, weight_decay=weight_decays[0])
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.scheduler_milestone], gamma=args.scheduler_gamma)
        loss_fcn = torch.nn.MSELoss()
        data = OrderedCifar10(class_indices=args.cifar_classes)
        test_data = OrderedCifar10Test(class_indices=args.cifar_classes)
        train_dataloader = DataLoader(data, batch_size=args.train_batch_size, shuffle=True)
        eval_dataloader = DataLoader(data, batch_size=args.eval_batch_size, shuffle=True)
        test_dataloader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=True)
        trainer = RealNNTrainerAndAnalyzer(train_dataloader=train_dataloader, eval_dataloader=eval_dataloader,
                                           test_dataloader=test_dataloader, model=the_model,
                                           optimizer=optimizer,
                                           scheduler=scheduler, loss_fcn=loss_fcn, wds=weight_decays,
                                           num_classes=input_dims[num_layers], num_samples=samples_per_class,
                                           num_backbone_layers=args.resnet_depth, num_test_samples=args.num_test_samples,
                                           exp_path=path, logger=the_logger, device=args.device, use_wandb=args.use_wandb,
                                           penalize_unconstrained_features=args.penalize_uf)

        _, _, _, _, _, _, _, _ = trainer.run_and_analyze(steps=args.num_epochs, analyze=True, log_every=args.log_every,
                                                         save_model=args.save_model)
