# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

'''
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu

Implemented the following paper:

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
import mxnet as mx
import numpy as np
from math_ops import *

# hyperparameters
gamma = 2
target = 0.6
is_shuffle= False

def residual_unit(data, num_filter, depth, quantize_w, w_bits, quantize_act, 
        act_bits, groups, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
    """Return ResNet Unit symbol for building ResNet
    Parameters
    ----------
    data : str
        Input data
    num_filter : int
        Number of output channels
    bnf : int
        Bottle neck channels factor with regard to num_filter
    stride : tuple
        Stride used in convolution
    dim_match : Boolean
        True means channel number between input and output is the same, otherwise means differ
    name : str
        Base name of the operators
    workspace : int
        Workspace used in convolution operator
    """
    if bottle_neck:
        # the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
        print "not supported yet"
    else:
        if (is_shuffle):
            data = channel_shuffle(data, groups)
        conv1, rmcc1, r1, mac1 = gSGBCNV(data=data, nfltr=num_filter, 
                quantize_w=quantize_w, w_bits=w_bits, quantize_act=quantize_act, act_bits=act_bits, depth=depth, 
                kernel=3, nsigma=-6.0-float(target), gamma=gamma, groups=groups, stop=1.0/groups, pad=1, stride=stride, 
                target=target, workspace=workspace, tag=name+'_conv1', 
                no_bias=True, level='pixel')
        
        # fix the wrong depth --- the depth for second conv should be num_filter instead of depth
        if (is_shuffle):
            conv1 = channel_shuffle(conv1, groups)
        conv2, rmcc2, r2, mac2 = gSGBCNV(data=conv1, nfltr=num_filter, 
                quantize_w=quantize_w, w_bits=w_bits, quantize_act=quantize_act, act_bits=act_bits, depth=num_filter, 
                kernel=3, nsigma=-6.0-float(target), gamma=gamma, groups=groups, stop=1.0/groups, pad=1, stride=(1,1), 
                target=target, workspace=workspace, tag=name+'_conv2', 
                no_bias=True, level='pixel')
        
        if dim_match:
            shortcut = data
            c_sum = mac1+mac2
            r_sum = rmcc1+rmcc2
        else:
            shortcut, rmcc3, r3, mac3 = gSGBCNV(data=data, nfltr=num_filter,
                    quantize_w=quantize_w, w_bits=w_bits, quantize_act=quantize_act, act_bits=act_bits, depth=depth, 
                    kernel=1, nsigma=-6.0-float(target), gamma=gamma, groups=groups, stop=1.0/groups, stride=stride, pad=0, 
                    target=target, workspace=workspace, tag=name+'_sc', 
                    no_bias=True, level='pixel')
            c_sum = mac1+mac2+mac3
            r_sum = rmcc1+rmcc2+rmcc3
        if memonger:
            shortcut._set_attr(mirror_stage='True')
        return conv2+shortcut, c_sum, r_sum

def resnet_gcg(units, num_stage, filter_list, num_class, quantize_act, 
        act_bits, quantize_w, w_bits, groups, data_type, bottle_neck=False, bn_mom=0.9, workspace=256, memonger=False):
    """Return ResNet symbol of
    Parameters
    ----------
    units : list
        Number of units in each stage
    num_stage : int
        Number of stage
    filter_list : list
        Channel size of each stage
    num_classes : int
        Ouput size of symbol
    dataset : str
        Dataset type, only cifar10 and imagenet supports
    workspace : int
        Workspace used in convolution operator
    dtype : str
        Precision (float32 or float16)
    """
    dtype='float32'
    num_unit = len(units)
    assert(num_unit == num_stage)
    data = mx.sym.Variable(name='data')
    label = mx.sym.Variable(name='softmax_label')
    data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
    if data_type == 'cifar10':            # such as cifar10
        body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1),
                                  no_bias=True, name="conv0", workspace=workspace)
        body = mx.sym.BatchNorm(data=body, fix_gamma=False, momentum=bn_mom, eps=2e-5)
        body = mx.sym.Activation(data=body, act_type='relu')
    else:                       # often expected to be 224 such as imagenet
        body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
                                  no_bias=True, name="conv0", workspace=workspace)
        body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
        body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
        body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')

    csum = mx.sym.zeros(shape=[1])
    rsum = mx.sym.zeros(shape=[1])
    for i in range(num_stage):
        body, c, r = residual_unit(body, filter_list[i+1], filter_list[i], quantize_w, w_bits, quantize_act, 
                act_bits, groups, (1 if i==0 else 2, 1 if i==0 else 2), False, 
                name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace, memonger=memonger)
        csum = csum+c
        rsum = rsum+r
        for j in range(units[i]-1):
            body, c, r = residual_unit(body, filter_list[i+1], filter_list[i+1], quantize_w, w_bits, quantize_act, 
                    act_bits, groups, (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck, workspace=workspace, memonger=memonger)
            csum = csum+c
            rsum = rsum+r

    body, flops = mx.sym.Custom(body, rsum, csum, op_type='DMon')
    bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
    relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1')
    # Although kernel is not used here when global_pool=True, we should put one
    pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
    flat = mx.sym.Flatten(data=pool1)
    fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_class, name='fc1')
    return mx.sym.SoftmaxOutput(data=fc1, name='softmax')
