#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
import sys
import time
import json

sys.path.append('./')
sys.path.append('../')
from src.Env.simUtils import *
from src.utils import *

import matplotlib.pyplot as plt
from moviepy.editor import ImageSequenceClip, ImageClip, concatenate_videoclips

import grpc
import argparse
import pickle
import os

import hashlib
import random

def string_to_seed(input_string):
    hash_object = hashlib.sha256(input_string.encode())
    hex_dig = hash_object.hexdigest()
    seed = int(hex_dig, 16)%100000009
    return seed


parser = argparse.ArgumentParser(description="")

parser.add_argument('--host', type=str)
parser.add_argument('--output_path', type=str)
parser.add_argument('--data_info', type=str, default='')
parser.add_argument('--n_objs', type=int)
parser.add_argument('--handSide', type=str)
parser.add_argument('--event', type=str)
args = parser.parse_args()


host = args.host
scene_num = 1
map_id = 2
server = SimServer(host, scene_num=scene_num, map_id=map_id)
sim = SimAction(host, scene_id=0)

seed = string_to_seed(args.output_path)
random.seed(seed)
np.random.seed(seed)

events = {
    'graspTargetObj':{'act':sim.graspTargetObj,'check':sim.checkGraspTargetObj},
    'placeTargetObj':{'act':sim.placeTargetObj,'check':sim.checkPlaceTargetObj},
    'moveNear':{'act':sim.moveNear,'check':sim.checkMoveNear},
    'knockOver':{'act':sim.knockOver,'check':sim.checkKnockOver},
    'pushFront':{'act':sim.pushFront,'check':sim.checkPushFront},
    'pushLeft':{'act':sim.pushLeft,'check':sim.checkPushLeft},
    'pushRight':{'act':sim.pushRight,'check':sim.checkPushRight}
}
event = events[args.event]

dir_name = args.output_path
output_path = ''+ os.sep+dir_name
data_info=args.data_info
meta_data_path = output_path + os.sep + 'meta_data.json'
n_objs = args.n_objs
handSide = args.handSide

can_list = list(sim.can_list)
if not os.path.exists(output_path):
    os.makedirs(output_path)

if not os.path.exists(meta_data_path):
    meta_data = {
        "collected_num": 0,
        "start_index": 0,
        "info": data_info,
    }
    with open(meta_data_path, 'w') as f:
        json.dump(meta_data, f)
else:
    with open(meta_data_path, 'rb') as f:
        meta_data = json.load(f)

before_grasp_images_path = './before_grasp_images'
if not os.path.exists(before_grasp_images_path):
    os.makedirs(before_grasp_images_path)
video_path = './log_videos'
if not os.path.exists(video_path):
    os.makedirs(video_path)

from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random

def Resize(mat):
    mat = Image.fromarray(mat, mode='RGB')
    mat = mat.resize((224, 224))
    mat = np.array(mat)
    mat = 1.0 * mat
    mat = mat / 255.0
    return mat


from tqdm import tqdm

import re
if n_objs==2:
    f=open('Imitation_data/2_objs_locs.pkl','rb')
elif n_objs==3:
    f=open('Imitation_data/3_objs_locs.pkl','rb')
elif n_objs==1:
    f=open('Imitation_data/1_objs_locs.pkl','rb')
df = pickle.load(f)
print('load data over')


collected_num = meta_data['collected_num']
start_index = meta_data['start_index']
for epoch in range(1):
    print('Epoch:', epoch)
    offline_data = dict()
    shuffled_list = df.copy()
    random.shuffle(shuffled_list)
    for index in tqdm(range(10000)):
        if index < start_index:
            continue
        if index>=len(df):
            break
        sim.reset()
        sim.bow_head()
        time.sleep(1)
        sim.grasp('release',handSide=handSide)
        time.sleep(1)
        sim.removeObjects('all')
        objs = sim.getObjsInfo()
        desk_height = 98 
        desk_id = random.choice(sim.desks.ID.values)
        sim.addDesk(desk_id, h=desk_height)
        ids = random.sample(list(can_list), n_objs)
        # objList = sim.genObjs(n=n_objs, ids=ids, h=sim.desk_height, handSide = handSide, min_distance=15)
        assert  len(df[index])==n_objs, 'data error'
        objList = []
        objList.append([ids[0],*df[index][0][:2],desk_height+1])
        if n_objs>1:
            objList.append([ids[1], *df[index][1][:2],desk_height+1])
        sim.addObjects(objList)
        target_obj_index = random.randint(1,n_objs)
        if n_objs>1:
            other_obj_index = random.choice([x for x in range(1,n_objs+1) if x!=target_obj_index])
        target_origin_loc = sim.getObjsInfo()[target_obj_index]['location']
        target_obj_id = objList[target_obj_index-1][0]
        target_obj = sim.objs[sim.objs.ID == target_obj_id].Name.values[0]
        sx, sy = sim.getObservation().location.X, sim.getObservation().location.Y


        ox, oy, oz = sim.getSensorsData(handSide=handSide)[0]
        offline_data['from_file'] = index
        offline_data['robot_location'] = (sx, sy, 90)
        offline_data['deskInfo'] = {'id': desk_id, 'height': sim.desk_height}
        offline_data['objList'] = objList
        offline_data['targetObjID'] = target_obj_id
        offline_data['target_obj_index'] = target_obj_index
        offline_data['initState'] = sim.getState()
        offline_data['initLoc'] = (ox-sx, oy-sy, oz)
        offline_data['handSide'] = handSide
        offline_data['event'] = args.event
        offline_data['trajectory'] = []

        last_action = (ox-sx, oy-sy, oz)
        last_imgs = sim.getImage()
        last_state = sim.getState()
        # do_values = []
        if args.event=='placeTargetObj':
            for action in sim.graspTargetObj(obj_id=target_obj_index,handSide=handSide):
                pass
            if not sim.checkGraspTargetObj(obj_id=target_obj_index):
                continue
        
        file_output_path = output_path+f'/{index:06d}'
        file_prefix = dir_name+f'/{index:06d}'
        try:
            if args.event=='moveNear':
                for action in event['act'](obj1_id=target_obj_index,obj2_id=other_obj_index, handSide=handSide):
                    # values = sim.bow_head()
                    # do_values.append(values)
                    each_frame = {}
                    each_frame['img'] = last_imgs[0]
                    each_frame['json_data'] = data_processing(sim,last_imgs[1],target_obj_id,file_prefix,frame_id)
                    each_frame['state'] = last_state
                    each_frame['action'] = action
                    time.sleep(0.05)
                    last_imgs = sim.getImage()
                    last_state = sim.getState()
                    each_frame['after_state'] = last_state
                    offline_data['trajectory'].append(each_frame)
            else:
                for frame_id,(action,action_description) in enumerate(event['act'](obj_id=target_obj_index,handSide=handSide)):
                    # values = sim.bow_head()
                    # do_values.append(values)
                    each_frame = {}
                    each_frame['img'] = last_imgs[0]
                    each_frame['json_data'] = data_processing(sim,last_imgs[1],target_obj_id,action_description,file_prefix,frame_id)
                    each_frame['state'] = last_state
                    each_frame['action'] = action
                    time.sleep(0.05)
                    last_imgs = sim.getImage()
                    last_state = sim.getState()
                    each_frame['after_state'] = last_state
                    offline_data['trajectory'].append(each_frame)
        except Exception as e:
            print(e)
            print(f'error index={index} and target is {target_obj}')
        if (args.event == 'moveNear' and event['check'](obj1_id=target_obj_index,obj2_id=other_obj_index)) or (args.event != 'moveNear' and event['check'](obj_id=target_obj_index)) :
            collected_num += 1
            is_success = True
            print(f'Success have collected {collected_num} datas')
        else:
            is_success = False
            print('fail data:', index, desk_id, target_obj_id, objList)
        im = sim.getImage()[0]
        plt.imshow(im)
        plt.savefig(before_grasp_images_path + f"/{index:04d}_{is_success}_{target_obj}_{args.event}.png", format='png')
        # do_values = np.array(do_values)
        if len(offline_data['trajectory'])>0:
            images = [ImageClip(frame['img'], duration=1 / 3) for frame in
                      offline_data['trajectory']]
            clip = concatenate_videoclips(images)
            clip.write_videofile(video_path + f"/{index:04d}_{is_success}_{target_obj}_{args.event}.mp4", fps=3)

        meta_data = {
            "collected_num": collected_num,
            "start_index": index + 1,
            "data_info": data_info,
        }
        with open(meta_data_path, 'w') as f:
            json.dump(meta_data, f)

        if is_success:
            if not os.path.exists(file_output_path):
                os.makedirs(file_output_path)
            for frame_id,frame in enumerate(offline_data['trajectory']):
                
                image = Image.fromarray(frame['img'])
                image.save(file_output_path+f'/{frame_id:03d}.jpg')

                with open(file_output_path+f'/{frame_id:03d}.json','w') as f:
                    json.dump(frame['json_data'],f)
                del frame['img']
                del frame['json_data']
            with open(file_output_path + f'/{collected_num:06d}.pkl', 'wb') as f:
                pickle.dump(offline_data, f)
            
            json_path = file_output_path+f'/{frame_id:03d}.json'
            img_path = json_path.replace(".json", ".jpg")
            img = cv2.imread(img_path)[:, :, ::-1]
            mask, comments, is_sentence = get_mask_from_json(json_path, img)
            ## visualization. Green for target, and red for ignore.
            valid_mask = (mask == 1).astype(np.float32)[:, :, None]
            ignore_mask = (mask == 255).astype(np.float32)[:, :, None]
            vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + (
                (np.array([0, 255, 0]) * 0.6 + img * 0.4) * valid_mask
                + (np.array([255, 0, 0]) * 0.6 + img * 0.4) * ignore_mask
            )
            vis_img = np.concatenate([img, vis_img], 1)
            vis_path = os.path.join(
                '../outputs/visualizations', json_path.split("/")[-1].replace(".json", ".jpg")
            )
            cv2.imwrite(vis_path, vis_img[:, :, ::-1])

            json_path = file_output_path+f'/{0:03d}.json'
            img_path = json_path.replace(".json", ".jpg")
            img = cv2.imread(img_path)[:, :, ::-1]
            mask, comments, is_sentence = get_mask_from_json(json_path, img)
            ## visualization. Green for target, and red for ignore.
            valid_mask = (mask == 1).astype(np.float32)[:, :, None]
            ignore_mask = (mask == 255).astype(np.float32)[:, :, None]
            vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + (
                (np.array([0, 255, 0]) * 0.6 + img * 0.4) * valid_mask
                + (np.array([255, 0, 0]) * 0.6 + img * 0.4) * ignore_mask
            )
            vis_img = np.concatenate([img, vis_img], 1)
            vis_path = os.path.join(
                '../outputs/visualizations', json_path.split("/")[-1].replace(".json", ".jpg")
            )
            cv2.imwrite(vis_path, vis_img[:, :, ::-1])
        
        
        