import numpy as np

VGG_ORIGIN_DEPS = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512]
VGG_CONV_NAMES = ['block1_conv1','block1_conv2',
                  'block2_conv1','block2_conv2',
                  'block3_conv1','block3_conv2','block3_conv3',
                  'block4_conv1', 'block4_conv2', 'block4_conv3',
                  'block5_conv1', 'block5_conv2', 'block5_conv3',]
VGG_FC_NAMES = ['fc1', 'fc2', 'error']
VGG_FC_OUTS = [4096, 4096, 1000]

VFS_FC_NAMES = ['fc1', 'error']
VFS_FC_OUTS = [512, 10]

VH_FC_OUTS = [512, 100]

IMAGENET_DIR = '/home/dataset/ILSVRC2015_TFRecords'

VGG_ORIGIN_DEPS = np.array([64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512], dtype=np.int32)

VGG_SUBSEQUENT_STRATEGY = {i:(i+1) for i in range(13)}

LENET_ORIGIN_DEPS = np.array([20, 50], dtype=np.int32)

LENET300_ORIGIN_DEPS = np.array([300,100], dtype=np.int32)

CFQK_ORIGIN_DEPS = np.array([32, 32, 64], dtype=np.int32)

#############################   begin   ResNet56 ####################################
RESNET56_ORIGIN_DEPS = [16,
                        [[16, 16]] * 9,
                        [[32, 32]] * 9,
                        [[64, 64]] * 9]
RESNET56_ORIGIN_DEPS_FLATTENED = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
                                  32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
                                  64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]
RESNET56_INTERNAL_KERNEL_IDXES = [1,3,5,7,9,11,13,15,17,20,22,24,26,28,30,32,34,36,39,41,43,45,47,49,51,53,55]
RESNET56_PACESETTER_IDXES = [0, 19, 38]
RESNET56_FOLLOW_DICT = {0:0, 2:0, 4:0, 6:0, 8:0, 10:0, 12:0, 14:0, 16:0, 18:0,
                        19:19, 21:19, 23:19, 25:19, 27:19, 29:19, 31:19, 33:19, 35:19, 37:19,
                        38:38, 40:38, 42:38, 44:38, 46:38, 48:38, 50:38, 52:38, 54:38, 56:38}
RESNET56_SUBSEQUENT_STRATEGY = {i : (i+1) for i in RESNET56_INTERNAL_KERNEL_IDXES}
RESNET56_SUBSEQUENT_STRATEGY[0] = 1
for i in RESNET56_FOLLOW_DICT.keys():
    if i == 18:
        RESNET56_SUBSEQUENT_STRATEGY[i] = [19, 20]
    elif i == 37:
        RESNET56_SUBSEQUENT_STRATEGY[i] = [38, 39]
    elif i not in RESNET56_PACESETTER_IDXES:
        RESNET56_SUBSEQUENT_STRATEGY[i] = i+1
def convert_flattened_resnet56_deps(flattened_deps):
    assert len(flattened_deps) == 57
    assert flattened_deps[19] == flattened_deps[21] and flattened_deps[39] == flattened_deps[41]
    d = [flattened_deps[0]]
    tmp = []
    for i in range(9):
        tmp.append([flattened_deps[1 + i * 2], flattened_deps[2 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(9):
        tmp.append([flattened_deps[20 + i * 2], flattened_deps[21 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(9):
        tmp.append([flattened_deps[39 + i * 2], flattened_deps[40 + i * 2]])
    d.append(tmp)
    return d
#############################   end   ResNet56 ####################################

#############################   begin   ResNet50 ####################################
RESNET50_ORIGIN_DEPS=[64,[[64,64,256]]*3,
                       [[128,128,512]]*4,
                       [[256, 256, 1024]]*6,
                       [[512, 512, 2048]]*3]
RESNET50_ORIGIN_DEPS_FLATTENED = [64,256,64,64,256,64,64,256,64,64,256,512,128,128,512,128,128,512,128,128,512,128,128,512,
                                  1024,256, 256, 1024,256, 256, 1024,256, 256, 1024,256, 256, 1024,256, 256, 1024,256, 256, 1024,
                                  2048,512, 512, 2048,512, 512, 2048,512, 512, 2048]
RESNET50_ALL_CONV_LAYERS = range(0, len(RESNET50_ORIGIN_DEPS_FLATTENED))
RESNET50_INTERNAL_KERNEL_IDXES = [2,3,5,6,8,9,12,13,15,16,18,19,21,22,25,26,28,29,31,32,34,35,
                                  37,38,40,41,44,45,47,48,50,51]
RESNET50_PACESETTER_IDXES = [1, 11, 24, 43]
RESNET50_ALL_SURVEY_LAYERS = [0] + RESNET50_INTERNAL_KERNEL_IDXES + RESNET50_PACESETTER_IDXES
RESNET50_FOLLOW_DICT = {1:1, 4:1, 7:1, 10:1, 11:11, 14:11, 17:11, 20:11, 23:11, 24:24, 27:24, 30:24, 33:24, 36:24, 39:24, 42:24, 43:43, 46:43, 49:43, 52:43}
# RESNET50_FOLLOWER_DICT = {1:[1,4,7,10], 11:[11,14,17,20,23], 24:[24,27,30,33,36,39,42], 43:[43,46,49,52]}
RESNET50_SUBSEQUENT_STRATEGY = {i : (i+1) for i in RESNET50_INTERNAL_KERNEL_IDXES}
RESNET50_SUBSEQUENT_STRATEGY[0] = [1,2]
idxes_before_pacesetters = [i-1 for i in RESNET50_PACESETTER_IDXES]
for i in RESNET50_FOLLOW_DICT.keys():
    if i not in RESNET50_PACESETTER_IDXES:
        if i in idxes_before_pacesetters:
            RESNET50_SUBSEQUENT_STRATEGY[i] = [i+1, i+2]
        else:
            RESNET50_SUBSEQUENT_STRATEGY[i] = i+1
#   we do not use deps[1,11,24,43]
def convert_flattened_resnet50_deps(deps):
    assert len(deps) == 53
    assert deps[1] == deps[4] and deps[11] == deps[14] and deps[24] == deps[27] and deps[43] == deps[46]
    d = [deps[0]]
    tmp = []
    for i in range(3):
        tmp.append([deps[2 + i * 3], deps[3 + i * 3], deps[4 + i * 3]])
    d.append(tmp)
    tmp = []
    for i in range(4):
        tmp.append([deps[12 + i * 3], deps[13 + i * 3], deps[14 + i * 3]])
    d.append(tmp)
    tmp = []
    for i in range(6):
        tmp.append([deps[25 + i * 3], deps[26 + i * 3], deps[27 + i * 3]])
    d.append(tmp)
    tmp = []
    for i in range(3):
        tmp.append([deps[44 + i * 3], deps[45 + i * 3], deps[46 + i * 3]])
    d.append(tmp)
    return d
#############################   end   ResNet50 ####################################


RESNET18_ORIGIN_DEPS_FLATTENED = [64,  64,64,64,  64,64,  128,128,128,  128,128,  256,256,256,
                                  256,256,  512,512,512,  512,512]
def convert_flattened_resnet18_deps(deps):
    assert len(deps) == 21
    assert deps[1] == deps[3] and deps[6] == deps[8] and deps[11] == deps[13] and deps[16] == deps[18]
    d = [deps[0]]
    tmp = []
    for i in range(2):
        tmp.append([deps[2 + i * 2], deps[3 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(2):
        tmp.append([deps[7 + i * 2], deps[8+ i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(2):
        tmp.append([deps[12 + i * 2], deps[13 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(2):
        tmp.append([deps[17 + i * 2], deps[18 + i * 2]])
    d.append(tmp)
    return d

#############################   begin   ResNet34 ####################################
RESNET34_ORIGIN_DEPS_FLATTENED = [64,  64,64,64,  64,64,  64,64,  128,128,128,  128,128,  128,128,  128,128,  256,256,256,
                                  256,256,  256,256,  256,256,  256,256,  256,256,  512,512,512,  512,512,  512,512]
RESNET34_PACESETTER_IDXES = [1, 8, 17, 30]
RESNET34_INTERNAL_KERNEL_IDXES = [2, 4, 6, 9, 11, 13, 15, 18, 20, 22, 24, 26, 28, 31, 33, 35]
RESNET34_FOLLOW_DICT = {1:1, 3:1, 5:1, 7:1, 8:8, 10:8, 12:8, 14:8, 16:8, 17:17, 19:17, 21:17, 23:17, 25:17, 27:17, 29:17,
                        30:30, 32:30, 34:30, 36:30}
RESNET34_SUBSEQUENT_STRATEGY = {i : (i+1) for i in RESNET34_INTERNAL_KERNEL_IDXES}
RESNET34_SUBSEQUENT_STRATEGY[0] = [1,2]
idxes_before_pacesetters = [i-1 for i in RESNET34_PACESETTER_IDXES]
for i in RESNET34_FOLLOW_DICT.keys():
    if i not in RESNET34_PACESETTER_IDXES:
        if i in idxes_before_pacesetters:
            RESNET34_SUBSEQUENT_STRATEGY[i] = [i+1, i+2]
        else:
            RESNET34_SUBSEQUENT_STRATEGY[i] = i+1
#   we do not use deps[1,8,17,30]
def convert_flattened_resnet34_deps(deps):
    assert len(deps) == 37
    assert deps[1] == deps[3] and deps[8] == deps[10] and deps[17] == deps[19] and deps[30] == deps[32]
    d = [deps[0]]
    tmp = []
    for i in range(3):
        tmp.append([deps[2 + i * 2], deps[3 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(4):
        tmp.append([deps[9 + i * 2], deps[10 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(6):
        tmp.append([deps[18 + i * 2], deps[19 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(3):
        tmp.append([deps[31 + i * 2], deps[32 + i * 2]])
    d.append(tmp)
    return d
#############################   end   ResNet34 ####################################


#############################   begin   DenseNet121 ####################################
DENSENET121_FLATTENED_DEPS=np.concatenate([[64], [128,32]*6, [128], [128,32]*12, [256], [128,32]*24, [512], [128,32]*16])
DENSENET121_INTERNAL_KERNEL_IDXES = np.array([1,3,5,7,9,11,
                                     14,16,18,20,22,24,26,28,30,32,34,36,
                                     39,41,43,45,47,49,51,53,55,57,59,61,63,65,67,69,71,73,75,77,79,81,83,85,
                                     88,90,92,94,96,98,100,102,104,106,108,110,112,114,116,118])
DENSENET121_ALL_CONV_LAYERS = np.arange(0, 120)
DENSENET121_STAGE1_INCRE_LAYERS = np.arange(2,13,2)
DENSENET121_STAGE2_INCRE_LAYERS = np.arange(15,38,2)
DENSENET121_STAGE3_INCRE_LAYERS = np.arange(40,87,2)
DENSENET121_STAGE4_INCRE_LAYERS = np.arange(89,120,2)


#   converted like:  [64, [[128, 32], [128, 32] ....... [128,32]], 128, [[], [], ....[]], 256, [[],[],...[]], 512, [[],[],....,[]]]
def convert_flattened_densenet121_deps(deps):
    assert len(deps) == 120
    d = [deps[0]]
    tmp = []
    for i in range(6):
        tmp.append([deps[1 + i * 2], deps[2 + i * 2]])
    d.append(tmp)
    d.append(deps[13])
    tmp = []
    for i in range(12):
        tmp.append([deps[14 + i * 2], deps[15 + i * 2]])
    d.append(tmp)
    d.append(deps[38])
    tmp = []
    for i in range(24):
        tmp.append([deps[39 + i * 2], deps[40 + i * 2]])
    d.append(tmp)
    d.append(deps[87])
    tmp = []
    for i in range(16):
        tmp.append([deps[88 + i * 2], deps[89 + i * 2]])
    d.append(tmp)
    return d
#############################   end   DenseNet121 ####################################

#############################   begin   AlexNet ##########################
# conv1/kernel:0 (11, 11, 3, 96)
# conv1/bias:0 (96,)
#
# conv3/kernel:0 (3, 3, 256, 384)
# conv3/bias:0 (384,)
# conv4_1/kernel:0 (3, 3, 192, 192)
# conv4_1/bias:0 (192,)
# conv4_2/kernel:0 (3, 3, 192, 192)
# conv4_2/bias:0 (192,)
# conv5_1/kernel:0 (3, 3, 192, 128)
# conv5_1/bias:0 (128,)
# conv5_2/kernel:0 (3, 3, 192, 128)
# conv5_2/bias:0 (128,)
# fc6/kernel:0 (9216, 4096)
# fc6/bias:0 (4096,)
# fc7/kernel:0 (4096, 4096)
# fc7/bias:0 (4096,)
# fc8/kernel:0 (4096, 1000)
# fc8/bias:0 (1000,)
#
# ALEXNET_DEPS = [48, 48, 128, 128, 192, 192, 192, 192, 128, 128]
# ALEXNET_FC_OUTS = [4096, 4096, 1000]
# ALEXNET_SUBSEQUENT_STRATEGY = {0:2, 1:3, 2:[4,5], 3:[4,5], 4:6, 5:7, 6:8, 7:9, 8:10, 9:10}
# ALEXNET_CONV_LAYER_IDXES = range(0,10)
# ALEXNET_DEPS = [48, 48, 128, 128, 192, 192, 192, 192, 128, 128]
# ALEXNET_NISP_A_DEPS = [24,24,64,64,96,96,96,96,64,64]
# ALEXNET_NISP_B_DEPS = [24,24,64,64,96,96,96,96,128,128]
# ALEXNET_NISP_C_DEPS = [24,24,64,64,96,96,192, 192, 128, 128]
# ALEXNET_NISP_D_DEPS = [48, 48,64,64,96,96,192, 192, 128, 128]
# def alexnet_subsequent_offset(cur_itr_remain_filters_of_conv_2_1, that_of_conv_5_1=ALEXNET_DEPS[-2]):
#     return {3:cur_itr_remain_filters_of_conv_2_1, 9:that_of_conv_5_1}
#############################   end     AlexNet ##########################



#############################   BDS training settings   ##################
VFS_LAYER_TO_INIT_GRANU_SETTING_1 = [2,2,4,4,8,8,8,16,16,16,32,32,32]
VFS_LAYER_TO_INIT_GRANU_SETTING_2 = [4,4,8,8,16,16,16,32,32,32,32,32,32]
VFS_LAYER_TO_INIT_GRANU_SETTING_3 = [8,8,16,16,32,32,32,64,64,64,64,64,64]
VFS_LAYER_TO_INIT_GRANU_SETTING_INF = [99999 for i in range(13)]
VFS_LAYER_TO_LOSS_INC_LIMIT_SETTING = {i:0.25 for i in range(0,13)}
VFS_LAYER_TO_EXAMPLES_PER_HALF_SETTING = {i:[[20000]] for i in range(0,13)}





########################### MARSRES
MARSRES_ORIGIN_DEPS = [32, 32,
                       32, 32,
                       32, 32,
                       64, 64, 64,
                       64, 64,
                       128, 128, 128,
                       128, 128, ]

MR2_ORIGIN_DEPS = [16, 16,
                       32, 16,
                       32, 16,
                       64, 48, 48,
                       64, 48,
                       128, 96, 96,
                       128, 96, ]



MARSRES_SUBSEQUENT_STRATEGY = {0:1, 1:2, 2:3, 3:4, 4:5, 5:[6,8], 6:7, 7:9, 9:10, 10:[11,13], 11:12, 12:14, 14:15, 15:16}
MARSRES_FOLLOW_DICT = {3:1, 5:1, 7:8, 10:8, 12:13, 15:13}
MARSRES_PACESETTERS = [1, 8, 13]
MARSRES_PACESETTERS_AND_FOLLOWERS = [1, 3, 5,
                7, 8, 10,
                12, 13, 15]
MARSRES_INTERNAL_LAYERS = [0, 2, 4, 6, 9, 11, 14]

################    GooGleNet
GOOGLENET_ORIGIN_DEPS = np.array([64,64,192,
64,96,128,16,32,32,
128,128,192,32,96,64,
192,96,208,16,48,64,
160,112,224,24,64,64,
128,128,256,24,64,64,
112,144,288,32,64,64,
256,160,320,32,128,128,
256,160,320,32,128,128,
384,192,384,48,128,128])




##################### general Resnet on CIFAR-10
def rc_origin_deps_flattened(n):
    assert n in [9, 12, 18, 27, 200]
    filters_in_each_stage = n * 2 + 1
    stage1 = [16] * filters_in_each_stage
    stage2 = [32] * filters_in_each_stage
    stage3 = [64] * filters_in_each_stage
    return np.array(stage1 + stage2 + stage3)

def rc_convert_flattened_deps(flattened):
    filters_in_each_stage = len(flattened) / 3
    n = int((filters_in_each_stage - 1) // 2)
    assert n in [9, 12, 18, 27, 200]
    pacesetters = rc_pacesetter_idxes(n)
    result = [flattened[0]]
    for ps in pacesetters:
        assert flattened[ps] == flattened[ps+2]
        stage_deps = []
        for i in range(n):
            stage_deps.append([flattened[ps + 1 + 2 * i], flattened[ps + 2 + 2 * i]])
        result.append(stage_deps)
    return result

def rc_pacesetter_idxes(n):
    assert n in [9, 12, 18, 27, 200]
    filters_in_each_stage = n * 2 + 1
    pacesetters = [0, int(filters_in_each_stage), int(2 * filters_in_each_stage)]
    return pacesetters

def rc_internal_layers(n):
    assert n in [9, 12, 18, 27, 200]
    pacesetters = rc_pacesetter_idxes(n)
    result = []
    for ps in pacesetters:
        for i in range(n):
            result.append(ps + 1 + 2 * i)
    return result

def rc_all_survey_layers(n):
    return rc_pacesetter_idxes(n) + rc_internal_layers(n)

def rc_all_cov_layers(n):
    return range(0, 6*n+3)

def rc_follow_dict(n):
    assert n in [9, 12, 18, 27, 200]
    pacesetters = rc_pacesetter_idxes(n)
    result = {}
    for ps in pacesetters:
        for i in range(0, n+1):
            result[ps + 2 * i] = ps
    return result

def rc_subsequent_strategy(n):
    assert n in [9, 12, 18, 27, 200]
    internal_layers = rc_internal_layers(n)
    result = {i : (i+1) for i in internal_layers}
    result[0] = 1
    follow_dic = rc_follow_dict(n)
    pacesetters = rc_pacesetter_idxes(n)
    layer_before_pacesetters = [i-1 for i in pacesetters]
    for i in follow_dic.keys():
        if i in layer_before_pacesetters:
            result[i] = [i+1, i+2]
        elif i not in pacesetters:
            result[i] = i + 1
    return result

def rc_fc_layer_idx(n):
    assert n in [9, 12, 18, 27, 200]
    return 6*n+3

def rc_stage_to_pacesetter_idx(n):
    ps_ids = rc_pacesetter_idxes(n)
    return {2:ps_ids[0], 3:ps_ids[1], 4:ps_ids[2]}

RESNET56_PACESETTER_IDXES = [0, 19, 38]
RESNET56_FOLLOW_DICT = {0:0, 2:0, 4:0, 6:0, 8:0, 10:0, 12:0, 14:0, 16:0, 18:0,
                        19:19, 21:19, 23:19, 25:19, 27:19, 29:19, 31:19, 33:19, 35:19, 37:19,
                        38:38, 40:38, 42:38, 44:38, 46:38, 48:38, 50:38, 52:38, 54:38, 56:38}




# validate
def is_dict_equal(d1, d2):
    if len(d1) != len(d2):
        return False
    for k, v in d1.items():
        if v != d2[k]:
            return False
    return True

def is_array_or_list_equal(a1, a2):
    if len(a1) != len(a2):
        return False
    for i in range(len(a1)):
        if a1[i] != a2[i]:
            return False
    return True



RESNET56_ORIGIN_DEPS = [16,
                        [[16, 16]] * 9,
                        [[32, 32]] * 9,
                        [[64, 64]] * 9]
RESNET56_ORIGIN_DEPS_FLATTENED = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
                                  32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
                                  64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]
RESNET56_INTERNAL_KERNEL_IDXES = [1,3,5,7,9,11,13,15,17,20,22,24,26,28,30,32,34,36,39,41,43,45,47,49,51,53,55]
RESNET56_PACESETTER_IDXES = [0, 19, 38]
RESNET56_FOLLOW_DICT = {0:0, 2:0, 4:0, 6:0, 8:0, 10:0, 12:0, 14:0, 16:0, 18:0,
                        19:19, 21:19, 23:19, 25:19, 27:19, 29:19, 31:19, 33:19, 35:19, 37:19,
                        38:38, 40:38, 42:38, 44:38, 46:38, 48:38, 50:38, 52:38, 54:38, 56:38}
RESNET56_SUBSEQUENT_STRATEGY = {i : (i+1) for i in RESNET56_INTERNAL_KERNEL_IDXES}
RESNET56_SUBSEQUENT_STRATEGY[0] = 1
for i in RESNET56_FOLLOW_DICT.keys():
    if i == 18:
        RESNET56_SUBSEQUENT_STRATEGY[i] = [19, 20]
    elif i == 37:
        RESNET56_SUBSEQUENT_STRATEGY[i] = [38, 39]
    elif i not in RESNET56_PACESETTER_IDXES:
        RESNET56_SUBSEQUENT_STRATEGY[i] = i+1
def convert_flattened_resnet56_deps(flattened_deps):
    assert len(flattened_deps) == 57
    assert flattened_deps[19] == flattened_deps[21] and flattened_deps[39] == flattened_deps[41]
    d = [flattened_deps[0]]
    tmp = []
    for i in range(9):
        tmp.append([flattened_deps[1 + i * 2], flattened_deps[2 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(9):
        tmp.append([flattened_deps[20 + i * 2], flattened_deps[21 + i * 2]])
    d.append(tmp)
    tmp = []
    for i in range(9):
        tmp.append([flattened_deps[39 + i * 2], flattened_deps[40 + i * 2]])
    d.append(tmp)
    return d
#
# # validate
# print(is_array_or_list_equal(rc_pacesetter_idxes(9), RESNET56_PACESETTER_IDXES))
# print(is_array_or_list_equal(rc_origin_deps_flattened(9), RESNET56_ORIGIN_DEPS_FLATTENED))
# print(is_array_or_list_equal(rc_internal_layers(9), RESNET56_INTERNAL_KERNEL_IDXES))
# print(is_dict_equal(rc_follow_dict(9), RESNET56_FOLLOW_DICT))
# print(is_dict_equal(rc_subsequent_strategy(9), RESNET56_SUBSEQUENT_STRATEGY))
# print(convert_flattened_resnet56_deps(RESNET56_ORIGIN_DEPS_FLATTENED))
# print(rc_convert_flattened_deps(RESNET56_ORIGIN_DEPS_FLATTENED))


MOBILENET_ORIGIN_DEPS = np.array([32,
                                  32, 64,
                                  64, 128,
                                  128, 128,
                                  128, 256,
                                  256, 256,
                                  256, 512,
                                  512, 512,
                                  512, 512,
                                  512, 512,
                                  512, 512,
                                  512, 512,
                                  512, 1024,
                                  1024, 1024])
MOBILENET_ALL_CONV_LAYERS = range(0, 27)
MOBILENET_ALL_SURVEY_LAYERS = range(0, 27, 2) #0, 2, ...,26   in total 14 layers.
                # Actually no layers 'follow' (mimic) layer 26, it is independent
MOBILENET_FOLLOWERS = range(1, 27, 2)   #1, 3, ...,25   in total 13 layers.
MOBILENET_FOLLOW_DICT = {i:(i-1) for i in MOBILENET_FOLLOWERS}
# MOBILENET_SUBSEQUENT_STRATEGY = {i:[i+1, i+2] for i in range(0, 25, 2)} # 0:[1,2], ...., 24:[25, 26], 26:27
# MOBILENET_SUBSEQUENT_STRATEGY[26] = 27
MOBILENET_SUBSEQUENT_STRATEGY = 'simple'
MOBILENET_FC_LAYERS = [99999]   #because the last layer is still conv (on 1x1 feature map), so it should be handled like a conv layer



DC40_ORIGIN_DEPS = [16] + [12]*12 + [160] + [12]*12 + [304] + [12]*12   #transition: 13 and 26
DC40_FC_LAYERS = [39]
DC40_SUBSEQUENT_STRATEGY = None
DC40_FOLLOW_DICT = None
DC40_ALL_CONV_LAYERS = range(0, len(DC40_ORIGIN_DEPS))
# completely rewrite the pruning function for DC40!
def customized_dc40_deps(arg):
    if '-rep' in arg:
        arg = arg[:arg.find('-rep')]
    if arg.endswith('-trans'):
        print('customized dc40 deps with transition layers modified')
        return trans_dc40_deps(arg)
    else:
        print('no modification to transition layers')
        deps = np.array(DC40_ORIGIN_DEPS)
        settings = arg.split('-')
        deps[1:13] = int(settings[0])
        deps[14:26] = int(settings[1])
        deps[27:39] = int(settings[2])
        return deps


def trans_dc40_deps(arg):
    assert arg.endswith('-trans')
    deps = np.array(DC40_ORIGIN_DEPS)
    settings = arg.split('-')
    deps[1:13] = int(settings[0])
    deps[13] = deps[0] + int(settings[0]) * 12
    deps[14:26] = int(settings[1])
    deps[26] = deps[13] + int(settings[1]) * 12
    deps[27:39] = int(settings[2])
    return deps