from bc import *
from bc_base import *


class RTSBC(BenchmarkCNN):

    def __init__(self, params, my_params, rts_params, dataset=None, model=None):
        super(RTSBC, self).__init__(params=params, my_params=my_params, dataset=dataset, model=model)
        self.rts_params = rts_params

    def simple_rts_survey(self, survey_save_file, survey_layer_ids, eval_feed_dict=None):
        assert self.my_params.need_record_internal_outputs
        sess = self.sess
        feed_dict = eval_feed_dict or {}

        if self.dataset.queue_runner_required():
            tf.train.start_queue_runners(sess=sess)
        image_producer = None
        if self.input_producer_op is not None:
            image_producer = cnn_util.ImageProducer(
                sess, self.input_producer_op, self.batch_group_size,
                self.params.use_python32_barrier)
            image_producer.start()
        if self.enqueue_ops:
            for i in xrange(len(self.enqueue_ops)):
                sess.run(self.enqueue_ops[:(i + 1)])
                if image_producer is not None:
                    image_producer.notify_image_consumption()
        loop_start_time = start_time = time.time()

        internal_outputs_dict = self.fetches['internal_outputs']

        aggregated_outputs_dict = {}

        kernels = self.get_kernel_variables()

        # self.internal_outputs_dict['{}${}'.format(self.num_internal_conv_outputs, name)] = conv
        # self.internal_outputs_dict['{}#{}'.format(self.num_internal_conv_outputs, name)] = biased

        def get_input(kernel_id):
            if kernel_id == 0:
                print('the input to the kernel is the input to the model')
                return internal_outputs_dict['input']
            preced_id = None
            for key, value in self.rts_params.subsequent_strategy.items():
                if hasattr(value, '__len__'):
                    if kernel_id in value:
                        preced_id = key
                        break
                elif value == kernel_id:
                    preced_id = key
                    break
            keyword = '{}#'.format(preced_id)
            for key, value in internal_outputs_dict.items():
                if key.startswith(keyword):
                    return value
            return None

        def get_follow_kernels(kernel_idx):
            if kernel_idx in self.rts_params.subsequent_strategy:
                follow_id = self.rts_params.subsequent_strategy[kernel_idx]
            else:
                follow_id = kernel_idx + 3
                print('found no follower for {} in the subsequent_strategy, so we use {} by default'.format(kernel_idx, follow_id))
            if hasattr(follow_id, '__len__'):
                return [kernels[p] for p in follow_id]
            else:
                return [kernels[follow_id]]

        assert self.params.data_format == 'NHWC'
        for i, k in enumerate(kernels):
            if i not in survey_layer_ids:
                continue
            sigma_var = self.get_moving_variance_variable_for_kernel(k)
            gamma_var = self.get_gamma_variable_for_kernel(k)

            k_shape = k.get_shape().as_list()
            input_featuremap = tf.nn.relu(get_input(i)) #TODO
            follow_kernels = get_follow_kernels(i)

            if sigma_var is not None:
                print('got the gamma var: ', sigma_var.name)
                sigma_var = tf.expand_dims(tf.expand_dims(tf.expand_dims(sigma_var, 0), 0), 0)
                sigma_var = tf.tile(sigma_var, [k_shape[0], k_shape[1], k_shape[2], 1])
                k = k / tf.sqrt(sigma_var + self.convnet_builder.batch_norm_config['epsilon'])

            if gamma_var is not None:
                print('got the gamma var: ', gamma_var.name)
                gamma_var = tf.expand_dims(tf.expand_dims(tf.expand_dims(gamma_var, 0), 0), 0)
                gamma_var = tf.tile(gamma_var, [k_shape[0], k_shape[1], k_shape[2], 1])
                k = k * gamma_var

            score = tf.abs(k)

            pre_factor = tf.pow(tf.reduce_mean(tf.abs(input_featuremap) ** self.rts_params.power, axis=(0,1,2), keepdims=True), 1/self.rts_params.power)
            pre_factor = tf.nn.softmax(pre_factor)
            pre_factor = tf.transpose(pre_factor, [0, 1, 3, 2])
            print('shape of pre_factor is ', pre_factor.get_shape())
            pre_factor = tf.tile(pre_factor, [k_shape[0], k_shape[1], 1, k_shape[3]])
            #TODO use softmax for normalization?
            score = score * pre_factor

            # if len(follow_kernels) == 1:
            #     post_factor = tf.abs(follow_kernels[0]) ** self.rts_params.power
            # else:
            #     post_factor = tf.add_n([tf.abs(p) ** self.rts_params.power for p in follow_kernels])
            if len(follow_kernels) == 1:
                post_factor = tf.abs(follow_kernels[0]) ** self.rts_params.power
            else:
                post_factor = tf.abs(follow_kernels[1]) ** self.rts_params.power


            if len(post_factor.get_shape().as_list()) == 4:
                post_factor = tf.reduce_sum(post_factor, axis=(0,1,3))
                post_factor = tf.pow(post_factor, 1/self.rts_params.power)
                post_factor = tf.nn.softmax(post_factor)
                post_factor = tf.expand_dims(tf.expand_dims(tf.expand_dims(post_factor, 0), 0), 0)
                post_factor = tf.tile(post_factor, [k_shape[0], k_shape[1], k_shape[2], 1])
            else:
                num_last_filters = k_shape[3]
                print('num_last_filters=', num_last_filters)
                num_input_neurons = post_factor.get_shape().as_list()[0]
                if num_last_filters == num_input_neurons:
                    print('got num_last_filters == num_input_neurons')
                    post_factor = tf.reduce_sum(post_factor, axis=1)
                    post_factor = tf.pow(post_factor, 1 / self.rts_params.power)
                    post_factor = tf.nn.softmax(post_factor)
                    post_factor = tf.expand_dims(tf.expand_dims(tf.expand_dims(post_factor, 0), 0), 0)
                    post_factor = tf.tile(post_factor, [k_shape[0], k_shape[1], k_shape[2], 1])
                else:
                    #   TODO not work
                    base = np.arange(0, post_factor.get_shape().as_list()[0], num_last_filters, dtype=np.int32)
                    print('base = ', base)
                    xx = post_factor[0, :]
                    print(xx.get_shape())
                    ll = []
                    for p in range(num_last_filters):
                        if len(base) == 1:
                            ll.append(post_factor[base + p, :])
                        else:
                            ll.append(tf.reduce_sum(post_factor[base + p, :]))
                    post_factor = tf.concat(ll)
                    post_factor = tf.expand_dims(tf.expand_dims(tf.expand_dims(post_factor, 0), 0), 0)
                    post_factor = tf.tile(post_factor, [k_shape[0], k_shape[1], k_shape[2], 1])

            score = score * post_factor
            aggregated_outputs_dict[i] = score

        accumulate_score_dict = {}

        for step in xrange(self.num_batches):
            results = sess.run(aggregated_outputs_dict, feed_dict=feed_dict)
            for out_idx, out_array in results.items():
                accumulate_score_dict[out_idx] = accumulate_score_dict.get(out_idx, np.zeros_like(out_array)) + out_array
            # results = self.model.postprocess(results)
            if (step + 1) % self.params.display_every == 0:
                duration = time.time() - start_time
                examples_per_sec = (
                    self.batch_size * self.params.display_every / duration)
                log_fn('%i\t%.1f examples/sec' % (step + 1, examples_per_sec))
                start_time = time.time()
            if image_producer is not None:
                image_producer.notify_image_consumption()
        loop_end_time = time.time()
        if image_producer is not None:
            image_producer.done()

        elapsed_time = loop_end_time - loop_start_time
        images_per_sec = (self.num_batches * self.batch_size / elapsed_time)
        log_fn('-' * 64)
        log_fn('total images/sec: %.2f' % images_per_sec)
        log_fn('-' * 64)

        save_dict = {}
        for idx, sum in accumulate_score_dict.items():
            save_dict[idx] = sum / self.num_batches

        print('save the survey hdf5 to {}, where the keys are {}'.format(survey_save_file, save_dict.keys()))
        save_hdf5(save_dict, survey_save_file)
        return save_dict


    #   every layer, under multiple thresholds, and global
    #   mean magnitude, std magnitude
    def do_extra_summaries(self, summary_writer, local_step, sess, graph_info):
        if local_step % 1000 == 0:

            thresholds = [-7, -6, -5, -4, -3, -2, -1]

            all_kts = self.get_kernel_variables(self.rts_params.target_layers)
            kts = [t for t in all_kts if t.name.startswith('v0')]

            kvs = self.get_value(kts)
            summary = tf.Summary()

            all_abs_weights = []

            for i, (t, v) in enumerate(zip(kts, kvs)):
                absv = np.abs(v)
                all_abs_weights.append(np.ravel(absv))
                for thresh_pow in thresholds:
                    thresh_value = pow(10, thresh_pow)
                    under_cnt = np.sum(absv < thresh_value)
                    summary.value.add(tag='layer{}_{}/abs_under_1e{}'.format(i, t.name.replace('v0/cg', '').replace('/', '+'), thresh_pow), simple_value=under_cnt / np.size(v))
                summary.value.add(tag='layer{}_{}/abs_mean'.format(i, t.name.replace('v0/cg', '').replace('/', '+')), simple_value=np.mean(absv))
                summary.value.add(tag='layer{}_{}/abs_std'.format(i, t.name.replace('v0/cg', '').replace('/', '+')), simple_value=np.std(absv))

            all_abs_weights = np.concatenate(all_abs_weights)
            cnt_all_weights = np.size(all_abs_weights)
            print('number of all the weights: ', cnt_all_weights)
            for thresh_pow in thresholds:
                thresh_value = pow(10, thresh_pow)
                under_cnt = np.sum(all_abs_weights < thresh_value)
                summary.value.add(tag='global/abs_under_1e{}'.format(thresh_pow), simple_value=under_cnt / cnt_all_weights)
            summary.value.add(tag='global/abs_mean', simple_value=np.mean(all_abs_weights))
            summary.value.add(tag='global/abs_std', simple_value=np.std(all_abs_weights))

            summary_writer.add_summary(summary, sess.run(graph_info.global_step))

    def do_train(self, graph_info):

        if self.rts_params.thresh_delay is None:
            return super(RTSBC, self).do_train(graph_info)

        if self.params.variable_update == 'horovod':
            import horovod.tensorflow as hvd  # pylint: disable=g-import-not-at-top
            # First worker will be 'chief' - it will write summaries and
            # save checkpoints.
            is_chief = hvd.rank() == 0
        else:
            is_chief = (not self.job_name or self.task_index == 0)

        summary_op = tf.summary.merge_all()
        # summary_op = tf.group(summary_op, graph_info.summary_op_group)
        # summary_op = tf.group(*graph_info.summary_ops)

        summary_writer = None
        if (is_chief and self.params.summary_verbosity and self.params.train_dir and
                    self.params.save_summaries_steps > 0):
            summary_writer = tf.summary.FileWriter(self.params.train_dir,
                tf.get_default_graph())

        # We want to start the benchmark timer right after a image_producer barrier
        # and avoids undesired waiting times on barriers.
        if ((self.num_warmup_batches + len(graph_info.enqueue_ops) - 1) %
                self.batch_group_size) != 0:
            self.num_warmup_batches = int(
                math.ceil(
                    (self.num_warmup_batches + len(graph_info.enqueue_ops) - 1.0) /
                    (self.batch_group_size)) * self.batch_group_size -
                len(graph_info.enqueue_ops) + 1)
            log_fn('Round up warm up steps to %d to match batch_group_size' %
                   self.num_warmup_batches)
            assert ((self.num_warmup_batches + len(graph_info.enqueue_ops) - 1) %
                    self.batch_group_size) == 0
        # We run the summaries in the same thread as the training operations by
        # passing in None for summary_op to avoid a summary_thread being started.
        # Running summaries and training operations in parallel could run out of
        # GPU memory.
        if is_chief and not self.forward_only_and_freeze:
            saver = tf.train.Saver(
                self.variable_mgr.savable_variables(),
                save_relative_paths=True,
                max_to_keep=self.params.max_ckpts_to_keep)
        else:
            saver = None
        ready_for_local_init_op = None
        if self.job_name and not (self.single_session or
                                      self.distributed_collective):
            # In distributed mode, we don't want to run local_var_init_op_group until
            # the global variables are initialized, because local_var_init_op_group
            # may use global variables (such as in distributed replicated mode). We
            # don't set this in non-distributed mode, because in non-distributed mode,
            # local_var_init_op_group may itself initialize global variables (such as
            # in replicated mode).
            ready_for_local_init_op = tf.report_uninitialized_variables(
                tf.global_variables())
        if self.params.variable_update == 'horovod':
            import horovod.tensorflow as hvd  # pylint: disable=g-import-not-at-top
            bcast_global_variables_op = hvd.broadcast_global_variables(0)
        else:
            bcast_global_variables_op = None

        if self.params.variable_update == 'collective_all_reduce':
            # It doesn't matter what this collective_graph_key value is,
            # so long as it's > 0 and the same at every worker.
            init_run_options = tf.RunOptions()
            init_run_options.experimental.collective_graph_key = 6
        else:
            init_run_options = tf.RunOptions()
        sv = MySupervisor(
            # For the purpose of Supervisor, all Horovod workers are 'chiefs',
            # since we want session to be initialized symmetrically on all the
            # workers.
            is_chief=is_chief or (self.params.variable_update == 'horovod'
                                  or self.distributed_collective),
            # Log dir should be unset on non-chief workers to prevent Horovod
            # workers from corrupting each other's checkpoints.
            logdir=self.params.train_dir if is_chief else None,
            ready_for_local_init_op=ready_for_local_init_op,
            local_init_op=graph_info.local_var_init_op_group,
            saver=saver,
            global_step=graph_info.global_step,
            summary_op=None,
            save_model_secs=self.params.save_model_secs,
            summary_writer=summary_writer,
            local_init_run_options=init_run_options,
            load_ckpt_full_path=self.my_params.load_ckpt,
            auto_continue=self.my_params.auto_continue)



        step_train_times = []
        start_standard_services = (
            self.params.train_dir or
            self.dataset.queue_runner_required())
        target = self.cluster_manager.get_target() if self.cluster_manager else ''

        #shawn
        sess_context = sv.managed_session(
                master=target,
                config=create_config_proto(self.params),
                start_standard_services=start_standard_services)

        with sess_context as sess:

            self.sess = sess

            if self.params.backbone_model_path is not None:
                self.model.load_backbone_model(sess, self.params.backbone_model_path)
            if bcast_global_variables_op:
                sess.run(bcast_global_variables_op)

            image_producer = None
            if graph_info.input_producer_op is not None:
                image_producer = cnn_util.ImageProducer(
                    sess, graph_info.input_producer_op, self.batch_group_size,
                    self.params.use_python32_barrier)
                image_producer.start()
            if graph_info.enqueue_ops:
                for i in xrange(len(graph_info.enqueue_ops)):
                    sess.run(graph_info.enqueue_ops[:(i + 1)])
                    if image_producer is not None:
                        image_producer.notify_image_consumption()
            self.init_global_step, = sess.run([graph_info.global_step])
            print('the current global step is ', self.init_global_step)
            if self.job_name and not self.params.cross_replica_sync:
                # TODO(zhengxq): Do we need to use a global step watcher at all?
                global_step_watcher = GlobalStepWatcher(
                    sess, graph_info.global_step,
                    self.num_workers * self.num_warmup_batches +
                    self.init_global_step,
                    self.num_workers * (self.num_warmup_batches + self.num_batches) - 1)
                global_step_watcher.start()
            else:
                global_step_watcher = None

            if self.graph_file is not None:
                path, filename = os.path.split(self.graph_file)
                as_text = filename.endswith('txt')
                log_fn('Writing GraphDef as %s to %s' % (  # pyformat break
                    'text' if as_text else 'binary', self.graph_file))
                tf.train.write_graph(sess.graph.as_graph_def(add_shapes=True), path,
                    filename, as_text)

            log_fn('Running warm up')
            local_step = -1 * self.num_warmup_batches
            if self.single_session:
                # In single session mode, each step, the global_step is incremented by
                # 1. In non-single session mode, each step, the global_step is
                # incremented once per worker. This means we need to divide
                # init_global_step by num_workers only in non-single session mode.
                end_local_step = self.num_batches - self.init_global_step
            else:
                end_local_step = self.num_batches - (self.init_global_step /
                                                     self.num_workers)

            if not global_step_watcher:
                # In cross-replica sync mode, all workers must run the same number of
                # local steps, or else the workers running the extra step will block.
                done_fn = lambda: local_step >= end_local_step
            else:
                done_fn = global_step_watcher.done
            if self.params.debugger is not None:
                if self.params.debugger == 'cli':
                    log_fn('The CLI TensorFlow debugger will be used.')
                    sess = tf_debug.LocalCLIDebugWrapperSession(sess)
                    self.sess = sess
                else:
                    log_fn('The TensorBoard debugger plugin will be used.')
                    sess = tf_debug.TensorBoardDebugWrapperSession(sess, self.params.debugger)
                    self.sess = sess
            profiler = tf.profiler.Profiler() if self.params.tfprof_file else None
            loop_start_time = time.time()
            last_average_loss = None

            print('self.lr_boundaries=', self.lr_boundaries)

            rts_mask_feed_dict = self.gradient_handler.init_mask_dict #TODO rts delay related

            while not done_fn():
                if local_step == 0:
                    log_fn('Done warm up')
                    if graph_info.execution_barrier:
                        log_fn('Waiting for other replicas to finish warm up')
                        sess.run([graph_info.execution_barrier])

                    # TODO(laigd): rename 'Img' to maybe 'Input'.
                    header_str = ('Step\tImg/sec\t' +
                                  self.params.loss_type_to_report.replace('/', ' '))
                    if self.params.print_training_accuracy or self.params.forward_only:
                        # TODO(laigd): use the actual accuracy op names of the model.
                        header_str += '\ttop_1_accuracy\ttop_5_accuracy'
                    log_fn(header_str)
                    assert len(step_train_times) == self.num_warmup_batches
                    # reset times to ignore warm up batch
                    step_train_times = []
                    loop_start_time = time.time()
                if (summary_writer and (local_step + 1) % self.params.save_summaries_steps == 0):
                    fetch_summary = summary_op
                else:
                    fetch_summary = None

                collective_graph_key = 7 if (self.params.variable_update == 'collective_all_reduce') else 0

                ################    TODO rts delay related

                if local_step % self.rts_params.thresh_delay == 0:
                    (summary_str, last_average_loss, mask_value_dict) = benchmark_one_step(
                        sess, graph_info.fetches, local_step,
                        self.batch_size * (self.num_workers
                                           if self.single_session else 1), step_train_times,
                        self.trace_filename, self.params.partitioned_graph_file_prefix,
                        profiler, image_producer, self.params, fetch_summary,
                        benchmark_logger=self.benchmark_logger,
                        collective_graph_key=collective_graph_key,
                        track_mvav_op=graph_info.mvav_op,
                        extra_feed_dict=rts_mask_feed_dict,
                        extra_fetch_dict=self.gradient_handler.var_name_to_mask_op)

                    num_rewake = 0
                    num_last_masked = 0
                    for k, v in mask_value_dict.items():
                        num_rewake += np.sum(np.logical_and(v == 1., rts_mask_feed_dict[self.gradient_handler.var_name_to_mask_ph[k]] == 0))
                        num_last_masked += np.sum(rts_mask_feed_dict[self.gradient_handler.var_name_to_mask_ph[k]] == 0)
                        rts_mask_feed_dict[self.gradient_handler.var_name_to_mask_ph[k]] = v
                        # print('update the mask value for {}, the zero rate is {}'.format(k, 1 - np.sum(v) / np.size(v)))
                    summary = tf.Summary()
                    summary.value.add(tag='global/rewake_num', simple_value=num_rewake)
                    summary.value.add(tag='global/rewake_ratio', simple_value=num_rewake / num_last_masked if num_last_masked !=0 else 0)
                    summary_writer.add_summary(summary, sess.run(graph_info.global_step))

                else:
                    (summary_str, last_average_loss, _) = benchmark_one_step(
                        sess, graph_info.fetches, local_step,
                        self.batch_size * (self.num_workers
                                           if self.single_session else 1), step_train_times,
                        self.trace_filename, self.params.partitioned_graph_file_prefix,
                        profiler, image_producer, self.params, fetch_summary,
                        benchmark_logger=self.benchmark_logger,
                        collective_graph_key=collective_graph_key,
                        track_mvav_op=graph_info.mvav_op,
                        extra_feed_dict=rts_mask_feed_dict)


                local_step += 1

                if summary_str is not None and is_chief:
                    sv.summary_computed(sess, summary_str)

                self.do_extra_summaries(summary_writer = summary_writer, local_step = local_step, sess=sess, graph_info=graph_info)

                if (self.my_params.num_steps_per_hdf5 > 0 and local_step % self.my_params.num_steps_per_hdf5 == 0 and local_step > 0 and is_chief):
                    self.save_hdf5_by_global_step(sess.run(graph_info.global_step))

                if (self.params.save_model_steps and local_step % self.params.save_model_steps == 0 and local_step > 0 and is_chief):
                    sv.saver.save(sess, sv.save_path, sv.global_step)

                if self.lr_boundaries is not None and local_step % 100 == 0 and local_step > 0 and is_chief:
                    cur_global_step = sess.run(graph_info.global_step)
                    for b in self.lr_boundaries:
                        if b > cur_global_step and b - cur_global_step < 100:
                            sv.saver.save(sess, sv.save_path, sv.global_step)
                            self.save_hdf5_by_global_step(cur_global_step)
                            break

                if self.my_params.frequently_save_interval is not None and self.my_params.frequently_save_last_epochs is not None and local_step % self.my_params.frequently_save_interval == 0 and local_step > 0 and is_chief:
                    cur_global_step = sess.run(graph_info.global_step)
                    remain_steps = self.num_batches - cur_global_step
                    remain_epochs = remain_steps * self.batch_size / self.dataset.num_examples_per_epoch(self.subset)
                    if remain_epochs < self.my_params.frequently_save_last_epochs:
                        self.save_hdf5_by_global_step(cur_global_step)






            loop_end_time = time.time()
            # Waits for the global step to be done, regardless of done_fn.
            if global_step_watcher:
                while not global_step_watcher.done():
                    time.sleep(.25)
            if not global_step_watcher:
                elapsed_time = loop_end_time - loop_start_time
                average_wall_time = elapsed_time / local_step if local_step > 0 else 0
                images_per_sec = (self.num_workers * local_step * self.batch_size /
                                  elapsed_time)
                num_steps = local_step * self.num_workers
            else:
                # NOTE: Each worker independently increases the global step. So,
                # num_steps will be the sum of the local_steps from each worker.
                num_steps = global_step_watcher.num_steps()
                elapsed_time = global_step_watcher.elapsed_time()
                average_wall_time = (elapsed_time * self.num_workers / num_steps
                                     if num_steps > 0 else 0)
                images_per_sec = num_steps * self.batch_size / elapsed_time

            if self.my_params.save_hdf5:
                print('start saving the final hdf5 to ', self.my_params.save_hdf5)
                self.save_weights_to_hdf5(self.my_params.save_hdf5)
                if self.my_params.save_mvav:
                    self.save_moving_average_weights_to_hdf5(self.my_params.save_hdf5.replace('.hdf5', '_mvav.hdf5'),
                        moving_averages=self.variable_averages)

            log_fn('-' * 64)
            # TODO(laigd): rename 'images' to maybe 'inputs'.
            log_fn('total images/sec: %.2f' % images_per_sec)
            log_fn('-' * 64)
            if image_producer is not None:
                image_producer.done()
            if is_chief:
                if self.benchmark_logger:
                    self.benchmark_logger.log_metric(
                        'average_examples_per_sec', images_per_sec, global_step=num_steps)

            # Save the model checkpoint.
            if self.params.train_dir is not None and is_chief:
                checkpoint_path = os.path.join(self.params.train_dir, 'model.ckpt')
                if not gfile.Exists(self.params.train_dir):
                    gfile.MakeDirs(self.params.train_dir)
                sv.saver.save(sess, checkpoint_path, graph_info.global_step)

            if graph_info.execution_barrier:
                # Wait for other workers to reach the end, so this worker doesn't
                # go away underneath them.
                sess.run([graph_info.execution_barrier])


        sv.stop()
        if profiler:
            generate_tfprof_profile(profiler, self.params.tfprof_file)
        stats = {
            'num_workers': self.num_workers,
            'num_steps': num_steps,
            'average_wall_time': average_wall_time,
            'images_per_sec': images_per_sec
        }
        if last_average_loss is not None:
            stats['last_average_loss'] = last_average_loss
        return stats


    def taylor_survey(self, survey_save_file, survey_layer_ids, accumulate_type, eval_feed_dict=None):
        assert accumulate_type in ['abs', 'pos', 'neg', 'origin']
        assert self.params.weight_decay == 0

        sess = self.sess

        if self.dataset.queue_runner_required():
            tf.train.start_queue_runners(sess=sess)

        image_producer = None
        if self.input_producer_op is not None:
            image_producer = cnn_util.ImageProducer(
                sess, self.input_producer_op, self.batch_group_size,
                self.params.use_python32_barrier)
            image_producer.start()
        if self.enqueue_ops:
            for i in xrange(len(self.enqueue_ops)):
                sess.run(self.enqueue_ops[:(i + 1)])
                if image_producer is not None:
                    image_producer.notify_image_consumption()
        loop_start_time = start_time = time.time()


        gv_outputs_dict = {}

        kernels = self.get_kernel_variables()
        kernels_to_survey = [kernels[i] for i in survey_layer_ids]
        print('the kernels to taylor-survey are ', kernels_to_survey)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0)
        grads_and_vars = optimizer.compute_gradients(loss=self.fetches['loss'], var_list=[kernels_to_survey])
        assert len(grads_and_vars) == len(survey_layer_ids)
        print('grads_and_vars: ', grads_and_vars)
        for i, (g, v) in enumerate(grads_and_vars):
            gv_outputs_dict[i] = g * v

        accum_record_dict = {}

        for step in xrange(self.num_batches):
            results = sess.run(gv_outputs_dict)
            for out_idx, out_array in results.items():

                if accumulate_type == 'origin':
                    to_add = out_array
                elif accumulate_type == 'abs':
                    to_add = np.abs(out_array)
                elif accumulate_type == 'pos':
                    to_add = np.array(out_array)
                    to_add[out_array < 0] = 0
                elif accumulate_type == 'neg':
                    to_add = np.array(out_array)
                    to_add[out_array > 0] = 0
                else:
                    assert False

                accum_record_dict[out_idx] = accum_record_dict.get(out_idx, 0) + to_add

            # results = self.model.postprocess(results)
            if (step + 1) % self.params.display_every == 0:
                duration = time.time() - start_time
                examples_per_sec = (
                    self.batch_size * self.params.display_every / duration)
                log_fn('%i\t%.1f examples/sec' % (step + 1, examples_per_sec))
                start_time = time.time()
            if image_producer is not None:
                image_producer.notify_image_consumption()

        loop_end_time = time.time()
        if image_producer is not None:
            image_producer.done()

        elapsed_time = loop_end_time - loop_start_time
        images_per_sec = (self.num_batches * self.batch_size / elapsed_time)
        log_fn('-' * 64)
        log_fn('total images/sec: %.2f' % images_per_sec)
        log_fn('-' * 64)

        print('save the survey hdf5 to {}, where the keys are {}'.format(survey_save_file, accum_record_dict.keys()))
        save_hdf5(accum_record_dict, survey_save_file)
        return accum_record_dict

