# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Environment for high level policy."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import cPickle as pickle

from gym import spaces
import numpy as np
import tensorflow as tf


from hal.low_level_policy.model import Model
from hal.low_level_policy.model import VariableInputModel
from hal.low_level_policy.model import ImageModel
from hal.low_level_policy.config import get_config
import hal.language_processing_utils.word_vectorization as wv


self_attention_low_level_model_ckpt_path = None  # path to state-based low-level policy

image_factor_low_level_model_ckpt_path = None  # path to factor image-based low-level policy

diverse_image_factor_low_level_model_ckpt_path = None  # path to diver factor image-based low-level policy

vocab_path = os.path.join(__file__, '..', 'assets/vocab.txt')

diverse_vocab_path = os.path.join(__file__, '..', 'assets/variable_input_vocab.txt')



class HighLevelEnv(object):
  """Environment for training high level policy."""

  def __init__(
      self,
      low_level_env,
      tf_session,
      obs_type='direct',
      self_attention_low_level=False,
      diverse=False,
      low_level_obs_type='order_invariant',
      low_level_action_type='perfect',
      low_level_step=3,
      low_level_policy_name='target',
      text_embedding_path=None,
      low_level_model_ckpt_path=None):

    print('===========create high level environment=======================')
    vocab_list = wv.load_vocab_list(vocab_path)
    if diverse:
      vocab_list = wv.load_vocab_list(diverse_vocab_path)
    v2i, i2v = wv.create_look_up_table(vocab_list)
    self.decode_fn = wv.decode_with_lookup_table(i2v)
    self.encode_fn = wv.encode_text_with_lookup_table(v2i)

    self.vocab_size = len(vocab_list)
    if self.vocab_size > 40: self.vocab_size -= 1

    self.env = low_level_env
    self.sess = tf_session
    self.diverse = diverse


    self.text = None  # load encoded text which is the actions

    # loading low level policy
    self.low_level_action_type = low_level_action_type
    self.low_level_obs_type = low_level_obs_type
    self.cfg = get_config(self_attention_low_level)
    cfg = self.cfg


    if not self_attention_low_level and low_level_obs_type != 'image':
      self.low_level_policy_network = Model(
          input_dim=cfg.input_dim, ac_dim=cfg.ac_dim,
          vocab_size=cfg.vocab_size, embedding_size=cfg.embedding_size,
          conv_layer_config=cfg.conv_layer_config,
          dense_layer_config=cfg.dense_layer_config,
          encoder_n_unit=cfg.encoder_n_unit, seq_length=cfg.max_len,
          name=low_level_policy_name, direct_obs=cfg.direct_obs,
          variable_input=cfg.obs_type=='order_invariant', reuse=tf.AUTO_REUSE
      )
    elif low_level_obs_type == 'order_invariant':
      print('using self attention low level model')
      self.low_level_policy_network = VariableInputModel(
          input_dim=cfg.input_dim, vocab_size=cfg.vocab_size,
          embedding_size=cfg.embedding_size,
          name=low_level_policy_name, des_len=cfg.des_len,
          inner_len=cfg.inner_len, encoder_n_unit=cfg.encoder_n_unit,
          per_input_ac_dim=cfg.ac_dim, reuse=tf.AUTO_REUSE
      )
    elif low_level_obs_type == 'image':
      print('using image model')
      conv_layer = [(48, 8, 2), (128, 5, 2), (64, 3, 1)]
      self.low_level_policy_network = ImageModel(
          input_dim=[64, 64, 3],
          name='model',
          ac_dim=[800],
          vocab_size=self.vocab_size,
          embedding_size=cfg.embedding_size,
          conv_layer_config=conv_layer,
          dense_layer_config=[512, 512],
          encoder_n_unit=cfg.encoder_n_unit,
          action_type='discrete',
          action_parameterization='factor'
      )
    else:
      raise ValueError(('only Model, VariableInputModel'
                        'and ImageModel are supported.'))

    print('******created low level network******')

    if not low_level_model_ckpt_path:
      low_level_model_ckpt_path = default_low_level_model_ckpt_path

    if low_level_obs_type == 'order_invariant':
      low_level_model_ckpt_path = self_attention_low_level_model_ckpt_path

    if low_level_obs_type == 'image':
      low_level_model_ckpt_path = image_factor_low_level_model_ckpt_path
      if self.diverse:
        low_level_model_ckpt_path = diverse_image_factor_low_level_model_ckpt_path

    low_level_network_saver = tf.train.Saver(
        var_list=self.low_level_policy_network.variables)
    # print(self.low_level_policy_network.variables)
    low_level_network_saver.restore(self.sess, low_level_model_ckpt_path)

    print('******loaded low level model******')

    self._rew = self._reward()
    self._done = False or self._is_done(self._rew)
    self._obs = self._get_obs()
    self.low_level_step = low_level_step
    self.obs_type = obs_type

  def __getattr__(self, attr):
    return getattr(self.env, attr)

  def restore_low_level_policy(self, low_level_model_ckpt_path=None):
    if not low_level_model_ckpt_path:
      low_level_model_ckpt_path = default_low_level_model_ckpt_path

    if self.low_level_obs_type == 'order_invariant':
      low_level_model_ckpt_path = self_attention_low_level_model_ckpt_path
    if self.low_level_obs_type == 'image':
      low_level_model_ckpt_path = image_factor_low_level_model_ckpt_path
      if self.diverse:
        low_level_model_ckpt_path = diverse_image_factor_low_level_model_ckpt_path

    low_level_network_saver = tf.train.Saver(
        var_list=self.low_level_policy_network.variables)
    low_level_network_saver.restore(self.sess, low_level_model_ckpt_path)

  def get_low_level_action(self, instruction):

    if self.low_level_obs_type == 'image':
      obs = self.get_image_obs()
    elif self.low_level_obs_type == 'direct':
      obs = self.get_direct_obs()
    else:
      obs = self.get_order_invariant_obs()

    if self.low_level_obs_type != 'image':
      action = self.sess.run(
          self.low_level_policy_network.predict,
          feed_dict={self.low_level_policy_network.inputs: [obs],
                     self.low_level_policy_network.word_inputs: [instruction],
                     self.low_level_policy_network.is_training: False})
    else:
      action_q = self.sess.run(
          self.low_level_policy_network.Q_,
          feed_dict={self.low_level_policy_network.inputs: [obs],
                     self.low_level_policy_network.word_inputs: [instruction],
                     self.low_level_policy_network.is_training: False})
      action_q = np.squeeze(action_q)
      action = np.squeeze(action_q).argsort()[-1]
    return np.squeeze(action)

  def step(self, action):
    raise NotImplementedError('Step function has not been implemented.')

  def reset(self):
    obs = self.env.reset()
    return obs

  def _reward(self):
    raise NotImplementedError('please implement specific goal')

  def _is_done(self, reward):
    raise NotImplementedError('please implement finish condition')


class SortingXCoordEnv(HighLevelEnv):

  def _reward(self):
    curr_state = [self.get_body_com(name) for name in self.obj_name]
    is_sorted = True
    for i in range(1, len(curr_state)):
      is_sorted = is_sorted and curr_state[i][0] > curr_state[i-1][0]
    return (-1. + float(is_sorted)) * 3.

  def _is_done(self, reward):
    return reward == 0.

  def reset(self, max_reset=50):
    obs = self.env.reset()
    r = self._complete()
    rep = 0
    while r > -8 and rep < max_reset:
      obs = self.env.reset()
      r = self._complete()
      rep += 1
    return obs


class MultiStatementEnv(HighLevelEnv):

  def _reward(self):
    ans = [self.get_answer(question) for question in self.objective_program]
    satisfied_goal = np.logical_not(np.logical_xor(ans, self.objective_goal))
    reward = np.sum(satisfied_goal.astype(float)) - len(self.objective_program)
    reward = -10. if reward < 0 else reward
    return reward

  def _complete(self):
    ans = [self.get_answer(question) for question in self.objective_program]
    satisfied_goal = np.logical_not(np.logical_xor(ans, self.objective_goal))
    reward = np.sum(satisfied_goal.astype(float)) - len(self.objective_program)
    return reward

  def _is_done(self, reward):
    return reward == 0.

  def reset(self, max_reset=50):
    obs = self.env.reset()
    r = self._complete()
    rep = 0
    while r > -8 and rep < max_reset:
      obs = self.env.reset()
      r = self._complete()
      rep += 1
    return obs
