# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Feature-wise Linear Modulation Layer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
import sys

def film_params(sentence_embedding, n_layer_channel):
  """Generate FiLM parameters from a sentence embedding.

  Generate FiLM parameters from a sentence embedding. This method assumes a
  batch dimension exists.

  Args:
    sentence_embedding: A tensor containing batched sentenced embedding to be
                          transformed
    n_layer_channel:    A list of integers specifying how many channels are
                          at each hidden layer to be FiLM'ed

  Returns:
    A tuple of tensors the same length as n_layer_channel. Each element
    contains all gamma_i and beta_i for a single hidden layer.
  """
  n_total = sum(n_layer_channel) * 2
  all_params = tf.layers.dense(sentence_embedding, n_total)
  return tf.split(all_params, [c*2 for c in n_layer_channel], 1)


def film_pi_network(
    obs, goal, ac_dim, conv_layer_config, dense_layer_config, is_training):
  """Build FiLM policy network.

  Build the graph for a network with film layer to combine image and text.

  Args:
    obs: Tensor containing the state observations.
    goal: Tensor containing the goal embedding
    ac_dim: Dimension of the policy
    conv_layer_config: A list of tuple (channel, kernel_size, stride) specifying
                         the topology of convolutional layers
    dense_layer_config: A list of integer specifying the topology of the dense
                          layers
    is_training: An indicator of which phase the model is in

  Returns:
    Policy conditioned on the observation and goal
  """
  # TODO: add variable scope and add reuse
  n_layer_channel = [layer_config[0] for layer_config in conv_layer_config]
  layer_film_params = film_params(goal, n_layer_channel)
  out = obs
  # building convnet
  for cfg, param in zip(conv_layer_config, layer_film_params):
    out = tf.layers.conv2d(out, cfg[0], cfg[1], cfg[2], padding='SAME')
    out = tf.layers.batch_normalization(
        out, center=False, scale=False, training=is_training)
    gamma, beta = tf.split(param, 2, axis=1)
    out *= tf.expand_dims(tf.expand_dims(gamma, 1), 1)
    out += tf.expand_dims(tf.expand_dims(beta, 1), 1)
    out = tf.nn.relu(out)
  shape = out.get_shape()
  out = tf.reshape(out, (-1, np.prod(shape[1:])))
  # building fully connected net
  for cfg in dense_layer_config:
    out = tf.nn.relu(tf.layers.dense(out, cfg))
  return tf.layers.dense(out, ac_dim)
  # out = tf.layers.conv2d(out, ac_dim, 1, 1)
  # return tf.reduce_mean(out, axis=[1, 2])


def mlp_pi_network(obs, goal, ac_dim, dense_layer_config):
  """Build mlp policy network.

  Build the graph for a network with direct input and text.

  Args:
    obs: Tensor containing the state observations.
    goal: Tensor containing the goal embedding
    ac_dim: Dimension of the policy
    dense_layer_config: A list of integer specifying the topology of the dense
                          layers

  Returns:
    Policy conditioned on the observation and goal
  """
  # TODO: add variable scope and add reuse
  out = tf.concat([obs, goal], axis=1)
  # building fully connected net
  for cfg in dense_layer_config:
    out = tf.nn.relu(tf.layers.dense(out, cfg))
  return tf.layers.dense(out, ac_dim)


def combine_variable_input(obs, goal, description_length, inner_length):
  """Combine variable sized input

  Yield fixed length vector description of variable number of observation.

  Args:
    obs: Tensor containing the state [B, ?, N]
    layer_config: A list of integer specifiying the topology of the combination
                  network

  Returns:
    Fixed length description [B, dl]
  """
  multiples = tf.convert_to_tensor([1, tf.shape(obs)[1], 1])
  goal_tile = tf.tile(tf.expand_dims(goal, axis=1), multiples=multiples)
  obs = tf.concat([obs, goal_tile], axis=-1)
  obs = tf.expand_dims(obs, axis=2) # [B, ?, 1, N]

  # f shape: [B, ?, 1, inner_length]
  f = tf.layers.conv2d(
      obs, inner_length, 1, padding='same', activation=tf.nn.relu)
  f = tf.layers.conv2d(
      f, inner_length, 1, padding='same', activation=tf.nn.relu)
  f = tf.layers.conv2d(f, inner_length, 1, padding='same')

  # g shape: [B, ?, 1, inner_length]
  g = tf.layers.conv2d(
      obs, inner_length, 1, padding='same', activation=tf.nn.relu)
  g = tf.layers.conv2d(
      g, inner_length, 1, padding='same', activation=tf.nn.relu)
  g = tf.layers.conv2d(g, inner_length, 1, padding='same')

  # h shape: [B, ?, 1, description_length]
  h = tf.layers.conv2d(obs, description_length, 1, padding='same')
  f = tf.squeeze(f, axis=2)  # [B, ?, inner_length]
  g = tf.squeeze(g, axis=2)  # [B, ?, inner_length]
  g = tf.transpose(g, perm=[0, 2, 1])  # [B, inner_length, ?]

  inner = tf.matmul(f, g)  # [B, ?, ?]
  weight = tf.nn.softmax(inner)  # [B, ?, ?]
  prod = tf.matmul(weight, tf.squeeze(h, axis=2))  # [B, ?, description_length]
  return tf.reduce_mean(prod, axis=1)  # [B, description_length]


def film_critic_network(obs, goal, a, out_dim, conv_layer_config,
    ac_dense_layer_config, dense_layer_config):
  """Build FiLM critic network.

  Build the graph for a network with film layer to combine image and text.

  Args:
    obs: Tensor containing the state observations.
    goal: Tensor containing the goal embedding
    a: Tensor containing the action
    out_dim: dimension of the critic output, usually 1
    conv_layer_config: A list of tuple (channel, kernel_size, stride) specifying
                         the topology of convolutional layers
    ac_dense_layer_config: A list of integers specifcying the topology of the
                           dense layers for action alone
    dense_layer_config: A list of integer specifying the topology of the dense
                          layers for concatenate state and action

  Returns:
    Output the critic networks of dimension out_dim
  """
  n_layer_channel = [layer_config[0] for layer_config in conv_layer_config]
  layer_film_params = film_params(goal, n_layer_channel)
  out = obs
  # building convnet
  for cfg, param in zip(conv_layer_config, layer_film_params):
    out = tf.layers.conv2d(out, cfg[0], cfg[1], cfg[2])
    out *= param[0]
    out += param[1]
    out = tf.nn.relu(out)
  # building action fully connected layer
  a = a
  for cfg in ac_dense_layer_config:
    a = tf.nn.relu(tf.layers.dense(a, cfg))
  # building fully connected net
  out = tf.reshape(out, [tf.shape(out)[0], -1])
  out = tf.concat([out, a], axis=1)
  for cfg in dense_layer_config:
    out = tf.nn.relu(tf.layers.dense(out, cfg))
  return tf.layers.dense(out, out_dim)
