import torch
import torch.nn as nn
from transformers import GPT2Model, GPT2Config
from tqdm import tqdm
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression, Lasso
import warnings
from sklearn import tree
import xgboost as xgb
import math

from base_models import NeuralNetwork, ParallelNetworks


def build_model(conf):
    if conf.family == "gpt2_nn":
        model = TransformerNN(
            n_dims=conf.n_dims,
            n_positions=conf.n_positions,
            n_embd=conf.n_embd,
            n_layer=conf.n_layer,
            n_head=conf.n_head,
            zero_pad_embed=conf.zero_pad_embed,
            hidden_layer_size=conf.hidden_layer_size,
            n_intermediate_activations=conf.n_intermediate_activations,
            n_seen_intermediate=conf.n_seen_intermediate,
            hidden_sep_linear=conf.hidden_sep_linear,
        )
    elif conf.family == "gpt2_skill":
        model = TransformerSkill(
            n_dims=conf.n_dims,
            n_positions=conf.n_positions,
            n_embd=conf.n_embd,
            n_layer=conf.n_layer,
            n_head=conf.n_head,
            n_skills=conf.n_skills,
        )
    else:
        raise NotImplementedError

    return model


class TransformerNN(nn.Module):
    def __init__(self, n_dims, n_positions, n_embd=128, n_layer=12, n_head=4, zero_pad_embed=False, 
        hidden_layer_size=4, n_intermediate_activations=0, n_seen_intermediate=0, hidden_sep_linear=False):

        super(TransformerNN, self).__init__()
        configuration = GPT2Config(
            n_positions=(2 + n_intermediate_activations) * n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
        )
        self.name = f"gpt2_nn_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_dims
        self.n_embd = n_embd
        self.zero_pad_embed=zero_pad_embed
        self.n_intermediate_activations = n_intermediate_activations
        self.n_seen_intermediate = min(n_seen_intermediate, n_intermediate_activations)

        self._read_in_x = nn.Linear(n_dims, n_embd)
        self._read_in_y = self._read_in_x
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, hidden_layer_size)
        self._read_in_hidden = self._read_in_x
        self._read_out_hidden = self._read_out
        self.hidden_sep_linear = False



        # # self._read_in_x = nn.Linear(hidden_layer_size, n_embd)
        # self._read_in_x = nn.Linear(n_dims, n_embd)
        # self._read_in_y = nn.Linear(hidden_layer_size, n_embd)
        # # self._read_in_y = nn.Linear(1, n_embd)
        # self._backbone = GPT2Model(configuration)
        # self._read_out = nn.Linear(n_embd, hidden_layer_size)
        # # self._read_out = nn.Linear(n_embd, 1)

        # # if apply_multihead is True: all the intermediate inputs/outputs use the same linear layer
        # # else: different intermediate output from different layers use different linear layers
        # self.hidden_sep_linear = hidden_sep_linear
        # if hidden_sep_linear:
        #     self._read_in_hidden = nn.ModuleList(
        #         [nn.Linear(hidden_layer_size, n_embd) for i in range(self.n_seen_intermediate)]
        #     )
        #     self._read_out_hidden = nn.ModuleList(
        #         [nn.Linear(n_embd, hidden_layer_size) for i in range(self.n_seen_intermediate)]
        #     )
        # else:
        #     self._read_in_hidden = nn.Linear(hidden_layer_size, n_embd)
        #     self._read_out_hidden = nn.Linear(n_embd, hidden_layer_size)

    def _combine_embed(self, xs_b, ys_b, layer_activations=None):
        bsize, points, dim = xs_b.shape
        xs_embed = self._read_in_x(xs_b)
        stacked_tensors = [xs_embed]

        if layer_activations is not None:
            es_embeds = []
            if self.hidden_sep_linear:
                for i, act in enumerate(layer_activations):
                    es_embeds.append(self._read_in_hidden[i](act))
            else:
                 for act in layer_activations:
                    es_embeds.append(self._read_in_hidden(act))
            stacked_tensors += es_embeds
        
        # ys_embed = self._read_in_y(ys_b.reshape(bsize, points, 1))
        ys_embed = self._read_in_y(ys_b)
        stacked_tensors += [ys_embed]

        zs_embed = torch.stack(stacked_tensors, dim=2)
        zs_embed = zs_embed.view(bsize, len(stacked_tensors) * points, self.n_embd)
        return zs_embed

    def forward(self, xs, ys, loss_func, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")
        
        if ((layer_activations is not None) and (self.n_intermediate_activations != len(layer_activations))) \
            or ((layer_activations is None) and (self.n_intermediate_activations != 0)):
            raise ValueError("given activation number is not consistent with model setting")

        x_idx_freq = self.n_intermediate_activations + 2
        n_seen_intermediate = self.n_seen_intermediate

        embeds = self._combine_embed(xs, ys, layer_activations)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        
        # calculate training loss
        losses = []
        for i in range(n_seen_intermediate):
            if self.hidden_sep_linear:
                pred_hidden = self._read_out_hidden[i](output[:, i:][:, ::x_idx_freq])
            else:
                pred_hidden = self._read_out_hidden(output[:, i:][:, ::x_idx_freq])
            losses += [loss_func(pred_hidden, layer_activations[i]).sum(-1).mean()]
        pred = self._read_out(output[:, n_seen_intermediate:][:, ::x_idx_freq])
        # losses += [loss_func(pred[:,:,0], ys).mean()]
        losses += [loss_func(pred, ys).sum(-1).mean()]
        # losses += [loss_func(pred[:,:,0], ys).mean()]
        return losses, sum(losses)
    
    def forward_circulant(self, xs, ys, loss_func, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")
        
        # if ((layer_activations is not None) and (self.n_intermediate_activations != len(layer_activations))) \
        #     or ((layer_activations is None) and (self.n_intermediate_activations != 0)):
        #     raise ValueError("given activation number is not consistent with model setting")

        x_idx_freq = self.n_intermediate_activations + 2
        n_seen_intermediate = self.n_intermediate_activations
        if layer_activations is not None:
            layer_activations = layer_activations[:n_seen_intermediate]

        losses = []
        for act in layer_activations:
            act = torch.zeros_like(act)
        for i in range(len(layer_activations)):
            embeds = self._combine_embed(xs, ys, layer_activations)
            output = self._backbone(inputs_embeds=embeds).last_hidden_state
            if self.hidden_sep_linear:
                pred_hidden = self._read_out_hidden[i](output[:, i:][:, ::x_idx_freq])
            else:
                pred_hidden = self._read_out_hidden(output[:, i:][:, ::x_idx_freq])
            layer_activations[i] = pred_hidden
        embeds = self._combine_embed(xs, ys, layer_activations)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        for i in range(len(layer_activations)):
            if self.hidden_sep_linear:
                pred_hidden = self._read_out_hidden[i](output[:, i:][:, ::x_idx_freq])
            else:
                pred_hidden = self._read_out_hidden(output[:, i:][:, ::x_idx_freq])
            # losses += [loss_func(pred_hidden, layer_activations[i]).sum(-1).mean()]
        pred = self._read_out(output[:, len(layer_activations):][:, ::x_idx_freq])
        losses += [loss_func(pred[:,:,0], ys).mean()]
        return losses, sum(losses)
    
    def forward_circulant_v2(self, xs, ys, loss_func, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")
        
        # if ((layer_activations is not None) and (self.n_intermediate_activations != len(layer_activations))) \
        #     or ((layer_activations is None) and (self.n_intermediate_activations != 0)):
        #     raise ValueError("given activation number is not consistent with model setting")

        x_idx_freq = self.n_intermediate_activations + 2
        n_seen_intermediate = self.n_intermediate_activations
        if layer_activations is not None:
            layer_activations = layer_activations[:n_seen_intermediate]

        losses = []
        bsize, points, dim = xs.shape
        xs_embed = self._read_in_x(xs)
        ys_embed = self._read_in_y(ys.reshape(bsize, points, 1))
        es_embeds = [torch.zeros_like(xs_embed) for i in range(len(layer_activations))]
        for i in range(len(layer_activations)):
            ## 
            stacked_tensors = [xs_embed] + es_embeds + [ys_embed]
            embeds = torch.stack(stacked_tensors, dim=2)
            embeds = embeds.view(bsize, len(stacked_tensors) * points, self.n_embd)
            ##
            output = self._backbone(inputs_embeds=embeds).last_hidden_state
            # es_embeds[i] = (output[:, i:][:, ::x_idx_freq]).detach().clone()
            es_embeds[i] = (output[:, i:][:, ::x_idx_freq])
        ##
        stacked_tensors = [xs_embed] + es_embeds + [ys_embed]
        embeds = torch.stack(stacked_tensors, dim=2)
        embeds = embeds.view(bsize, len(stacked_tensors) * points, self.n_embd)
        ##
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        pred = self._read_out(output[:, len(layer_activations):][:, ::x_idx_freq])
        losses += [loss_func(pred[:,:,0], ys).mean()]
        return losses, sum(losses)
 
    def forward_circulant_pointwise(self, xs, ys, loss_func, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")
        
        if ((layer_activations is not None) and (self.n_intermediate_activations != len(layer_activations))) \
            or ((layer_activations is None) and (self.n_intermediate_activations != 0)):
            raise ValueError("given activation number is not consistent with model setting")

        x_idx_freq = self.n_intermediate_activations + 2
        n_seen_intermediate = self.n_seen_intermediate

        embeds = self._combine_embed(xs, ys)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        output = output[:,::2]
        
        # calculate training loss
        ids = [20,40,60,80,100]
        losses = []
        for i in range(len(ids)):
            if self.hidden_sep_linear:
                raise ValueError("TODO")
                pred_hidden = self._read_out_hidden[i](output[:, i:][:, ::x_idx_freq])
            else:
                pred_hidden = self._read_out_hidden(output[:, i:][:, ::x_idx_freq])
            losses += [loss_func(pred_hidden, layer_activations[i]).sum(-1).mean()]
        pred = self._read_out(output[:, n_seen_intermediate:][:, ::x_idx_freq])
        losses += [loss_func(pred[:,:,0], ys).mean()]
        return losses, sum(losses)
      
    def predict(self, xs, ys, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")

        if ((layer_activations is not None) and (self.n_intermediate_activations != len(layer_activations))) \
            or ((layer_activations is None) and (self.n_intermediate_activations != 0)):
            raise ValueError("given activation number is not consistent with model setting")

        x_idx_freq = self.n_intermediate_activations + 2
        n_seen_intermediate = self.n_seen_intermediate

        with torch.no_grad():
            embeds = self._combine_embed(xs, ys, layer_activations)
            output = self._backbone(inputs_embeds=embeds).last_hidden_state
            ##### ATTENTION ####
            # all_output = self._backbone(inputs_embeds=embeds, output_attentions=True, output_hidden_states=True)
            # attens = all_output.attentions
            # hidden_states = all_output.hidden_states
            # for i in attens:
            #     i = i.cpu().numpy()
            # for i in hidden_states:
            #     i = i.cpu().numpy()
            # torch.save(attens, "attentions_chain6.pt")
            # torch.save(hidden_states, "hidden_states_chain6.pt")
            # # for n, p in self.named_parameters():
            # #     print(n)
            # # print(self._backbone.h[5].attn.c_proj.weight.shape)
            # # torch.save(self._backbone.h[4].attn.c_attn.weight, "attn_weight_4.pt")
            # exit()
            ######   END     ####
            # pred_y = self._read_out(output[:, n_seen_intermediate:][:, ::x_idx_freq])[:,:,0]
            # return pred_y[:, inds]
            pred_y = self._read_out_hidden(output[:, 1:][:, ::x_idx_freq])
            return pred_y[:, inds]


            pred_y = torch.zeros_like(ys)
            # pred_y = torch.zeros_like(layer_activations[0])
            for i in range(ys.shape[1]-1, -1, -1):
                xs = xs[:,:i+1]
                ys = ys[:,:i+1]
                if layer_activations is not None:
                    layer_activations = [item[:,:i+1] for item in layer_activations]

                for j in range(n_seen_intermediate):
                    if self.hidden_sep_linear:
                        layer_activations[j][:,i] = self._read_out_hidden[j](output[:, j:][:, ::x_idx_freq][:,i])
                    else:
                        layer_activations[j][:,i] = self._read_out_hidden(output[:, j:][:, ::x_idx_freq][:,i])
                        # pred_y[:,i] = self._read_out_hidden(output[:, j:][:, ::x_idx_freq][:,i])
                    embeds = self._combine_embed(xs, ys, layer_activations)
                    output = self._backbone(inputs_embeds=embeds).last_hidden_state
                    
                # pred_y[:,i] = self._read_out(output[:, n_seen_intermediate:][:, ::x_idx_freq][:,i])[:,0]
                pred_y[:,i] = self._read_out(output[:, n_seen_intermediate:][:, ::x_idx_freq][:,i])
            return pred_y[:, inds]
    
    def predict_circulant(self, xs, ys, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")
        
        # if ((layer_activations is not None) and (self.n_intermediate_activations != len(layer_activations))) \
        #     or ((layer_activations is None) and (self.n_intermediate_activations != 0)):
        #     raise ValueError("given activation number is not consistent with model setting")

        x_idx_freq = self.n_intermediate_activations + 2
        n_seen_intermediate = self.n_intermediate_activations
        if layer_activations is not None:
            layer_activations = layer_activations[:n_seen_intermediate]

        with torch.no_grad():
            pred_y = torch.zeros_like(ys)
            for act in layer_activations:
                act = torch.zeros_like(act)
            for i in range(len(layer_activations)):
                embeds = self._combine_embed(xs, ys, layer_activations)
                output = self._backbone(inputs_embeds=embeds).last_hidden_state
                if self.hidden_sep_linear:
                    pred_hidden = self._read_out_hidden[i](output[:, i:][:, ::x_idx_freq])
                else:
                    pred_hidden = self._read_out_hidden(output[:, i:][:, ::x_idx_freq])
                layer_activations[i] = pred_hidden.detach().clone()
            embeds = self._combine_embed(xs, ys, layer_activations)
            output = self._backbone(inputs_embeds=embeds).last_hidden_state
            for i in range(len(layer_activations)):
                if self.hidden_sep_linear:
                    pred_hidden = self._read_out_hidden[i](output[:, i:][:, ::x_idx_freq])
                else:
                    pred_hidden = self._read_out_hidden(output[:, i:][:, ::x_idx_freq])
            pred_y = self._read_out(output[:, len(layer_activations):][:, ::x_idx_freq])[:,:,0]
            return pred_y[:, inds]

    def predict_circulant_v2(self, xs, ys, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")
        
        # if ((layer_activations is not None) and (self.n_intermediate_activations != len(layer_activations))) \
        #     or ((layer_activations is None) and (self.n_intermediate_activations != 0)):
        #     raise ValueError("given activation number is not consistent with model setting")

        x_idx_freq = self.n_intermediate_activations + 2
        n_seen_intermediate = self.n_intermediate_activations
        if layer_activations is not None:
            layer_activations = layer_activations[:n_seen_intermediate]

        with torch.no_grad():
            pred_y = torch.zeros_like(ys)
            bsize, points, dim = xs.shape
            xs_embed = self._read_in_x(xs)
            ys_embed = self._read_in_y(ys.reshape(bsize, points, 1))
            es_embeds = [torch.zeros_like(xs_embed) for i in range(len(layer_activations))]
            for i in range(len(layer_activations)):
                ## 
                stacked_tensors = [xs_embed] + es_embeds + [ys_embed]
                embeds = torch.stack(stacked_tensors, dim=2)
                embeds = embeds.view(bsize, len(stacked_tensors) * points, self.n_embd)
                ##
                output = self._backbone(inputs_embeds=embeds).last_hidden_state
                es_embeds[i] = (output[:, i:][:, ::x_idx_freq])
            ##
            stacked_tensors = [xs_embed] + es_embeds + [ys_embed]
            embeds = torch.stack(stacked_tensors, dim=2)
            embeds = embeds.view(bsize, len(stacked_tensors) * points, self.n_embd)
            ##
            output = self._backbone(inputs_embeds=embeds).last_hidden_state
            pred_y = self._read_out(output[:, len(layer_activations):][:, ::x_idx_freq])[:,:,0]
            return pred_y[:, inds]


# TODO
class TransformerSkill(nn.Module):
    def __init__(self, n_dims, n_positions, n_embd=128, n_layer=12, n_head=4, n_skills=10):

        super(TransformerSkill, self).__init__()
        configuration = GPT2Config(
            # n_positions= 2*(n_skills+1)*n_positions,
            n_positions= n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
        )
        self.name = f"gpt2_skill_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_dims
        self.n_embd = n_embd
        self.n_skills = n_skills

        # self._read_in = nn.Linear(n_dims, n_embd)
        # self._id_read_in = nn.Linear(n_skills+1, n_embd)
        self._read_in = nn.Linear(n_dims, n_embd//2)
        self._id_read_in = nn.Linear(n_skills+1, n_embd//2)
        self._backbone = GPT2Model(configuration)
        # self._read_out = nn.Linear(n_embd, n_dims)
        # self._id_read_out = nn.Linear(n_embd, n_skills+1)
        self._read_out = nn.Linear(n_embd//2, n_dims)
        self._id_read_out = nn.Linear(n_embd//2, n_skills+1)


    @staticmethod
    def _combine(xs_b, ys_b, layer_activations=None):
        """Interleaves the x's and the y's into a single sequence."""
        bsize, points, dim = xs_b.shape
        ys_b_wide = torch.cat(
            (
                ys_b.view(bsize, points, 1),
                torch.zeros(bsize, points, dim - 1, device=ys_b.device),
            ),
            axis=2,
        )
        stacked_tensors = [xs_b]

        if layer_activations is not None:
            layer_activations_wide = []
            for act in layer_activations:
                _, _, hidden_size = act.shape
                act_wide = torch.cat(
                (
                    act.view(bsize, points, hidden_size),
                    torch.zeros(bsize, points, dim - hidden_size, device=ys_b.device)
                ),
                axis=2,)
                layer_activations_wide.append(act_wide)

            stacked_tensors += layer_activations_wide + [ys_b_wide]
        else:
            stacked_tensors += [ys_b_wide]

        zs = torch.stack(stacked_tensors, dim=2)
        zs = zs.view(bsize, len(stacked_tensors) * points, dim)
        return zs


    def forward_ori(self, xs, IDs, loss_func, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(xs.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= xs.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs are not defined")
            
        n_skills = self.n_skills
        bsize, n_points, _ = xs.shape
        xs_embed = self._read_in(xs)
        # ids = torch.zeros_like(xs_embed)[:,:,:n_skills]
        ids = torch.zeros_like(xs_embed)[:,:,:n_skills+1]
        for i in range(bsize):
            for j in range(n_points):
                if IDs[i,j] >=0:
                    ids[i,j,IDs[i,j]] = 1.
        ids_embed = self._id_read_in(ids)
        embeds = torch.stack([xs_embed, ids_embed], dim=2)
        embeds = embeds.view(bsize, -1, self.n_embd)

        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        mask = (IDs >=0).reshape(bsize, n_points, 1)
        # mask_stop = (IDs != n_skills).reshape(bsize, n_points, 1)
        
        weighting = torch.tensor(list(map(lambda i: 1.1**i-1., range(mask.shape[1]-1)))).to(xs.device)
        # weighting = weighting / weighting.mean()
        weighting = weighting / weighting[-1]
        weighting = weighting.reshape(1,-1).repeat(64,1)

        mask_x = mask[:,:-1,0] #* mask_stop[:,:-1,0]
        pred_xs = self._read_out(output[:,1:][:,::2])
        loss1 = loss_func(pred_xs[:,:-1][mask_x], xs[:,1:][mask_x]).sum(-1).mean()
        # loss1 = (loss_func(pred_xs[:,:-1][mask_x], xs[:,1:][mask_x]).sum(-1) * weighting[mask_x]).mean()
        # loss1 = (loss_func(pred_xs[:,:-1][mask_x], xs[:,1:][mask_x]).sum(-1)[self.n_positions//2:]).mean()
        
        mask_id = mask[:,1:,0] * mask[:,:-1,0] #* mask_stop[:,:-1,0]


        pred_ids = self._id_read_out(output[:,::2])
        # loss2 = loss_func(nn.Softmax(dim=-1)(pred_ids[:,1:][mask_id]), ids[:,1:][mask_id]).sum(-1).mean()
        out = (pred_ids[:,1:][mask_id])
        target = (IDs[:,1:][mask_id])
        loss2 = nn.CrossEntropyLoss()(out, target)
        # loss2 = (nn.CrossEntropyLoss(reduction='none')(out, target) * weighting[mask_id]).mean()
        # loss2 = (nn.CrossEntropyLoss(reduction='none')(out, target) [self.n_positions//2:]).mean()
        # loss2 = loss_func(pred_ids[:,1:][mask_id], ids[:,1:][mask_id]).sum(-1).mean()

        losses = [loss1,loss2]
        # print(loss2)
        # exit()
        return losses, sum(losses)


    def forward(self, xs, IDs, loss_func, inds=None, layer_activations=None):
        if inds is None:
            inds = torch.arange(xs.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= xs.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs are not defined")
            
        n_skills = self.n_skills
        bsize, n_points, _ = xs.shape
        xs_embed = self._read_in(xs)
        # ids = torch.zeros_like(xs_embed)[:,:,:n_skills]
        ids = torch.zeros_like(xs_embed)[:,:,:n_skills+1]
        for i in range(bsize):
            for j in range(n_points):
                if IDs[i,j] >=0:
                    ids[i,j,IDs[i,j]] = 1.
        ids_embed = self._id_read_in(ids)
        embeds = torch.cat([xs_embed, ids_embed], dim=2)
        # embeds = embeds.view(bsize, -1, self.n_embd)

        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        mask = (IDs >=0).reshape(bsize, n_points, 1)
        # mask_stop = (IDs != n_skills).reshape(bsize, n_points, 1)
        
        weighting = torch.tensor(list(map(lambda i: 1.1**i-1., range(mask.shape[1]-1)))).to(xs.device)
        # weighting = weighting / weighting.mean()
        weighting = weighting / weighting[-1]
        weighting = weighting.reshape(1,-1).repeat(64,1)

        mask_x = mask[:,:-1,0] #* mask_stop[:,:-1,0]
        # pred_xs = self._read_out(output[:,1:][:,::2])
        pred_xs = self._read_out(output[:,:,:self.n_embd//2])
        loss1 = loss_func(pred_xs[:,:-1][mask_x], xs[:,1:][mask_x]).sum(-1).mean()
        # loss1 = (loss_func(pred_xs[:,:-1][mask_x], xs[:,1:][mask_x]).sum(-1) * weighting[mask_x]).mean()
        # loss1 = (loss_func(pred_xs[:,:-1][mask_x], xs[:,1:][mask_x]).sum(-1)[self.n_positions//2:]).mean()
        
        mask_id = mask[:,1:,0] * mask[:,:-1,0] #* mask_stop[:,:-1,0]


        # pred_ids = self._id_read_out(output[:,::2])
        pred_ids = self._id_read_out(output[:,:,self.n_embd//2:])
        # loss2 = loss_func(nn.Softmax(dim=-1)(pred_ids[:,1:][mask_id]), ids[:,1:][mask_id]).sum(-1).mean()
        out = (pred_ids[:,:-1][mask_id])
        target = (IDs[:,1:][mask_id])
        loss2 = nn.CrossEntropyLoss()(out, target)
        # loss2 = (nn.CrossEntropyLoss(reduction='none')(out, target) * weighting[mask_id]).mean()
        # loss2 = (nn.CrossEntropyLoss(reduction='none')(out, target) [self.n_positions//2:]).mean()
        # loss2 = loss_func(pred_ids[:,1:][mask_id], ids[:,1:][mask_id]).sum(-1).mean()

        losses = [loss1,loss2]
        # print(loss2)
        # exit()
        return losses, sum(losses)


    def predict_ori(self, xs, task, inds=None):
        n_skills = self.n_skills + 1
        chain_length = task.chain_length
        bsize = xs.shape[0]

        # inds = [80]
        # inds = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
        
        ys, ids_b, skill_ids, func_ids = task.evaluate(xs, eval=True)
        print(skill_ids[0], ids_b[0])
        xs, ys = xs.cuda(), ys.cuda()
        print(ys.shape[0], self.n_positions, n_skills)
        n_points = min(ys.shape[1], self.n_positions)-n_skills-1
        n_points = 50
        print(n_points, ys.shape[1]-n_skills-1,self.n_positions-n_skills-1)
        if inds is None:
            inds = list(range(n_points))
            # inds=[90]
            # print(inds)
        inds = [0,7,14,21,28,35,42,49]
        inds = [i * 8 for i in range(5)]
        # inds = [0,7]
        # inds = [0,12,24,36,48,60]
        with torch.no_grad():
            xs_embed = self._read_in(ys)
            ids = torch.zeros_like(xs_embed)[:,:,:n_skills]
            for i in range(bsize):
                for j in range(n_points):
                    if ids_b[i,j] >=0:
                        ids[i,j,ids_b[i,j]] = 1.
            ids_embed = self._id_read_in(ids)

            init_x = xs[:,-1:]
            init_x_embed = self._read_in(init_x)
            init_id = torch.zeros_like(ids[:,:1])
            for batch_idx in range(bsize):
                init_id[batch_idx,0,skill_ids[batch_idx,0]] = 1.
            init_id_embed = self._id_read_in(init_id)

            # true label generation
            # func_set = task.func_set.cuda()
            func_set = torch.from_numpy(task.func_set).float().cuda()
            ys_b = torch.zeros_like(init_x)
            print(func_set.shape, ids_b.shape, func_ids.shape, chain_length)
            for batch_idx in range(bsize):
                y = init_x[batch_idx].clone()
                for i in range(chain_length):
                    y = y @ func_set[skill_ids[batch_idx,i],func_ids[batch_idx,i]]
                    y = task.act_func(y)
                ys_b[batch_idx] = y
            
            # calculate the loss
            losses = torch.ones(bsize, len(inds)) * -1
            for j in range(len(inds)):
                n = inds[j]
                zs = torch.stack([xs_embed[:,:n].clone(), ids_embed[:,:n].clone()], dim=2)
                zs = zs.view(bsize, -1, self.n_embd)

                if n > 0:
                    # print(ids_b[:,n-2:n])
                    init_id = torch.zeros_like(ids[:,:1])                   # feeding as s_empty
                    # init_id[:,0,n_skills] = torch.ones_like(init_id[:,0,0]) # feeding as s_stop
                    # for batch_idx in range(bsize):
                    #     if ids_b[batch_idx,n-1] == n_skills:
                    #         init_id[batch_idx, 0] = n_skills
                    zs[:,-1:] = self._id_read_in(init_id)

                zs = torch.cat([zs, init_x_embed.clone(), init_id_embed.clone()], dim=1)
                
                # losses = [-1 for i in range(bsize)]

                
                # print(ys_b)
                # ## ordered prompt
                # ys_b = init_x.clone()
                # for i in range(n_skills):
                #     ys_b = torch.nn.functional.relu(ys_b @ skill_set[i,subskill_ids[:,i]]) * math.sqrt(2 / self.n_dims)
                ## random prompt
                # ys_b = torch.zeros_like(init_x)
                # for i in range(bsize):
                #     y = init_x[i,0].clone()
                #     for j in range(subskill_ids.shape[1]):
                #         y = torch.nn.functional.relu(y @ skill_set[subskill_ids[i,j],0]) * math.sqrt(2 / self.n_dims)
                #     ys_b[i,0] = y

                # for _ in range(min(chain_length, self.n_positions-n+10)):
                # for _ in range(min(chain_length*3,self.n_positions-n-1)):
                for _ in range(7):
                    # predict y
                    out_embed = self._backbone(inputs_embeds=zs).last_hidden_state
                    out_embed = out_embed[:,-1:]
                    out = self._read_out(out_embed)
                    in_embed = self._read_in(out)
                    zs = torch.cat([zs, in_embed], dim=1)
                    # predict skill
                    id_out_embed = self._backbone(inputs_embeds=zs).last_hidden_state
                    id_out_embed = id_out_embed[:,-1:]
                    id_out = self._id_read_out(id_out_embed)
                    id_pred = id_out[:,:,:n_skills+1].argmax(dim=2)
                    # if _ == 2:
                    #     print(id_out[0,:,:n_skills+1])
                    #     exit()
                    print(n,_,id_pred.view(-1))
                    # exit()
                    id_in = torch.zeros_like(id_out)
                    for i in range(bsize):
                        id_in[i,0,id_pred[i]] = 1.
                        id_in_embed = self._id_read_in(id_in)
                        if id_pred[i] == skill_ids[i,-1] and losses[i,j] < 0:
                        # if False:
                        # if skill_ids[i,_+1] == 1 and losses[i] < 0:
                            # losses[i,j] = (out[i] - ys_b[i]).square().mean()
                            # print(n, i, _, losses[i,j], out[i].norm(), ys_b[i].norm())
                            print(n, i, _)
                            # if _ == 4:
                            #     print(id_out[i,:,:n_skills+1])
                            #     exit()
                            # break
                        # id_in[i,0,skill_ids[i,_+1]] = 1.
                        # id_in_embed = self._id_read_in(id_in)
                    zs = torch.cat([zs, id_in_embed], dim=1)
                for i in range(bsize):
                    if losses[i,j] < 0:
                        losses[i,j] = (out[i] - ys_b[i]).square().mean()
                        print(n, i, _, losses[i,j], out[i].norm(), ys_b[i].norm())
                        # print(skill_ids[i], ids_b[i,:n])
                        # print(n, i, _, losses[i])
            print(losses.mean(dim=0))
            exit()
            return losses

    def predict(self, xs, task, inds=None):
        n_skills = self.n_skills + 1
        chain_length = task.chain_length
        bsize = xs.shape[0]

        # inds = [80]
        # inds = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
        
        ys, ids_b, skill_ids, func_ids = task.evaluate(xs, eval=True)
        # print(skill_ids[0], ids_b[0])
        xs, ys = xs.cuda(), ys.cuda()
        # print(ys.shape, ids_b.shape, skill_ids.shape, func_ids.shape)
        # print(ids_b[0], skill_ids[0], func_ids[0])
        # exit()
        # print(ys.shape[0], self.n_positions, n_skills)
        n_points = min(ys.shape[1], self.n_positions)-n_skills-1
        n_points = 50
        # print(n_points, ys.shape[1]-n_skills-1,self.n_positions-n_skills-1)
        if inds is None:
            inds = list(range(n_points))
            # inds=[90]
            # print(inds)
        # inds = [0,7,14,21,28,35,42,49]
        inds = [i*3*chain_length for i in range(6)]
        # print(inds)
        # inds = [0,12,24,36,48,60]
        with torch.no_grad():
            xs_embed = self._read_in(ys)
            ids = torch.zeros_like(xs_embed)[:,:,:n_skills]
            for i in range(bsize):
                for j in range(n_points):
                    if ids_b[i,j] >=0:
                        ids[i,j,ids_b[i,j]] = 1.
            ids_embed = self._id_read_in(ids)

            init_x = xs[:,-1:]
            init_x_embed = self._read_in(init_x)
            init_id = torch.zeros_like(ids[:,:1])
            for batch_idx in range(bsize):
                init_id[batch_idx,0,skill_ids[batch_idx,0]] = 1.
            init_id_embed = self._id_read_in(init_id)

            # true label generation
            # func_set = task.func_set.cuda()
            func_set = torch.from_numpy(task.func_set).float().cuda()
            ys_b = torch.zeros_like(init_x)
            # print(func_set.shape, ids_b.shape, func_ids.shape, chain_length)
            for batch_idx in range(bsize):
                y = init_x[batch_idx].clone()
                for i in range(chain_length):
                    y = y @ func_set[skill_ids[batch_idx,i],func_ids[batch_idx,i]]
                    y = task.act_func(y)
                ys_b[batch_idx] = y
            
            # calculate the loss
            losses = torch.ones(bsize, len(inds)) * -1
            accs = torch.zeros(bsize, len(inds))
            for j in range(len(inds)):
                n = inds[j]
                zs = torch.cat([xs_embed[:,:n].clone(), ids_embed[:,:n].clone()], dim=2)
                # zs = zs.view(bsize, -1, self.n_embd)

                if n > 0:
                    # print(ids_b[:,n-2:n])
                    init_id = torch.zeros_like(ids[:,:1])                   # feeding as s_empty
                    # init_id[:,0,n_skills] = torch.ones_like(init_id[:,0,0]) # feeding as s_stop
                    # for batch_idx in range(bsize):
                    #     if ids_b[batch_idx,n-1] == n_skills:
                    #         init_id[batch_idx, 0] = n_skills
                    zs[:,-1:,self.n_embd//2:] = self._id_read_in(init_id)

                zs = torch.cat([zs, torch.cat([init_x_embed.clone(), init_id_embed.clone()], dim=2)], dim=1)
                
                # losses = [-1 for i in range(bsize)]

                
                # print(ys_b)
                # ## ordered prompt
                # ys_b = init_x.clone()
                # for i in range(n_skills):
                #     ys_b = torch.nn.functional.relu(ys_b @ skill_set[i,subskill_ids[:,i]]) * math.sqrt(2 / self.n_dims)
                ## random prompt
                # ys_b = torch.zeros_like(init_x)
                # for i in range(bsize):
                #     y = init_x[i,0].clone()
                #     for j in range(subskill_ids.shape[1]):
                #         y = torch.nn.functional.relu(y @ skill_set[subskill_ids[i,j],0]) * math.sqrt(2 / self.n_dims)
                #     ys_b[i,0] = y

                # for _ in range(min(chain_length, self.n_positions-n+10)):
                # for _ in range(min(chain_length*3,self.n_positions-n-1)):
                for _ in range(10):
                    # predict y
                    all_output = self._backbone(inputs_embeds=zs)
                    out_embed = all_output.last_hidden_state
                    # all_output = self._backbone(inputs_embeds=zs, output_attentions=True)
                    # out_embed = all_output.last_hidden_state
                    # attens = all_output.attentions
                    # if _ == 4:
                    #     # print(attens)
                    #     # import json
                    #     # for i in attens:
                    #         # i = i.cpu().numpy()
                    #     # with open("attentions.json", "w") as fp:
                    #         # json.dump(attens, fp, indent=2)
                    #     torch.save(attens, "attentions.pt")
                    #     print(attens[1])
                    #     # exit()
                    out_embed = out_embed[:,-1:]
                    out = self._read_out(out_embed[:,:,:self.n_embd//2])
                    in_embed = self._read_in(out)
                    # zs = torch.cat([zs, in_embed], dim=1)
                    # predict skill
                    # id_out_embed = self._backbone(inputs_embeds=zs).last_hidden_state
                    # id_out_embed = id_out_embed[:,-1:]
                    id_out = self._id_read_out(out_embed[:,:,self.n_embd//2:])
                    id_pred = id_out[:,:,:n_skills+1].argmax(dim=2)
                    # if _ == 2:
                    #     print(id_out[0,:,:n_skills+1])
                    #     exit()
                    # print(n,_,id_pred.view(-1))
                    # exit()
                    id_in = torch.zeros_like(id_out)
                    for i in range(bsize):
                        id_in[i,0,id_pred[i]] = 1.
                        id_in_embed = self._id_read_in(id_in)
                        # if id_pred[i] == skill_ids[i,-1] and losses[i,j] < 0:
                        # print(_, chain_length, id_pred[i].item(), skill_ids[i,_+1], accs[i,j].item())
                        # print()
                        # exit()
                        if _ < chain_length and id_pred[i] == skill_ids[i,_+1] and accs[i,j] == _:
                            accs[i,j] += 1
                        if id_pred[i] == self.n_skills and losses[i,j] < 0:
                        # if False:
                        # if skill_ids[i,_+1] == 1 and losses[i] < 0:
                            losses[i,j] = (out[i] - ys_b[i]).square().mean()
                            # print(n, i, _, losses[i,j], out[i].norm(), ys_b[i].norm())
                            # print(n, i, _)
                            # if _ == 4:
                            #     print(id_out[i,:,:n_skills+1])
                            #     exit()
                            # break
                        # id_in[i,0,skill_ids[i,_+1]] = 1.
                        # id_in_embed = self._id_read_in(id_in)
                    zs = torch.cat([zs, torch.cat([in_embed, id_in_embed], dim=2)], dim=1)
                for i in range(bsize):
                    if losses[i,j] < 0:
                        losses[i,j] = (out[i] - ys_b[i]).square().mean()
                        # print(n, i, _, losses[i,j], out[i].norm(), ys_b[i].norm())
                        # print(skill_ids[i], ids_b[i,:n])
                        # print(n, i, _, losses[i])
            # accs = (accs==5).float().mean(dim=0)
            # print(accs)
            # print(losses.mean(dim=0))
            return (accs==5).float()
            #
            # exit()
            return losses
