import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import imageio
from proj_io_adapter_io_adapter import *
from cal_constraints import *

DEMO = "articulated_2"

def plot_gif(before_project, after_project, arrow, name = '0'):
    def generate_one_frame(data, pred, arrow, xy_max = 2.5):

        fig, ax = plt.subplots(figsize=(10,10))

        X, Y, U, V = zip(*[arrow])
        ax.quiver(X, Y, U, V, angles='xy', scale_units='xy', scale=1)

        ax.scatter(data[:,0], data[:,1], c = 'y')
        ax.scatter(pred[:,0], pred[:,1], c = 'b')       
        ax.grid()
        ax.set(xlabel='X', ylabel='Y', title='yellow: points before projection; blue: after projection.')
        ax.set_xlim(-xy_max, xy_max)
        ax.set_ylim(-xy_max, xy_max)
        fig.canvas.draw()     
        # fig.show()

        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))


        return image
    
    imageio.mimsave('./gif/' + name + '.gif', [
        generate_one_frame(before_project[i], after_project[i], arrow[i]) 
        for i in range(len(before_project))
        ], fps=10)

def show_one_frame(data, pred, arrow, xy_max = 1):
    X, Y, U, V = zip(*[arrow])
    plt.quiver(X, Y, U, V, angles='xy', scale_units='xy', scale=1)

    plt.scatter(data[:,0], data[:,1], c = 'y')
    plt.scatter(pred[:,0], pred[:,1], c = 'b')       
    plt.grid()
    plt.title('yellow: points before projection; blue: after projection.')
    plt.xlim(-xy_max, xy_max)
    plt.ylim(-xy_max, xy_max)
    plt.show()


class Proj():
    def __init__(self, model_path):
        model_path = 'models/' + model_path
        self.model = torch.load(model_path).cpu()
        # print(self.model.iter)
        self.model.iter = 1
        # print(self.model.iter)
    def project(self, data):
        d = torch.Tensor(data[None, :, :])
        pred = self.model(d)
        pred = pred[0,:,:]
        return pred.detach().numpy()

def test(te = 0):
    data = np.zeros([28,2])
    dx = 0.1
    timestamp = 0.1

    rope = np.zeros([8,2])
    for i in range(8):
        rope[i, 0] = - i * dx
        rope[i, 1] = 0
    rigid = np.array([[ 0.2514-1,  0.1707],
        [ 0.0483-1, -0.4097],
        [-0.3870-1,  0.0387],
        [-0.4260-1, -0.4451]]) - np.array([ 0.2514-1,  0.1707])
    

    
    data[0:8,:] = rope + np.array([2, 0])
    data[7:11,:] = rigid + data[7,:]
    data[10:18,:] = rope + data[10,:]
    data[17:21,:] = rigid + data[17,:]
    data[20:28,:] = rope + data[20,:]



    if te == 0:
        force0 = np.array([10, 28]) 
        force1 = np.array([-10, 28])         
        g = np.array([0, -2])

    vel = np.array(data)*0

    data_list = [np.array(data)]
    proj_list = [np.array(data)]
    force_list = []
    force_list.append(np.array([
            [data[0,0], data[-1,0], ],
            [data[0,1], data[-1,1], ],
            [0, 0], 
            [0, 0]
            ]))
    pr1 = Proj('rope_8_soft_bend.pt')
    pr2 = Proj('rigid_4.pt')

    for ite in range(100):
        new_data = np.array(data)
        vel[0, :] += force0 * timestamp
        vel[-1, :] += force1 * timestamp
        for i in range (0, len(data)):
            vel[i, :] += timestamp * g
            new_data[i, :] = data[i,:] + vel[i,:] * timestamp 

        data_list.append(np.array(new_data))
        force_list.append(np.array([
            [new_data[0,0], new_data[-1,0], ],
            [new_data[0,1], new_data[-1,1], ],
            [force0[0]/20, force1[0]/20], 
            [force0[1]/20, force1[1]/20]
            ]))
        for iter_num in range(0, 100):
            new_data[0:8, :] = pr1.project(new_data[0:8, :]) # new_data*0.5 + 0.5 * 
            new_data[7:11, :] = pr2.project(new_data[7:11, :]) # new_data*0.5 + 0.5 * 
            new_data[10:18, :] = pr1.project(new_data[10:18, :]) # new_data*0.5 + 0.5 * 
            new_data[17:21, :] = pr2.project(new_data[17:21, :]) # new_data*0.5 + 0.5 * 
            new_data[20:28, :] = pr1.project(new_data[20:28, :]) # new_data*0.5 + 0.5 * 
        

        proj_list.append(np.array(new_data))
        
        for i in range (len(data)):
           vel[i,:] =  (new_data[i,:] - data[i,:]) / timestamp # vel[i,:] * 0.5 + 0.5 *
        # print(new_data); print(data); print(vel)
        # show_one_frame(data_list[-1], proj_list[-1], force_list[-1])

        data = (np.array(new_data))
        # aaaaaaaaaaaaaaaaaaaa

    # write_viewer(data_list, output_dir = "output_15rope_4rigid_" + str(te))
    plot_gif(data_list, proj_list, force_list, DEMO)
    run_c(proj_list, [1,0,0,0,0], rigid, DEMO)
    run_d(proj_list, [0,1,1,0,0], rigid, DEMO)
    write_viewer(proj_list, output_dir = "output_" + DEMO)
    cal_mse(proj_list, 8, DEMO)
        
if __name__ == '__main__':
    test(0)
    # test(1)