import torch
import torch.nn
import torch.nn.functional as F
from torchvision.transforms import transforms
from IPython import embed
import numpy as np
import torch.nn as nn
import transformers
from transformers import BertConfig, BertModel 
from transformers import DistilBertConfig, DistilBertModel, GPT2Config, GPT2Model
import matplotlib.pyplot as plt
import time
import argparse
import torch.nn.functional as F
import time
import os
import pickle
from dataset import TrajDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Net(torch.nn.Module):
    def __init__(self, config):
        super(Net, self).__init__()
        
        self.config = config
        self.H = self.config['H']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.dx = self.config['dx']
        self.du = self.config['du']
        self.useQ = self.config['Q']

        seq_length = self.H + 1


        if self.useQ:    input_dim = (1 + self.dx + self.dx + self.du) * seq_length + self.dx**2
        else:            input_dim = (1 + self.dx + self.dx + self.du) * seq_length

        self.embed_state_action = torch.nn.Linear( input_dim, self.n_embd )
        self.ln1 = torch.nn.Linear( self.n_embd, self.n_embd )
        self.ln2 = torch.nn.Linear( self.n_embd, self.n_embd )
        self.ln3 = torch.nn.Linear( self.n_embd, self.du )



    def forward(self, x):
        states = x['states'][:,None,:]
        zeros = x['zeros'][:,None,:]

        x_seq = torch.cat([states, x['rollin_xs']], dim=1)
        u_seq = torch.cat([zeros[:,:,:self.du], x['rollin_us']], dim=1)
        xp_seq = torch.cat([zeros[:,:,:self.dx], x['rollin_xps']], dim=1)
        r_seq = torch.cat([zeros[:,:,:1], x['rollin_rs']], dim=1)
        
        seq = torch.cat([x_seq, u_seq, xp_seq, r_seq], dim=-1)
        if self.useQ:
            seq[:,1:] = 0.0

        batch_size = x_seq.shape[0]
        seq_length = x_seq.shape[1]

        seq = torch.reshape(seq, (batch_size, (self.dx + self.du + self.dx + 1) * seq_length) ) 
        if self.useQ:        
            Qs = x['Qs'].reshape((batch_size, self.dx**2))
            seq = torch.cat((seq, Qs), dim=-1)

        embeds = self.embed_state_action(seq)
        embeds = F.relu(embeds)
        embeds = self.ln1(embeds)
        embeds = F.relu(embeds)
        embeds = self.ln2(embeds)
        embeds = F.relu(embeds)
        preds = self.ln3(embeds)

        return preds




class TransformerTall(torch.nn.Module):
    def __init__(self, config):
        super(TransformerTall, self).__init__()

        self.config = config
        self.H = self.config['H']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.dx = self.config['dx']
        self.du = self.config['du']
        self.useQ = self.config['Q']
        self.dropout = self.config['dropout']

        config = GPT2Config(
            n_positions=4 * ( 1 + self.H ),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=1,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout, # added drop out
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)

        self.embed_transition = torch.nn.Linear(self.dx + self.du + self.dx + 1, self.n_embd)        

        self.embed_ln = nn.LayerNorm(self.n_embd)
        self.pred_actions = nn.Linear(self.n_embd, self.du)




    def forward(self, x):
        states = x['states'][:,None,:]
        zeros = x['zeros'][:,None,:]

        x_seq = torch.cat([states, x['rollin_xs']], dim=1)
        u_seq = torch.cat([zeros[:,:,:self.du], x['rollin_us']], dim=1)
        xp_seq = torch.cat([zeros[:,:,:self.dx], x['rollin_xps']], dim=1)
        r_seq = torch.cat([zeros[:,:,:1], x['rollin_rs']], dim=1)

        seq = torch.cat([x_seq, u_seq, xp_seq, r_seq], dim=2)

        
        
        batch_size = seq.shape[0]
        seq_length = seq.shape[1]


        seq_embeds = self.embed_transition(seq)

        stacked_inputs = seq_embeds


        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
        )
        x = transformer_outputs['last_hidden_state']

        feats = x
        preds = self.pred_actions(feats)
        if self.config['full']:     return preds[:,1:,:]
        else:                       return preds[:,-1,:]


class TransformerVision(torch.nn.Module):
    def __init__(self, config):
        super(TransformerVision, self).__init__()

        self.config = config
        self.H = self.config['H']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.du = self.config['du']
        self.dropout = self.config['dropout']
        self.small = self.config['small']
        self.im_embd = 8

        if config['shift']:
            self.shift_aug = True
        else:
            self.shift_aug = False

        if self.small:
            size = (25 - 3) // 2 + 1
            size = (size - 3) // 1 + 1
            self.image_encoder = nn.Sequential(
                nn.Conv2d(3, 16, kernel_size=3, stride=2),
                nn.ReLU(),
                nn.Conv2d(16, 16, kernel_size=3, stride=1),
                nn.ReLU(),
                # fc layer
                nn.Flatten(start_dim=1),
                nn.Linear(int(16 * size * size), self.im_embd),
                nn.ReLU(),
            )
        else:
            self.image_encoder = nn.Sequential(
                nn.Conv2d(3, 2, kernel_size=3, stride=4),
                nn.ReLU(),
                # fc layer
                nn.Flatten(),
                nn.Linear(600, self.im_embd),
                nn.ReLU(),
            )


        config = GPT2Config(
            n_positions=4 * ( 1 + self.H ),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=1,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout,
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)

        self.embed_transition = torch.nn.Linear(self.im_embd * 2 + self.du + 1 + 8, self.n_embd)

        self.embed_ln = nn.LayerNorm(self.n_embd)
        self.pred_actions = nn.Linear(self.n_embd, self.du)



    def forward(self, x):
        states = x['states'][:, None, :]
        poses = x['poses'][:, None, :] / 10.0
        angles = x['angles'][:, None, :]
        rollin_obs = x['rollin_obs']
        rollin_poses = x['rollin_poses'] / 10.0
        rollin_angles = x['rollin_angles']
        rollin_actions = x['rollin_actions']
        rollin_rewards = x['rollin_rewards']
        rollin_next_obs = x['rollin_next_obs']
        rollin_next_poses = x['rollin_next_poses'] / 10.0
        rollin_next_angles = x['rollin_next_angles']
        if len(rollin_rewards.shape) == 2:
            rollin_rewards = rollin_rewards[:, :, None]
        bsize = states.shape[0]

        state_seq = torch.cat([states, rollin_obs], dim=1)
        next_state_seq = torch.cat([torch.zeros(bsize, 1, *state_seq.size()[2:]).to(device), rollin_next_obs], dim=1)
        rollin_poses = torch.cat([poses, rollin_poses], dim=1)
        rollin_angles = torch.cat([angles, rollin_angles], dim=1)
        rollin_actions = torch.cat([torch.zeros(bsize, 1, self.du).to(device), rollin_actions], dim=1)
        rollin_rewards = torch.cat([torch.zeros(bsize, 1, 1).to(device), rollin_rewards], dim=1)
        rollin_next_poses = torch.cat([torch.zeros(bsize, 1, 2).to(device), rollin_next_poses], dim=1)
        rollin_next_angles = torch.cat([torch.zeros(bsize, 1, 2).to(device), rollin_next_angles], dim=1)

        state_seq = state_seq.view(-1, *state_seq.size()[2:])
        next_state_seq = next_state_seq.view(-1, *next_state_seq.size()[2:])

        if self.shift_aug:
            shift_value = torch.FloatTensor(1).uniform_(0.0, 0.1)
            transform = transforms.RandomAffine(
                degrees=0,
                translate=(shift_value.item(), 0),  # (horizontal_shift_range, vertical_shift_range)
            )
            concat_seq = torch.cat([state_seq, next_state_seq], dim=1)  # (bsize, 6, 25, 25)
            concat_seq = torch.stack([transform(x) for x in concat_seq])
            assert len(concat_seq.shape) == 4
            assert concat_seq.shape[1] == 6
            assert concat_seq.shape[2] == 25
            assert concat_seq.shape[3] == 25
            state_seq = concat_seq[:, :3]
            next_state_seq = concat_seq[:, 3:]


        states_enc_tmp = self.image_encoder(state_seq)
        states_enc_seq = states_enc_tmp.view(bsize, -1, self.im_embd)

        next_states_enc_tmp = self.image_encoder(next_state_seq)
        next_states_enc_seq = next_states_enc_tmp.view(bsize, -1, self.im_embd)

        states_enc_seq = torch.cat([
            states_enc_seq,
            rollin_poses,
            rollin_angles,
            rollin_actions,
            rollin_rewards,
            next_states_enc_seq,
            rollin_next_poses,
            rollin_next_angles,
        ], dim=2)


        batch_size = states_enc_seq.shape[0]
        seq_length = states_enc_seq.shape[1]

        seq_embeds = self.embed_transition(states_enc_seq)
        
        stacked_inputs = seq_embeds


        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
        )
        x = transformer_outputs['last_hidden_state']

        feats = x
        preds = self.pred_actions(feats)
        if self.config['full']:     return preds[:,1:,:]
        else:                       return preds[:,-1,:]


if __name__ == '__main__':
    config = {}
    n_envs = 1000
    n_hists = 1
    n_samples = 1
    H = 10
    dim = 1
    path_train = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
    path_test = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'
    ds = TrajDataset(path_train, config)

    config = {
        'H': H,
        'dx': dim,
        'du': dim,
        'n_layer': 3,
        'n_embd': 32,
        'n_head': 1,
        'Q': False,
    }
    model = Net(config).to(device)
    model(ds[:64])

    model = Transformer(config).to(device)
    model(ds[:64])
