import numpy as np
import torch
import os


class ReplayBuffer(object):
	def __init__(self, state_dim, action_dim, max_size=int(1e6)):
		self.max_size = max_size
		self.ptr = 0
		self.size = 0

		self.state = np.zeros((max_size, state_dim))
		self.action = np.zeros((max_size, action_dim))
		self.next_state = np.zeros((max_size, state_dim))
		self.reward = np.zeros((max_size, 1))
		self.not_done = np.zeros((max_size, 1))

		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


	def add(self, state, action, next_state, reward, done):
		self.state[self.ptr] = state
		self.action[self.ptr] = action
		self.next_state[self.ptr] = next_state
		self.reward[self.ptr] = reward
		self.not_done[self.ptr] = 1. - done

		self.ptr = (self.ptr + 1) % self.max_size
		self.size = min(self.size + 1, self.max_size)


	def sample(self, batch_size):
		ind = np.random.randint(0, self.size, size=batch_size)

		return (
			torch.FloatTensor(self.state[ind]).to(self.device),
			torch.FloatTensor(self.action[ind]).to(self.device),
			torch.FloatTensor(self.next_state[ind]).to(self.device),
			torch.FloatTensor(self.reward[ind]).to(self.device),
			torch.FloatTensor(self.not_done[ind]).to(self.device)
		)


	def convert_D4RL(self, file_path,type="expert",task="lift"):
		self.file_path = file_path
		trajs = np.load(os.path.join(self.file_path,f'{task}_{type}_1000.npy'), allow_pickle=True)
		state = np.concatenate([traj['obs'] for traj in trajs], axis=0)
		# columns_to_keep = list(range(3)) + list(range(7,13)) + list(range(19, 22)) + list(range(29, 33))
		self.state = state#[:,columns_to_keep]
		# print(self.state.shape)

		self.action = np.concatenate([traj['acts'] for traj in trajs], axis=0)
		next_state = np.concatenate([traj['obs_next'] for traj in trajs], axis=0) 
		self.next_state=next_state#[:,columns_to_keep]
		self.reward = np.concatenate([traj['rewards'] for traj in trajs], axis=0)
		self.reward= np.expand_dims(self.reward,axis=-1)
		self.not_done = 1. - np.concatenate([traj['dones'] for traj in trajs], axis=0)
		self.not_done = np.expand_dims(self.not_done,axis=-1)
		self.size = self.state.shape[0]
		return np.max(self.action)


	def normalize_states(self, eps = 1e-3):
		mean = self.state.mean(0,keepdims=True)
		std = self.state.std(0,keepdims=True) + eps
		self.state = (self.state - mean)/std
		self.next_state = (self.next_state - mean)/std
		return mean, std