import argparse


def fed_args():
    """
    Arguments for running federated learning baselines
    :return: Arguments for federated learning baselines
    """
    parser = argparse.ArgumentParser()

    parser.add_argument('-c', '--config', type=str, default="configs/base.yaml", help='path to config file;',)
    parser.add_argument('-sn', '--save_name', type=str, required=True, help='file name for log, statistic and checkpoint')
    parser.add_argument('-g', '--gpu_id', type=str, default="cuda:0", help="Gpu id")
    parser.add_argument('-wsf', '--w_save_freq', type=int, default=None,
                        help="Save the sever model every w_save_freq round")
    parser.add_argument('--mr18', action="store_true", default=False)
    parser.add_argument('--using_wandb', action="store_true", default=False)
    parser.add_argument('--wandb_proj_name', type=str, required=True)
    parser.add_argument(
        "--mixup",
        type=float,
        default=0.8,
        help="mixup alpha, mixup enabled if > 0. (default: 0.8)",
    )
    parser.add_argument(
        "--cutmix",
        type=float,
        default=1.0,
        help="cutmix alpha, cutmix enabled if > 0. (default: 1.0)",
    )

    parser.add_argument('-nc', '--sys-n_client', type=int, help='Number of the clients')
    parser.add_argument('-ck', '--sys-n_local_class', type=int, help='Number of the classes in each client')
    parser.add_argument('-ds', '--sys-dataset', type=str,
                        help='Dataset name, one of the following four datasets: MNIST, CIFAR-10, FashionMnist, SVHN')
    parser.add_argument('-md', '--sys-model', type=str, help='Model name')
    parser.add_argument('-is', '--sys-i_seed', type=int, help='Seed used in experiments')
    parser.add_argument('-rr', '--sys-res_root', type=str, help='Root directory of the results')
    parser.add_argument('-nr', '--sys-n_round', type=int, help='Number of global communication rounds')
    parser.add_argument('-os', '--sys-oneshot', type=bool, default=False,
                        help='Ture if only run with one-shot communication, otherwise false.')
    parser.add_argument('-dda', '--sys-dataset_dir_alpha', type=float, help='Alpha used for partitioning dataset with dirichlet')

    parser.add_argument('-sne', '--server-n_epoch', type=int,
                        help='Number of training epochs in the server')
    parser.add_argument('-sbs', '--server-bs', type=int, help='Batch size in the server')
    parser.add_argument('-slr', '--server-lr', type=float, help='Learning rate in the server')
    parser.add_argument('-smt', '--server-momentum', type=float, help='Momentum in the server')
    parser.add_argument('-snw', '--server-n_worker', type=int, help='Number of workers in the server')
    parser.add_argument('-so', '--server-optimizer', type=str, help='Optimizer for server model')
    parser.add_argument('-sls', '--server-lr_scheduler', type=str, default="cos", help='Lr scheduler in the server')

    parser.add_argument('-cis', '--client-instance', type=str,
                        help='Instance of federated learning algorithm used in clients')
    parser.add_argument('-cil', '--client-instance_lr', type=float, help='Learning rate in clients')
    parser.add_argument('-cib', '--client-instance_bs', type=int, help='Batch size in clients')
    parser.add_argument('-cie', '--client-instance_n_epoch', type=int,
                        help='Number of local training epochs in clients')
    parser.add_argument('-cim', '--client-instance_momentum', type=float, help='Momentum of local training in clients')
    parser.add_argument('-cin', '--client-instance_n_worker', type=int, help='Number of workers in the server')
    parser.add_argument('-ciw', '--client-instance_weight_decay', type=float, help='Weight decay of local training in clients')
    parser.add_argument('-cif', '--client-instance_freeze_bn', action="store_true", default=False, help='Freeze bn parameters for local training')
    parser.add_argument('--client-instance_rbn_fc', action="store_true", default=False, help='Freeze bn parameters for local training')
    parser.add_argument('--client-instance_rbn', action="store_true", default=False, help='Freeze bn parameters for local training')
    
    parser.add_argument('-cia', '--client-instance_aug', action="store_true", default=False, help='Dataset augment for local training')
    parser.add_argument('--client-instance_mixup_alpha', type=float, default=0.0)
    parser.add_argument('--client-instance_cutmix_alpha', type=float, default=0.0)
    parser.add_argument('--client-instance_identity_aug', action="store_true", default=False, help='insert an identity into mix aug')

    # FedProx
    parser.add_argument('-proxmu', '--fedprox_mu', type=float, help='Parameter of FedProx')

    # FedSD2C
    parser.add_argument('-fm', '--fedsd2c_mipc', type=int, help='Number of pre-loaded images per class')
    parser.add_argument('-fi', '--fedsd2c_ipc', type=int, help='IPC for client dataset distillation')
    parser.add_argument('-fmp', '--fedsd2c_m_path', type=str, help='Path to teacher model')
    parser.add_argument('-fnc', '--fedsd2c_num_crop', type=int, help='Number of Crop')
    parser.add_argument('-fmt', '--fedsd2c_mix_type', type=str, default=None, help='mixup or cutmix or None')
    parser.add_argument('-ff', '--fedsd2c_factor', type=int, help='Factor of MultiRandomCrop')
    parser.add_argument('-ft', '--fedsd2c_temperature', type=float, help='Temperature of soft label')
    parser.add_argument('-fit', '--fedsd2c_iteration', type=int, help='Number of SRe2L steps')
    parser.add_argument('-fj', '--fedsd2c_jitter', type=int, help='jitter')
    parser.add_argument('-flr', '--fedsd2c_lr', type=float, help='lr for SRe2L stage')
    parser.add_argument('-fls', '--fedsd2c_l2_scale', type=float, help='coefficient for SRe2L l2 loss')
    parser.add_argument('-ftl', '--fedsd2c_tv_l2', type=float, help='coefficient for SRe2L tv l2 loss')
    parser.add_argument('-frb', '--fedsd2c_r_bn', type=float, help='coefficient for SRe2L bn loss')
    parser.add_argument('-frc', '--fedsd2c_r_c', type=float, help='coefficient for other client ce loss')
    parser.add_argument('-fra', '--fedsd2c_r_adv', type=float, help='coefficient for adversarial los')
    parser.add_argument('-fsi', '--fedsd2c_store_images', action="store_true", default=False,
                        help='Whether to save intermediate images')
    parser.add_argument('-ful', '--fedsd2c_use_ld', action="store_true", default=False,
                        help='Whether to use label distribution as distillation aim')
    parser.add_argument('-fsti', '--fedsd2c_sd_trn_interval', type=str, help='The interval between local data training and server distill data training')
    parser.add_argument('-fsa', '--fedsd2c_sd_alpha', type=float, help='coefficient for server distill training')
    parser.add_argument('-fcdo', '--fedsd2c_client_diff_optim', action="store_true", default=False)
    parser.add_argument('-fssr', '--fedsd2c_sd_start_round', type=int, help='The start epoch for server distillation')
    parser.add_argument('-fii', '--fedsd2c_inputs_init', type=str, default="random")
    parser.add_argument('-fhl', '--fedsd2c_hard_label', action="store_true", default=False)
    parser.add_argument('-fzis', '--fedsd2c_zero_init_scaler', type=str, default="cos")
    parser.add_argument('-fsr', '--fedsd2c_syn_root', type=str)
    parser.add_argument('--fedsd2c_mask_ratio', type=float, default=0.5)
    parser.add_argument('--fedsd2c_patch_size', type=int, default=16)
    parser.add_argument('--fedsd2c_filling_methods', type=str, default="random")
    parser.add_argument('--fedsd2c_gm_iter', type=int, default=3)
    parser.add_argument('--fedsd2c_gm_metric', type=str, default="ours")
    parser.add_argument('--fedmix_batch_size', type=int, default=2)
    parser.add_argument('--fedmix_method', type=str, default="random")
    parser.add_argument('--fedmix_src', type=str, default="self")
    parser.add_argument('--descending_dist', action="store_true", default=False)
    parser.add_argument('--fedsd2c_iter_mode', type=str, default="label")
    parser.add_argument('--fedsd2c_compress', action="store_true", default=False)

    parser.add_argument('--fedsd2c_noise_type', type=str, default="None")
    parser.add_argument('--fedsd2c_noise_s', type=float, default=0.)
    parser.add_argument('--fedsd2c_noise_p', type=float, default=0.)

    parser.add_argument('--fedsd2c_clip_client_data', action="store_true", default=False)
    parser.add_argument('--fedsd2c_ipc_min', type=int, default=1, help='Used for determing the minimum number of data used for distillation')
    parser.add_argument('--avg_logits', type=str, default=None,  help='Whether to use avg logits for distillation')
    parser.add_argument('--beta', type=float, default=0, help='Weight for weighted average')
    parser.add_argument('--client_model_root', type=str, default=None, help='If not None load client model from the root')
    parser.add_argument('--cluster_methods', type=str, default="KMeans")
    
    parser.add_argument('--save_client_model', action="store_true", default=False)
    parser.add_argument('--fourier_lambda', type=float, default=0.9)
    parser.add_argument('--fourier_src', type=str, default="img")

    # FedCVAE
    parser.add_argument('--cvae_z_dim', type=int)
    parser.add_argument('--cvae_beta', type=float)

    # DFKD
    parser.add_argument('--dfkd_z_dim', type=int)
    parser.add_argument('--dfkd_temp', type=float)
    parser.add_argument('--dfkd_r_bn', type=float)
    parser.add_argument('--dfkd_r_adv', type=float)
    parser.add_argument('--dfkd_r_bal', type=float)
    parser.add_argument('--dfkd_r_oh', type=float)
    parser.add_argument('--dfkd_giter', type=int)
    parser.add_argument('--dfkd_miter', type=int)
    parser.add_argument('--dfkd_eiter', type=int)
    parser.add_argument('--dfkd_img_size', type=int)
    parser.add_argument('--dfkd_batch_size', type=int)
    parser.add_argument('--dfkd_syn_data', type=bool, default=True)


    # CoBoost
    parser.add_argument('--cb_z_dim', type=int)
    parser.add_argument('--cb_temp', type=float)
    parser.add_argument('--cb_r_bn', type=float)
    parser.add_argument('--cb_div', type=float)
    parser.add_argument('--cb_giter', type=int)
    parser.add_argument('--cb_miter', type=int)
    parser.add_argument('--cb_witer', type=int)
    parser.add_argument('--cb_hs', type=float)
    parser.add_argument('--cb_oh', type=float)
    parser.add_argument('--cb_mu', type=float)
    parser.add_argument('--cb_wdc', type=float)
    parser.add_argument('--cb_weighted', type=bool, default=False)
    parser.add_argument('--cb_odseta', type=float)

    # FedD3
    parser.add_argument('--fedd3_n_dd', type=int, help='Number of distilled images in clients')
    parser.add_argument('--fedd3_max_n_epoch', type=int, help='Maximal number of epochs in clients')
    parser.add_argument('--fedd3_threshold', type=float, help='Accuracy threshold for dataset distillation in clients')
    parser.add_argument('--fedd3_bs', type=int, help='Batch size in clients')

    args = parser.parse_args()
    args.mix_type = args.fedsd2c_mix_type
    if args.fedsd2c_mix_type == "mixup" and args.fedsd2c_temperature is None:
        args.fedsd2c_temperature = 4
    elif args.fedsd2c_mix_type == "cutmix" and args.fedsd2c_temperature is None:
        args.fedsd2c_temperature = 20
    print(args.fedsd2c_mix_type, args.fedsd2c_temperature)
    return args
