from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import numpy as np
sys.path.append('.')
sys.path.append('..')
from methods.lips_bound import lips_bound_evaluation, estimate_rpi

def fun(x):
	return np.sum(np.sqrt(x*x + x*np.sin(x) + 1), axis = -1)
def Q_true(s,a):
	# s: n*d
	# a: n*d
	a = a - np.pi/2
	# x = np.hstack([s,a-np.pi/2])	# n*2d
	return 	fun(s) + fun(a)

class Linear_Gaussian_Policy(object):
	def __init__(self, w, b, logvar):
		self.w = w
		self.b = b
		self.logvar = logvar
		self.state_dim = w.shape[1]
		self.action_dim = w.shape[0]

	def choose_action(self, state):
		return np.matmul(self.w, state) + self.b + np.exp(self.logvar/2) * np.random.randn(self.action_dim)		#1

	def choose_actions(self, states):
		n = states.shape[0]
		return np.matmul(states, (self.w).T) + self.b + np.exp(self.logvar/2) * np.random.randn(n, self.action_dim)	#N

	def logpis(self, states, actions):
		diff = np.matmul(states, (self.w).T) + self.b - actions
		return -0.5*self.logvar - 0.5*np.sum(diff*diff, axis = -1)/np.exp(self.logvar)		# N

	def logpi(self, state, action):
		diff = np.matmul(self.w, state) + self.b - action
		return -0.5*self.logvar - 0.5*np.sum(diff*diff, axis = -1)/np.exp(self.logvar)		# 1

def get_policy_behavior():
	return Linear_Gaussian_Policy(np.array([[1.3]]), np.array([0.0]), -1.)

def get_policy_target():
	return Linear_Gaussian_Policy(np.array([[1.5]]), np.array([-0.1]), -10.)

class reverse_q_toy(object):
	def __init__(self, W1, W2, b, gamma):
		self.W1 = W1
		self.W2 = W2
		self.state_dim = W1.shape[1]
		self.action_dim = W2.shape[1]
		self.b = b
		self.gamma = gamma
		self.policy_target = get_policy_target()

	def set_initial(self, state):
		self.state = state

	def sample_initials(self, n):
		return 0.001 * np.random.randn(n, self.state_dim)

	def reset(self):
		self.state = 0.001 * np.random.randn(self.state_dim)
		return self.state

	def get_reward(self, state, action, repeat = 5):
		Q_sa = Q_true(state, action)
		next_state = self.transition(state, action)
		BQ_sa = 0.0
		for i in range(repeat):
			BQ_sa += Q_true(next_state, self.policy_target.choose_action(next_state))
		BQ_sa /= repeat
		return Q_sa - self.gamma * BQ_sa

	def transition(self, state, action, sigma = 0.00):
		return np.matmul(self.W1, state) + np.matmul(self.W2, action) + self.b + sigma * np.random.randn(self.state_dim)

	def step(self, action):
		reward = self.get_reward(self.state, action)
		self.state = self.transition(self.state, action)
		return self.state, reward

def get_reverse_q_env_policy(gamma):
	W1 = np.array([[0.8]])
	W2 = np.array([[-0.4]])
	b = np.array([-0.1])
	env = reverse_q_toy(W1, W2, b, gamma)
	policy_behavior = get_policy_behavior()
	policy_target = get_policy_target()
	return env, policy_behavior, policy_target

def get_initial_data(env, num_trajectory, seedID):
    np.random.rand(seedID)

    S0 = []
    for i_episode in range(num_trajectory):
        state = env.reset()
        S0.append(state)
    return np.array(S0)

def get_reverse_q_transition_data(env, num_trajectory, truncate_size, policy, seedID):
	np.random.seed(seedID)

	intial_length = 0
	state_dim = policy.state_dim
	action_dim = policy.action_dim
	S = np.zeros([num_trajectory, truncate_size, state_dim])
	SN = np.zeros([num_trajectory, truncate_size, state_dim])
	A = np.zeros([num_trajectory, truncate_size, action_dim])
	REW = np.zeros([num_trajectory, truncate_size])
	# Q = np.zeros([num_trajectory, truncate_size])
	for i_episode in range(num_trajectory):
		state = env.reset()
		for i_iteration in range(truncate_size + intial_length):
			action = policy.choose_action(state).reshape(action_dim)
			next_state, reward = env.step(action)
			if i_iteration >= intial_length:
				S[i_episode, i_iteration - intial_length, :] = state
				A[i_episode, i_iteration - intial_length, :] = action
				SN[i_episode, i_iteration - intial_length, :] = next_state
				REW[i_episode, i_iteration - intial_length] = reward
				# Q[i_episode, i_iteration - intial_length] = Q_true(state, action)
			state = next_state
	return [S, A, SN, REW]
	# return [S, A, SN, REW, Q]

class toy_config(object):
    # domain parameters
    state_dim = 1
    action_dim = 1

    gamma = 0.95
    num_trajectory = 30
    truncate_size = 100
    eta = 2.0
    subsample_size = 500
    NT = [1,2,4,6,10,20,30]
    ETA = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4]
    SSIZE = [100, 200, 300, 400, 500, 600, 800, 1000, 1500]
    result_path = './results/toy_results/'
    data_path = './transition_data/toy_data/'
    ground_truth = 0.16678
    figure_name = 'toy.pdf'

    max_iteration = 100
    env, policy_behavior, policy_target = get_reverse_q_env_policy(gamma)

    def feature_naive(self, S, A):
        return np.hstack([S, A])

    def get_transition(self, num_trajectory, truncate_size, seedID):
        state_dim = self.state_dim
        action_dim = self.action_dim
        S, A, SN, REW = get_reverse_q_transition_data(self.env, num_trajectory, truncate_size, self.policy_behavior, seedID)
        return S.reshape(-1, state_dim), A.reshape(-1, action_dim), SN.reshape(-1, state_dim), REW.reshape(-1)

    def interval_estimation(self, num_trajectory, eta, subsample_size, seedID):
        print('======== Current Setting for toy =========')
        print('---nt = {}, ts = {}, eta = {}, sample_size = {}, seed = {}---'.format(num_trajectory, self.truncate_size, eta, subsample_size, seedID))

        state_dim = self.state_dim
        action_dim = self.action_dim

        S, A, SN, REW = get_reverse_q_transition_data(self.env, num_trajectory, self.truncate_size, self.policy_behavior, seedID)
        S0 = get_initial_data(self.env, 500, seedID)
        S = S.reshape(-1, state_dim)
        A = A.reshape(-1, action_dim)
        SN = SN.reshape(-1, state_dim)
        REW = REW.reshape(-1)

        Q_lower, Q_upper = lips_bound_evaluation(S0, [S, A, SN, REW], self.policy_target, self.feature_naive, self.gamma, eta, subsample_size = subsample_size, max_iteration = self.max_iteration)
        est_lower, est_upper = estimate_rpi(S0, self.policy_target, self.feature_naive, S, A, Q_lower, Q_upper, self.gamma, eta)

        Q_lower, Q_upper = lips_bound_evaluation(S0, [S, A, SN, REW], self.policy_target, self.feature_naive, self.gamma, eta, subsample_size = subsample_size, double_sample = True, max_iteration = self.max_iteration)
        est_lower2, est_upper2 = estimate_rpi(S0, self.policy_target, self.feature_naive, S, A, Q_lower, Q_upper, self.gamma, eta)

        print('-----end calculation-----')
        print('lower = {}, upper = {}'.format(est_lower, est_upper))
        print('double sample: lower = {}, upper = {}'.format(est_lower2, est_upper2))
        print('============================')
        sys.stdout.flush()
        return est_lower, est_upper, est_lower2, est_upper2
