"""Non-standard embedding implementations."""

import torch
import math

from typing import Tuple
from einops import repeat
import random


class PositionalEmbedding(torch.nn.Module):
    # https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15C1-L31C37
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()

        self.demb = demb

        inv_freq = (1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))).float()
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, pos_seq, bsz=None):
        # sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        tensor_24_17_1 = pos_seq.float().unsqueeze(2)

        vector_512_expanded = self.inv_freq.unsqueeze(0).unsqueeze(1)

        result = torch.matmul(tensor_24_17_1, vector_512_expanded)

        sinusoid_inp = result.squeeze(2)

        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
        return pos_emb
        # if bsz is not None:
        #     return pos_emb[:,None,:].expand(-1, bsz, -1)
        # else:
        #     return pos_emb[:,None,:]


class RandomNoise(torch.nn.Module):

    def __init__(self, embedding_dim, max_seq_length=5000):
        super().__init__()
        self.embedding_dim = embedding_dim

    def forward(self, input_ids):
        return torch.normal(0, 0.1, size=(input_ids.size(0), input_ids.size(1), self.embedding_dim)).to(input_ids.device)


class RPE(torch.nn.Module):
    # https://jaketae.github.io/study/relative-positional-encoding/
    # def __init__(self, embedding_dim, max_seq_length=5000):
    #     super().__init__()

    # def forward(self, input_ids):
    #     return torch.normal(0, 0.1, size=input_ids.shape)
    def __init__(self, d_model, num_heads, max_len=1024, dropout=0.1):
        super().__init__()
        d_head, remainder = divmod(d_model, num_heads)
        if remainder:
            raise ValueError("incompatible `d_model` and `num_heads`")
        self.max_len = max_len
        self.d_model = d_model
        self.num_heads = num_heads
        self.key = torch.nn.Linear(d_model, d_model)
        self.value = torch.nn.Linear(d_model, d_model)
        self.query = torch.nn.Linear(d_model, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.Er = torch.nn.Parameter(torch.randn(max_len, d_head))
        self.register_buffer("mask", torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0))
        # self.mask.shape = (1, 1, max_len, max_len)

    def forward(self, x):
        # x.shape == (batch_size, seq_len, d_model)
        batch_size, seq_len, _ = x.shape

        if seq_len > self.max_len:
            raise ValueError("sequence length exceeds model capacity")

        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
        # k_t.shape = (batch_size, num_heads, d_head, seq_len)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        # shape = (batch_size, num_heads, seq_len, d_head)

        start = self.max_len - seq_len
        Er_t = self.Er[start:, :].transpose(0, 1)
        # Er_t.shape = (d_head, seq_len)
        QEr = torch.matmul(q, Er_t)
        # QEr.shape = (batch_size, num_heads, seq_len, seq_len)
        Srel = self.skew(QEr)
        # Srel.shape = (batch_size, num_heads, seq_len, seq_len)

        QK_t = torch.matmul(q, k_t)
        # QK_t.shape = (batch_size, num_heads, seq_len, seq_len)
        attn = (QK_t + Srel) / math.sqrt(q.size(-1))
        mask = self.mask[:, :, :seq_len, :seq_len]
        # mask.shape = (1, 1, seq_len, seq_len)
        attn = attn.masked_fill(mask == 0, float("-inf"))
        # attn.shape = (batch_size, num_heads, seq_len, seq_len)
        attn = torch.nn.functional.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        # out.shape = (batch_size, num_heads, seq_len, d_head)
        out = out.transpose(1, 2)
        # out.shape == (batch_size, seq_len, num_heads, d_head)
        out = out.reshape(batch_size, seq_len, -1)
        # out.shape == (batch_size, seq_len, d_model)
        return self.dropout(out)

    def skew(self, QEr):
        # QEr.shape = (batch_size, num_heads, seq_len, seq_len)
        padded = torch.nn.functional.pad(QEr, (1, 0))
        # padded.shape = (batch_size, num_heads, seq_len, 1 + seq_len)
        batch_size, num_heads, num_rows, num_cols = padded.shape
        reshaped = padded.reshape(batch_size, num_heads, num_cols, num_rows)
        # reshaped.size = (batch_size, num_heads, 1 + seq_len, seq_len)
        Srel = reshaped[:, :, 1:, :]
        # Srel.shape = (batch_size, num_heads, seq_len, seq_len)
        return Srel


# module partially stolen from pytorch examples:
class SinusoidalPositional(torch.nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
    in the sequence. The positional encodings have the same dimension as
    the embeddings, so that the two can be summed. Here, we use sine and cosine
    functions of different frequencies.
    """

    def __init__(self, embedding_dim, max_seq_length=5000):
        super().__init__()

        pe = torch.zeros(max_seq_length, embedding_dim)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, input_ids):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [batch size, sequence length, embed dim]
            output: [batch size, sequence length, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
        return self.pe[:, : input_ids.shape[1], :]


class ScaledSinosoidal(SinusoidalPositional):
    """Sinusoidal with scaling (see FLASH paper)."""

    def __init__(self, embedding_dim, max_seq_length):
        super().__init__(embedding_dim, max_seq_length)
        self.scale_factor = torch.nn.Parameter(torch.tensor([1.0 / embedding_dim**0.5]))

    def forward(self, input_ids):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [batch size, sequence length, embed dim]
            output: [batch size, sequence length, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
        return self.scale_factor * self.pe[:, : input_ids.shape[1], :]


class LearnablePositional(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        position_ids = self.position_ids[:, : input_ids.shape[1]]
        return self.embedding(position_ids)


class LearnablePositionalRand(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()
        self.max_length = max_seq_length
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        seq_length = input_ids.shape[1]
        device = input_ids.device
        if seq_length > self.max_length:  # max length will be increased to max sequnece length if max length is short
            max_length = seq_length
        else:
            max_length = self.max_length
        position_ids = self.position_ids[:, : input_ids.shape[1]]
        position_ids = torch.sort(torch.randperm(max_length, dtype=torch.long, device=device)[:seq_length]).values
        return self.embedding(position_ids)


class RecyclePositional(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024, max_k=100):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))
        self.max_k = max_k

    def helper(self, mask, input_ids, k=1, device="cuda"):
        output = torch.zeros_like(input_ids).to(device)
        indices = torch.nonzero(mask)
        output[indices[:, 0], indices[:, 1]] = (torch.arange(k, k + len(indices))).to(device)
        max_values, _ = torch.max(output, dim=1)
        max_values = max_values - k + 1
        max_values = torch.cat((torch.tensor([0], device=device), max_values[:-1]))[:, None]
        output = torch.max(output - max_values, torch.zeros_like(output).to(device))
        return output

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        pad tokenizer 
        '+': 14
        '=': 17
        '[EOS]': 3
        """
        device = input_ids.device
        position_ids = torch.arange(input_ids.shape[1]).expand(input_ids.shape[0], -1).to(device) + 1
        indices_14 = (input_ids == 14).nonzero()[:, 1]
        indices_17 = (input_ids == 17).nonzero()[:, 1]
        indices_3 = (input_ids == 3).nonzero()[:, 1]

        if not self.training:  # this is inefficient so only check properly in testing cause we have control during training
            if indices_17.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 17).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_17 = indices_17[~mask]
            if indices_14.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 14).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_14 = indices_14[~mask]

        mask_num_1 = position_ids <= indices_14.unsqueeze(1).expand_as(position_ids)  # mask for elements before +
        mask_num_2 = position_ids > indices_14.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after +
        mask_ans_for_num_2 = position_ids > indices_17.unsqueeze(1).expand_as(position_ids)  # rmeoving the = symbol from the AND
        mask_ans = position_ids > indices_17.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after =
        mask_num_2 = torch.logical_and(mask_num_2, ~mask_ans_for_num_2)

        k = 1
        if self.training:
            mask_eos = position_ids > indices_3.unsqueeze(1).expand_as(position_ids)
            mask_ans = torch.logical_and(mask_ans, ~mask_eos)
            k = random.randint(1, self.max_k)
        ## This means that in testing the [PAD] after the = sign is included in the 3rd positional embedding

        # mask = torch.logical_or(torch.logical_or(mask_num_1,mask_num_2),mask_ans)

        output_1 = self.helper(mask_num_1, input_ids, k=k, device=device)
        output_2 = self.helper(mask_num_2, input_ids, k=k, device=device)
        output_ans = self.helper(mask_ans, input_ids, k=k, device=device)
        output = output_1 + output_2 + output_ans
        """
        output=tensor([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,  0,
         10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
          0, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
         27,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  0, 10,
         11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  0, 10, 11,
         12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
        input_ids=tensor([[13,  4,  9,  0,  6,  8, 13, 12, 11,  9, 13,  8,  9,  4, 13,  8,  6, 14,
          4,  6,  5,  0, 11,  8,  9,  6, 10,  0, 10,  6, 11,  5,  7, 12,  0, 13,
         17, 13,  6,  0, 10,  0, 13, 12,  8,  5,  8,  6,  6,  6, 11,  7, 11,  8,
          7,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [11,  5,  7, 10,  4,  0,  7,  0, 12, 12,  6, 12, 12,  8,  5,  8, 14,  8,
         10, 11,  6,  6,  9,  7,  5,  4,  4, 10, 13,  0,  7, 10,  9, 17,  5, 12,
          4, 13,  6, 12,  5,  4,  7, 12,  8,  0,  8,  9,  4, 10,  3,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0') for k=10
        """
        return self.embedding(output)


class RecycleMax100(torch.nn.Module):
    """If at test time the position id is greater than 100 we subtract 100 so it is cyclic"""

    def __init__(self, embedding_dim, max_seq_length=1024, max_k=100):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))
        self.max_k = max_k

    def helper(self, mask, input_ids, k=1, device="cuda"):
        output = torch.zeros_like(input_ids).to(device)
        indices = torch.nonzero(mask)
        output[indices[:, 0], indices[:, 1]] = (torch.arange(k, k + len(indices))).to(device)
        max_values, _ = torch.max(output, dim=1)
        max_values = max_values - k + 1
        max_values = torch.cat((torch.tensor([0], device=device), max_values[:-1]))[:, None]
        output = torch.max(output - max_values, torch.zeros_like(output).to(device))
        return output

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        pad tokenizer 
        '+': 14
        '=': 17
        '[EOS]': 3
        """
        device = input_ids.device
        position_ids = torch.arange(input_ids.shape[1]).expand(input_ids.shape[0], -1).to(device) + 1
        indices_14 = (input_ids == 14).nonzero()[:, 1]
        indices_17 = (input_ids == 17).nonzero()[:, 1]
        indices_3 = (input_ids == 3).nonzero()[:, 1]

        if not self.training:  # this is inefficient so only check properly in testing cause we have control during training
            if indices_17.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 17).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_17 = indices_17[~mask]
            if indices_14.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 14).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_14 = indices_14[~mask]

        mask_num_1 = position_ids <= indices_14.unsqueeze(1).expand_as(position_ids)  # mask for elements before +
        mask_num_2 = position_ids > indices_14.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after +
        mask_ans_for_num_2 = position_ids > indices_17.unsqueeze(1).expand_as(position_ids)  # rmeoving the = symbol from the AND
        mask_ans = position_ids > indices_17.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after =
        mask_num_2 = torch.logical_and(mask_num_2, ~mask_ans_for_num_2)

        k = 1
        if self.training:
            mask_eos = position_ids > indices_3.unsqueeze(1).expand_as(position_ids)
            mask_ans = torch.logical_and(mask_ans, ~mask_eos)
            k = random.randint(1, self.max_k)
        ## This means that in testing the [PAD] after the = sign is included in the 3rd positional embedding

        # mask = torch.logical_or(torch.logical_or(mask_num_1,mask_num_2),mask_ans)

        output_1 = self.helper(mask_num_1, input_ids, k=k, device=device)
        output_2 = self.helper(mask_num_2, input_ids, k=k, device=device)
        output_ans = self.helper(mask_ans, input_ids, k=k, device=device)
        output = output_1 + output_2 + output_ans

        # if not self.training: # cap at 100, i.e. cyclic at test time
        # output = torch.where(output > 100, output - 100, output) NOTE this won't work if the number is 200 or bigger
        mask = output > self.max_k
        output[mask] %= self.max_k

        return self.embedding(output)


class RecyclePlusAndMinus(torch.nn.Module):
    """Recycle emebedding that can handle both plus and minus."""

    def __init__(self, embedding_dim, max_seq_length=1024, max_k=100):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))
        self.max_k = max_k

    def helper(self, mask, input_ids, k=1, device="cuda"):
        output = torch.zeros_like(input_ids).to(device)
        indices = torch.nonzero(mask)
        output[indices[:, 0], indices[:, 1]] = (torch.arange(k, k + len(indices))).to(device)
        max_values, _ = torch.max(output, dim=1)
        max_values = max_values - k + 1
        max_values = torch.cat((torch.tensor([0], device=device), max_values[:-1]))[:, None]
        output = torch.max(output - max_values, torch.zeros_like(output).to(device))
        return output

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        pad tokenizer 
        '+': 14
        '-': 15
        '=': 17
        '[EOS]': 3
        """
        device = input_ids.device
        position_ids = torch.arange(input_ids.shape[1]).expand(input_ids.shape[0], -1).to(device) + 1
        indices_15_or_14 = ((input_ids == 15) | (input_ids == 14)).nonzero()[:, 1]
        indices_17 = (input_ids == 17).nonzero()[:, 1]
        indices_3 = (input_ids == 3).nonzero()[:, 1]

        if not self.training:  # this is inefficient so only check properly in testing cause we have control during training
            if indices_17.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 17).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_17 = indices_17[~mask]
            if indices_15_or_14.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = ((input_ids == 15) | (input_ids == 14)).nonzero()[:, 1]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_15_or_14 = indices_15_or_14[~mask]

        mask_num_1 = position_ids <= indices_15_or_14.unsqueeze(1).expand_as(position_ids)  # mask for elements before +
        mask_num_2 = position_ids > indices_15_or_14.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after +
        mask_ans_for_num_2 = position_ids > indices_17.unsqueeze(1).expand_as(position_ids)  # rmeoving the = symbol from the AND
        mask_ans = position_ids > indices_17.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after =
        mask_num_2 = torch.logical_and(mask_num_2, ~mask_ans_for_num_2)

        k = 1
        if self.training:
            mask_eos = position_ids > indices_3.unsqueeze(1).expand_as(position_ids)
            mask_ans = torch.logical_and(mask_ans, ~mask_eos)
            k = random.randint(1, self.max_k)
        ## This means that in testing the [PAD] after the = sign is included in the 3rd positional embedding

        # mask = torch.logical_or(torch.logical_or(mask_num_1,mask_num_2),mask_ans)

        output_1 = self.helper(mask_num_1, input_ids, k=k, device=device)
        output_2 = self.helper(mask_num_2, input_ids, k=k, device=device)
        output_ans = self.helper(mask_ans, input_ids, k=k, device=device)
        output = output_1 + output_2 + output_ans
        """
        [ 88,  89,  90,  91,  92,  93,  94,   0,  88,  89,  90,  91,  92,  93,
          94,  95,   0,  88,  89,  90,  91,  92,  93,  94,  95,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0],
        [ 88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101,
         102, 103, 104, 105, 106, 107, 108, 109,   0,  88,  89,  90,  91,  92,
           0,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100,
         101, 102, 103, 104, 105, 106, 107, 108, 109,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0]
        
        [13, 10,  8,  8,  9,  4, 12, 14, 13,  5,  7,  5,  4,  9, 13,  9, 17, 12,
         12, 11,  9,  9,  9, 11, 10,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0],
        [ 4,  4,  4,  4,  5,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
          4,  4,  4,  4, 15,  4,  4,  4,  4,  5, 17,  4,  4,  4,  4,  5,  4,  4,
          4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  3,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0]] for k=88
        """
        return self.embedding(output)


class RecyclePositionalMixed(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def helper(self, mask, input_ids, k=1, device="cuda"):
        output = torch.zeros_like(input_ids).to(device)
        indices = torch.nonzero(mask)
        output[indices[:, 0], indices[:, 1]] = (torch.arange(k, k + len(indices))).to(device)
        max_values, _ = torch.max(output, dim=1)
        max_values = max_values - k + 1
        max_values = torch.cat((torch.tensor([0], device=device), max_values[:-1]))[:, None]
        output = torch.max(output - max_values, torch.zeros_like(output).to(device))
        return output

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        pad tokenizer 
        '+': 14
        '=': 17
        '[EOS]': 3
        """
        device = input_ids.device
        if random.random() < 0.7:
            position_ids = torch.arange(input_ids.shape[1]).expand(input_ids.shape[0], -1).to(device) + 1

            indices_14 = (input_ids == 14).nonzero()[:, 1]
            indices_17 = (input_ids == 17).nonzero()[:, 1]
            indices_3 = (input_ids == 3).nonzero()[:, 1]
            mask_num_1 = position_ids <= indices_14.unsqueeze(1).expand_as(position_ids)  # mask for elements before +
            mask_num_2 = position_ids > indices_14.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after +
            mask_ans_for_num_2 = position_ids > indices_17.unsqueeze(1).expand_as(position_ids)  # rmeoving the = symbol from the AND
            mask_ans = position_ids > indices_17.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after =
            mask_num_2 = torch.logical_and(mask_num_2, ~mask_ans_for_num_2)

            k = 1
            if self.training:
                mask_eos = position_ids > indices_3.unsqueeze(1).expand_as(position_ids)
                mask_ans = torch.logical_and(mask_ans, ~mask_eos)
                k = random.randint(1, 100)
            ## This means that in testing the [PAD] after the = sign is included in the 3rd positional embedding

            # mask = torch.logical_or(torch.logical_or(mask_num_1,mask_num_2),mask_ans)
            output_1 = self.helper(mask_num_1, input_ids, k=k, device=device)
            output_2 = self.helper(mask_num_2, input_ids, k=k, device=device)
            output_ans = self.helper(mask_ans, input_ids, k=k, device=device)
            output = output_1 + output_2 + output_ans
        else:
            output = torch.zeros_like(input_ids)
        factor = random.choices([0, 0.1, 1, 10], weights=[1, 1, 3, 1], k=1)
        return self.embedding(output) * torch.tensor([factor], device=device)


class RecyclePositionalRand(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()
        self.max_length = max_seq_length
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def helper(self, mask, input_ids, k=1, device="cuda"):
        output = torch.zeros_like(input_ids).to(device)
        indices = torch.nonzero(mask)

        # random_values = torch.randint(1, self.max_length, size=(len(indices),)).to(device) ## NOTE: this means there may be repeated values in the embedding
        # sorted_indices = torch.argsort(random_values)
        # random_values = random_values[sorted_indices]
        # output[indices[:, 0], indices[:, 1]] = random_values

        random_permutation = torch.randperm(len(indices) + self.max_length, device=device)[: len(indices)]
        sorted_indices = torch.argsort(random_permutation)
        random_values = random_permutation[sorted_indices]
        output[indices[:, 0], indices[:, 1]] = random_values

        max_values, _ = torch.max(output, dim=1)
        max_values = torch.cat((torch.tensor([0], device=device), max_values[:-1]))[:, None]
        output = torch.max(output - max_values, torch.zeros_like(output).to(device))
        return output

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        pad tokenizer 
        '+': 14
        '=': 17
        '[EOS]': 3
        """
        device = input_ids.device
        position_ids = torch.arange(input_ids.shape[1]).expand(input_ids.shape[0], -1).to(device) + 1

        indices_14 = (input_ids == 14).nonzero()[:, 1]
        indices_17 = (input_ids == 17).nonzero()[:, 1]
        indices_3 = (input_ids == 3).nonzero()[:, 1]
        mask_num_1 = position_ids <= indices_14.unsqueeze(1).expand_as(position_ids)  # mask for elements before +
        mask_num_2 = position_ids > indices_14.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after +
        mask_ans_for_num_2 = position_ids > indices_17.unsqueeze(1).expand_as(position_ids)  # rmeoving the = symbol from the AND
        mask_ans = position_ids > indices_17.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after =
        mask_num_2 = torch.logical_and(mask_num_2, ~mask_ans_for_num_2)

        k = 1
        if self.training:
            mask_eos = position_ids > indices_3.unsqueeze(1).expand_as(position_ids)
            mask_ans = torch.logical_and(mask_ans, ~mask_eos)
            k = random.randint(1, 100)
        ## This means that in testing the [PAD] after the = sign is included in the 3rd positional embedding

        # mask = torch.logical_or(torch.logical_or(mask_num_1,mask_num_2),mask_ans)

        output_1 = self.helper(mask_num_1, input_ids, k=k, device=device)
        output_2 = self.helper(mask_num_2, input_ids, k=k, device=device)
        output_ans = self.helper(mask_ans, input_ids, k=k, device=device)
        output = output_1 + output_2 + output_ans
        # print(f"{input_ids=}")
        # print(f"{output=}")
        # exit()
        """
        input_ids=tensor([[13,  4,  9,  0,  6,  8, 13, 12, 11,  9, 13,  8,  9,  4, 13,  8,  6, 14,
          4,  6,  5,  0, 11,  8,  9,  6, 10,  0, 10,  6, 11,  5,  7, 12,  0, 13,
         17, 13,  6,  0, 10,  0, 13, 12,  8,  5,  8,  6,  6,  6, 11,  7, 11,  8,
          7,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [11,  5,  7, 10,  4,  0,  7,  0, 12, 12,  6, 12, 12,  8,  5,  8, 14,  8,
         10, 11,  6,  6,  9,  7,  5,  4,  4, 10, 13,  0,  7, 10,  9, 17,  5, 12,
          4, 13,  6, 12,  5,  4,  7, 12,  8,  0,  8,  9,  4, 10,  3,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
        output=tensor([[11, 16, 17, 26, 28, 34, 41, 45, 51, 53, 57, 63, 68, 75, 78, 95, 96,  0,
          2,  4,  7, 18, 21, 26, 28, 29, 53, 57, 66, 68, 73, 74, 75, 86, 93, 98,
          0,  7,  8, 13, 18, 20, 23, 24, 27, 28, 35, 50, 52, 64, 68, 71, 74, 82,
         92,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  4,  8,  9, 10, 11, 12, 14, 33, 34, 36, 45, 52, 54, 57, 63,  0,  5,
         11, 15, 17, 22, 24, 25, 32, 33, 43, 46, 50, 51, 52, 54, 59,  0,  6, 12,
         13, 23, 24, 27, 34, 40, 43, 53, 54, 55, 56, 57, 60, 61,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
        """
        return self.embedding(output)


class RecyclePositionalMul(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def helper(self, mask, input_ids, k=1, device="cuda"):
        output = torch.zeros_like(input_ids).to(device)
        indices = torch.nonzero(mask)
        output[indices[:, 0], indices[:, 1]] = (torch.arange(k, k + len(indices))).to(device)
        max_values, _ = torch.max(output, dim=1)
        max_values = max_values - k + 1
        max_values = torch.cat((torch.tensor([0], device=device), max_values[:-1]))[:, None]
        output = torch.max(output - max_values, torch.zeros_like(output).to(device))
        return output

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        pad tokenizer 
        'x': 16
        '=': 17
        '[EOS]': 3
        """
        device = input_ids.device
        position_ids = torch.arange(input_ids.shape[1]).expand(input_ids.shape[0], -1).to(device) + 1

        indices_14 = (input_ids == 16).nonzero()[:, 1]
        indices_17 = (input_ids == 17).nonzero()[:, 1]
        indices_3 = (input_ids == 3).nonzero()[:, 1]

        if not self.training:  # this is inefficient so only check properly in testing cause we have control during training
            if indices_17.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 17).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_17 = indices_17[~mask]
            if indices_14.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 16).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_14 = indices_14[~mask]

        mask_num_1 = position_ids <= indices_14.unsqueeze(1).expand_as(position_ids)  # mask for elements before +
        mask_num_2 = position_ids > indices_14.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after +
        mask_ans_for_num_2 = position_ids > indices_17.unsqueeze(1).expand_as(position_ids)  # rmeoving the = symbol from the AND
        mask_ans = position_ids > indices_17.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after =
        mask_num_2 = torch.logical_and(mask_num_2, ~mask_ans_for_num_2)

        k = 1
        if self.training:
            mask_eos = position_ids > indices_3.unsqueeze(1).expand_as(position_ids)
            mask_ans = torch.logical_and(mask_ans, ~mask_eos)
            k = random.randint(1, 100)
        ## This means that in testing the [PAD] after the = sign is included in the 3rd positional embedding

        # mask = torch.logical_or(torch.logical_or(mask_num_1,mask_num_2),mask_ans)

        output_1 = self.helper(mask_num_1, input_ids, k=k, device=device)
        output_2 = self.helper(mask_num_2, input_ids, k=k, device=device)
        output_ans = self.helper(mask_ans, input_ids, k=k, device=device)
        output = output_1 + output_2 + output_ans
        """
        output=tensor([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,  0,
         10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
          0, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
         27,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  0, 10,
         11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  0, 10, 11,
         12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
        input_ids=tensor([[13,  4,  9,  0,  6,  8, 13, 12, 11,  9, 13,  8,  9,  4, 13,  8,  6, 14,
          4,  6,  5,  0, 11,  8,  9,  6, 10,  0, 10,  6, 11,  5,  7, 12,  0, 13,
         17, 13,  6,  0, 10,  0, 13, 12,  8,  5,  8,  6,  6,  6, 11,  7, 11,  8,
          7,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [11,  5,  7, 10,  4,  0,  7,  0, 12, 12,  6, 12, 12,  8,  5,  8, 14,  8,
         10, 11,  6,  6,  9,  7,  5,  4,  4, 10, 13,  0,  7, 10,  9, 17,  5, 12,
          4, 13,  6, 12,  5,  4,  7, 12,  8,  0,  8,  9,  4, 10,  3,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0') for k=10
        """
        return self.embedding(output)


class RecyclePositionalDoubleForAns(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def helper(self, mask, input_ids, k=1, device="cuda"):
        output = torch.zeros_like(input_ids).to(device)
        indices = torch.nonzero(mask)
        output[indices[:, 0], indices[:, 1]] = (torch.arange(k, k + len(indices))).to(device)
        max_values, _ = torch.max(output, dim=1)
        max_values = max_values - k + 1
        max_values = torch.cat((torch.tensor([0], device=device), max_values[:-1]))[:, None]
        output = torch.max(output - max_values, torch.zeros_like(output).to(device))
        return output

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        pad tokenizer 
        '+': 14
        '=': 17
        '[EOS]': 3
        """
        device = input_ids.device
        position_ids = torch.arange(input_ids.shape[1]).expand(input_ids.shape[0], -1).to(device) + 1

        indices_14 = (input_ids == 14).nonzero()[:, 1]
        indices_17 = (input_ids == 17).nonzero()[:, 1]
        indices_3 = (input_ids == 3).nonzero()[:, 1]
        mask_num_1 = position_ids <= indices_14.unsqueeze(1).expand_as(position_ids)  # mask for elements before +
        mask_num_2 = position_ids > indices_14.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after +
        mask_ans_for_num_2 = position_ids > indices_17.unsqueeze(1).expand_as(position_ids)  # rmeoving the = symbol from the AND
        mask_ans = position_ids > indices_17.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after =
        mask_num_2 = torch.logical_and(mask_num_2, ~mask_ans_for_num_2)

        k = 1
        if self.training:
            mask_eos = position_ids > indices_3.unsqueeze(1).expand_as(position_ids)
            mask_ans = torch.logical_and(mask_ans, ~mask_eos)
            k = random.randint(1, 100)
        ## This means that in testing the [PAD] after the = sign is included in the 3rd positional embedding

        # mask = torch.logical_or(torch.logical_or(mask_num_1,mask_num_2),mask_ans)

        output_1 = self.helper(mask_num_1, input_ids, k=k, device=device)
        output_2 = self.helper(mask_num_2, input_ids, k=k, device=device)
        output_ans = self.helper(mask_ans, input_ids, k=k, device=device)
        output = output_1 + output_2

        return self.embedding(output) + (self.embedding(output_ans) * 2)


class MultiplicativeFactorPositional(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()
        self.embedding = torch.nn.Parameter(
            torch.randn(1, 1, embedding_dim)
        )  # we only have one position embedding, we just multiply it to change positions
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        position_ids = self.position_ids[:, : input_ids.shape[1]].view(1, -1, 1)
        output = position_ids * self.embedding
        return output


class MultiplicativeFactorPositionalRand(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()

        self.max_length = max_seq_length
        self.embedding = torch.nn.Parameter(
            torch.randn(1, 1, embedding_dim)
        )  # we only have one position embedding, we just multiply it to change positions
        self.embedding = torch.nn.Parameter(torch.ones(1, 1, embedding_dim))

        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        seq_length = input_ids.shape[1]
        device = input_ids.device
        if seq_length > self.max_length:  # max length will be increased to max sequnece length if max length is short
            max_length = seq_length
        else:
            max_length = self.max_length
        position_ids = self.position_ids[:, : input_ids.shape[1]]
        position_ids = torch.sort(torch.randperm(max_length, dtype=torch.long, device=device)[:seq_length]).values

        output = position_ids.view(1, -1, 1) * self.embedding
        return output


# Code stolen from GPT-X:
class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10000, def_seq_length=128, seq_dim: int = 0):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=True)
        self.seq_len_cached = def_seq_length
        self.seq_dim = seq_dim
        cos_cache, sin_cache = self._get_cos_sin()
        self.register_buffer("cos_cached", cos_cache, persistent=False)
        self.register_buffer("sin_cached", sin_cache, persistent=False)

        # Force fusions on batched version
        def rotate_half(x: torch.Tensor):
            x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]  # torch.split(x, x.shape[-1] // 2, dim=-1)  # not faster
            return torch.cat((-x2, x1), dim=-1)

        def rope_fn(cos: torch.Tensor, sin: torch.Tensor, query_layer: torch.Tensor, key_layer: torch.Tensor):
            QK = torch.cat([query_layer, key_layer], dim=1)
            rotated = QK * cos[: QK.shape[0]] + rotate_half(QK) * sin[: QK.shape[0]]
            return torch.split(rotated, query_layer.shape[1], dim=1)

        self.rope_fn = rope_fn  # handle fusion on module level

    @torch.no_grad()
    def get_cos_sin_cache(self, x: torch.Tensor):
        seq_len = x.shape[self.seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = x.shape[self.seq_dim]
            cos_cache, sin_cache = self._get_cos_sin()
            self.cos_cached = cos_cache.to(x.device)
            self.sin_cached = sin_cache.to(x.device)
        return self.cos_cached, self.sin_cached

    def _get_cos_sin(self):
        t = torch.arange(self.seq_len_cached).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        if self.seq_dim == 0:
            return emb.cos()[:, None, None, :].detach(), emb.sin()[:, None, None, :].detach()
        else:
            return emb.cos()[None, :, None, :].detach(), emb.sin()[None, :, None, :].detach()

    def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
        cos_cached, sin_cached = self.get_cos_sin_cache(query_layer)
        return self.rope_fn(cos_cached, sin_cached, query_layer, key_layer)

    @torch.jit.export
    def single_forward(self, inputs: torch.Tensor):
        """For cases where shapes of Q and K do not match."""
        cos, sin = self.cos_cached[: inputs.shape[0]], self.sin_cached[: inputs.shape[0]]
        return inputs * cos + self.rotate_half(inputs) * sin

    def rotate_half(self, x: torch.Tensor):
        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)  # torch.split(x, x.shape[-1] // 2, dim=-1)  # not faster


class RecyclePositionalForward(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    """In forward the numbers are decreading in the pos_ids instead of increasing"""

    def __init__(self, embedding_dim, max_seq_length=1024):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))

    def helper(self, mask, input_ids, k=1, device="cuda"):
        mask = torch.flip(mask, dims=[1])
        output = torch.zeros_like(input_ids).to(device)
        indices = torch.nonzero(mask)
        output[indices[:, 0], indices[:, 1]] = (torch.arange(k, k + len(indices))).to(device)
        max_values, _ = torch.max(output, dim=1)
        max_values = max_values - k + 1
        max_values = torch.cat((torch.tensor([0], device=device), max_values[:-1]))[:, None]
        output = torch.max(output - max_values, torch.zeros_like(output).to(device))
        output = torch.flip(output, dims=[1])
        return output

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        pad tokenizer 
        '+': 14
        '=': 17
        '[EOS]': 3
        """
        device = input_ids.device
        position_ids = torch.arange(input_ids.shape[1]).expand(input_ids.shape[0], -1).to(device) + 1
        indices_14 = (input_ids == 14).nonzero()[:, 1]
        indices_17 = (input_ids == 17).nonzero()[:, 1]
        indices_3 = (input_ids == 3).nonzero()[:, 1]

        if not self.training:  # this is inefficient so only check properly in testing cause we have control during training
            if indices_17.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 17).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_17 = indices_17[~mask]
            if indices_14.shape[0] != position_ids.shape[0]:  # the model has output multiple equals signs in some completion
                tensor = (input_ids == 14).nonzero()[:, 0]

                unique_elements, counts = torch.unique(tensor, return_counts=True)
                repeated_elements = unique_elements[counts > 1]
                mask = torch.isin(tensor, repeated_elements)

                first_occur = set()
                for i in range(mask.shape[0]):
                    if mask[i].item() == True:
                        if tensor[i].item() not in first_occur:
                            mask[i] = False
                            first_occur.add(tensor[i].item())

                indices_14 = indices_14[~mask]

        mask_num_1 = position_ids <= indices_14.unsqueeze(1).expand_as(position_ids)  # mask for elements before +
        mask_num_2 = position_ids > indices_14.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after +
        mask_ans_for_num_2 = position_ids > indices_17.unsqueeze(1).expand_as(position_ids)  # rmeoving the = symbol from the AND
        mask_ans = position_ids > indices_17.unsqueeze(1).expand_as(position_ids) + 1  # mask for elements after =
        mask_num_2 = torch.logical_and(mask_num_2, ~mask_ans_for_num_2)

        k = 1
        if self.training:
            mask_eos = position_ids > indices_3.unsqueeze(1).expand_as(position_ids)
            mask_ans = torch.logical_and(mask_ans, ~mask_eos)
            k = random.randint(1, 100)
        ## This means that in testing the [PAD] after the = sign is included in the 3rd positional embedding

        # mask = torch.logical_or(torch.logical_or(mask_num_1,mask_num_2),mask_ans)

        output_1 = self.helper(mask_num_1, input_ids, k=k, device=device)
        output_2 = self.helper(mask_num_2, input_ids, k=k, device=device)
        if self.training:
            output_ans = self.helper(mask_ans, input_ids, k=k, device=device)
        else:
            # this is highly tailored to our eval
            rowwise_max_1, _ = output_1.max(dim=1)
            rowwise_max_1 = rowwise_max_1.reshape((-1, 1))
            rowwise_max_2, _ = output_2.max(dim=1)
            rowwise_max_2 = rowwise_max_2.reshape((-1, 1))
            maxes = torch.cat([rowwise_max_1, rowwise_max_2], dim=1)
            maxes, indices = maxes.max(dim=1)

            output_ans = self.helper(mask_ans, input_ids, k=maxes[0].item(), device=device)
            if not mask_ans.any().item():
                output_ans = torch.zeros_like(output_ans)

            subtraction = int((torch.count_nonzero(output_ans).item() / output_ans.shape[0])) - 1
            if subtraction < 0:
                subtraction = 0
            output_ans[output_ans != 0] = output_ans[output_ans != 0] - subtraction
        output = output_1 + output_2 + output_ans
        """
        output=tensor([[89, 88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75, 74, 73,  0,
         90, 89, 88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75, 74, 73,
          0, 90, 89, 88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75, 74,
         73,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75, 74, 73,  0, 88,
         87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75, 74, 73,  0, 88, 87,
         86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75, 74, 73,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
        input_ids=tensor([[13,  4,  9,  0,  6,  8, 13, 12, 11,  9, 13,  8,  9,  4, 13,  8,  6, 14,
          4,  6,  5,  0, 11,  8,  9,  6, 10,  0, 10,  6, 11,  5,  7, 12,  0, 13,
         17, 13,  6,  0, 10,  0, 13, 12,  8,  5,  8,  6,  6,  6, 11,  7, 11,  8,
          7,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [11,  5,  7, 10,  4,  0,  7,  0, 12, 12,  6, 12, 12,  8,  5,  8, 14,  8,
         10, 11,  6,  6,  9,  7,  5,  4,  4, 10, 13,  0,  7, 10,  9, 17,  5, 12,
          4, 13,  6, 12,  5,  4,  7, 12,  8,  0,  8,  9,  4, 10,  3,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
        """
        return self.embedding(output)


class RotarySanityCheck(torch.nn.Module):
    """not again..."""

    def __init__(self, dim, base=10000, def_seq_length=128, seq_dim: int = 0):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=True)
        self.seq_len_cached = def_seq_length
        self.seq_dim = seq_dim
        cos_cache, sin_cache = self._get_cos_sin()
        self.register_buffer("cos_cached", cos_cache, persistent=False)
        self.register_buffer("sin_cached", sin_cache, persistent=False)

    @torch.no_grad()
    def get_cos_sin_cache(self, x: torch.Tensor):
        seq_len = x.shape[self.seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = x.shape[self.seq_dim]
            cos_cache, sin_cache = self._get_cos_sin()
            self.cos_cached = cos_cache.to(x.device)
            self.sin_cached = sin_cache.to(x.device)
        return self.cos_cached, self.sin_cached

    def _get_cos_sin(self):
        t = torch.arange(self.seq_len_cached).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        if self.seq_dim == 0:
            return emb.cos()[:, None, None, :].detach(), emb.sin()[:, None, None, :].detach()
        else:
            return emb.cos()[None, :, None, :].detach(), emb.sin()[None, :, None, :].detach()

    def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
        # cos, sin = self.get_cos_sin_cache(key_layer)
        # cos, sin = (cos[offset : query_layer.shape[0] + offset, ...], sin[offset : query_layer.shape[0] + offset, ...])
        cos, sin = self.cos_cached, self.sin_cached
        return (query_layer * cos) + (self.rotate_half(query_layer) * sin), (key_layer * cos) + (self.rotate_half(key_layer) * sin)

    def rotate_half(self, x: torch.Tensor):
        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)  # torch.split(x, x.shape[-1] // 2, dim=-1)  # not faster

    @torch.jit.export
    def single_forward(self, inputs: torch.Tensor):
        """For cases where shapes of Q and K do not match."""
        cos, sin = self.cos_cached[: inputs.shape[0]], self.sin_cached[: inputs.shape[0]]
        return inputs * cos + self.rotate_half(inputs) * sin


# Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/rotary.py who adapted from
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
class RotaryEleutherAI(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.
    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration
    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
    """

    _seq_len_cached: int
    # _cos_cached: Optional[torch.Tensor]
    # _sin_cached: Optional[torch.Tensor]

    def __init__(self, dim_model: int, *_, **__):
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
        self.register_buffer("inv_freq", inv_freq)

        _cos_cached, _sin_cached = self._update_cos_sin_tables(torch.randn(1, 128, 1), seq_dimension=-2)
        self.register_buffer("_cos_cached", _cos_cached, persistent=False)
        self.register_buffer("_sin_cached", _sin_cached, persistent=False)

    @torch.jit.ignore
    def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
        seq_len = x.shape[seq_dimension]

        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        # if seq_len != self._seq_len_cached:  # or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
        self._seq_len_cached = seq_len
        t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=self.inv_freq.dtype)
        # Don't do einsum, it converts fp32 to fp16
        # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        freqs = torch.outer(t, self.inv_freq)
        cos_cached = repeat(torch.cos(freqs).to(x.dtype), "... d -> ... (d 2)")
        sin_cached = repeat(torch.sin(freqs).to(x.dtype), "... d -> ... (d 2)")

        return cos_cached, sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dimension: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
        # assert seq_dimension in [-2, -3]  # Either (bs, h, s, d) or (bs, s, h, d)
        # self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=seq_dimension)

        return (
            self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dimension),
            self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dimension),
        )

    def rotate_half(self, x: torch.Tensor):
        x = x.unflatten(dim=-1, sizes=(-1, 2))
        x1, x2 = x.unbind(dim=-1)
        rotated_x = torch.stack((-x2, x1), dim=-1)
        return rotated_x.flatten(start_dim=-2)

    def apply_rotary_pos_emb(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seq_dimension: int = -2):
        # NOTE: This could probably be moved to Triton

        # Handle a possible sequence length mismatch in between q and k
        cos = cos[: x.shape[seq_dimension], :]
        sin = sin[: x.shape[seq_dimension], :]
        if seq_dimension == -3:
            cos = cos[:, None, :]
            sin = sin[:, None, :]
        return (x * cos) + (self.rotate_half(x) * sin)


class RotaryLLAMA(torch.nn.Module):
    """Facebook implementation of rotary embeddings."""

    def __init__(self, hidden_per_head, base=10000, max_seq_length=512, seq_dim: int = 0):
        super().__init__()
        self.seq_dim: int = seq_dim
        freqs_cis = self.precompute_freqs_cis(dim=hidden_per_head, end=max_seq_length * 2, theta=base)
        self.register_buffer("freqs_cis", freqs_cis)

    def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
        return self.apply_rotary_emb(query_layer, key_layer, freqs_cis=self.freqs_cis)

    def apply_rotary_emb(self, xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
        xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
        freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_)

        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
        return xq_out.type_as(xq), xk_out.type_as(xk)

    def reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor):
        freqs_cis = freqs_cis[: x.shape[self.seq_dim]]
        # shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
        # shape = [1, seq_length, 1, hidden_per_head]
        shape = [s if i == self.seq_dim or i == x.ndim - 1 else 1 for i, s in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    @staticmethod
    def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)  # type: ignore
        freqs = torch.outer(t, freqs).float()  # type: ignore
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
        return freqs_cis


class RelativeMultiHeadAttention(torch.nn.Module):
    # https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/xl/relative_mha.py
    def __init__(self, heads, d_model, hidden_per_head, dropout_prob=0):
        super().__init__()
        self.d_k = hidden_per_head
        self.P = 2**4  # 12 # TODO possibly change this num as makes model big

        self.key_pos_embeddings = torch.nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
        self.key_pos_bias = torch.nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
        self.query_pos_bias = torch.nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

    def shift_right(self, x: torch.Tensor):
        zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
        x_padded = torch.cat([x, zero_pad], dim=1)
        x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
        x = x_padded[1:].view_as(x)
        return x

    def forward(self, query: torch.Tensor, key: torch.Tensor):
        key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0] : self.P + query.shape[0]]
        key_pos_bias = self.key_pos_bias[self.P - key.shape[0] : self.P + query.shape[0]]
        query_pos_bias = self.query_pos_bias[None, None, :, :]

        ac = torch.einsum("ibhd,jbhd->ijbh", query + query_pos_bias, key)

        b = torch.einsum("ibhd,jhd->ijbh", query, key_pos_emb)
        d = key_pos_bias[None, :, None, :]
        # print((b+d).shape) (82, 32, 24, 16)
        bd = self.shift_right(b + d)
        # print("bd shape ", bd.shape)(82, 32, 24, 16)
        bd = bd[:, -key.shape[0] :]
        # print("ac shape ", ac.shape) (82, 82, 24, 16)
        # print("bd shape ", bd.shape) ([82, 32, 24, 16])
        temp = ac + bd
        print("output shape: ", temp.shape)
        return ac + bd


class AlibiMultiHeadAttention(torch.nn.Module):
    # https://github.com/ofirpress/attention_with_linear_biases/issues/5
    # https://github.com/EIFY/fairseq/pull/2/commits/e67880a1ea67220abd197c194e2ebcd462d8a0be
    # not alibi is not learned it is static, so we can cache it after the first calculation
    def __init__(self, num_heads, max_seq_length=1024, asymmetric=False):
        super().__init__()
        self.attn_heads = num_heads
        self.asymmetric = asymmetric
        self.cached = None
        self.length = 0

    def forward(self, seq_length, device):  # this is to make the forward call idetical to FIRE
        if self.cached == None or seq_length != self.length:
            context_position = torch.arange(seq_length)[:, None].to(device)
            memory_position = torch.arange(seq_length)[None, :].to(device)
            relative_position = memory_position - context_position
            if self.asymmetric:
                relative_position -= (0.5 * torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).long().to(device)
            relative_position = torch.abs(relative_position).unsqueeze(0).expand(self.attn_heads, -1, -1)
            slopes = torch.Tensor(self.get_slopes(self.attn_heads)).to(device) * -1
            alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
            alibi = alibi.view(1, self.attn_heads, seq_length, seq_length)
            self.cached = alibi
            self.length = self.cached.shape[3]
        # print(self.cached.shape) #torch.Size([1, 16, 82, 82])
        return self.cached
        # self_attn_mask = self.alibi[:, :shape[1], :shape[1]].repeat(shape[0], 1, 1)

    def get_slopes(self, n):  # n = num heads
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)  # In the paper, we only train models that have 2^a heads for some a. This function has
        else:  # some good properties that only occur when the input is a power of 2. To maintain that even
            closest_power_of_2 = 2 ** math.floor(math.log2(n))  # when the number of heads is not a power of 2, we use this workaround.
            return get_slopes_power_of_2(closest_power_of_2) + self.get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]


class FIRE(torch.nn.Module):
    def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512.0, eps=1e-6, max_length=0):
        """
        FIRE attention bias module (https://arxiv.org/abs/2310.04418).

        Args:
            num_heads: number of attention heads.
            mlp_width: Width of MLP.
            init_c: initial value of log transformation parameter
            init_L: initial value of thresholding parameter
            eps: small constant for numerical stability
        """
        super(FIRE, self).__init__()
        self.max_length = max_length  # using random PE

        # Define the MLP layers
        self.mlp = torch.nn.Sequential(torch.nn.Linear(1, mlp_width), torch.nn.ReLU(), torch.nn.Linear(mlp_width, num_heads))

        # Initialize c (log transformation parameter)
        self.c = torch.nn.Parameter(torch.tensor(init_c))

        # Initialize L (threshold)
        self.init_L = torch.nn.Parameter(torch.tensor(init_L), requires_grad=False)
        self.L_multiplier = torch.nn.Parameter(torch.tensor(1.0))  # learn a multiplier to L

        self.eps = eps

    def forward(self, seq_length, device):
        """
        Compute FIRE attention bias (https://arxiv.org/abs/2310.04418).

        Args:
            x: input sequence, shape [bsz, num_heads, seq_len, hidden_dim]

        Returns:
            attention bias of shape [1, num_heads, seq_len, seq_len]
        """
        if (seq_length > self.max_length) or (
            not self.training
        ):  # max length will be increased to max sequnece length if max length is short
            max_length = seq_length
        else:
            max_length = self.max_length

        # positions = torch.arange(self.max_length, dtype=torch.float, device=device)
        # take a subset (of length seq_length) of a random permutation of length max_length, then sort it to
        positions = torch.sort(torch.randperm(max_length, dtype=torch.float, device=device)[:seq_length]).values
        relative_distances = positions[:, None] - positions[None, :]
        # print(f"{relative_distances=}")
        # exit()
        # Thresholding the normalizer for short sequence modeling
        threshold = torch.abs(self.L_multiplier * self.init_L)
        position_normalizer = torch.max(positions, threshold)[:, None]

        # Amplifying differences among local positions with log transform
        relative_distances = torch.log(torch.abs(self.c * relative_distances) + 1)
        position_normalizer = torch.log(torch.abs(self.c * position_normalizer) + 1)

        # Progressive interpolation
        normalized_distances = relative_distances / (position_normalizer + self.eps)
        fire_bias = self.mlp(normalized_distances.unsqueeze(-1)).unsqueeze(0)
        # [1, seq_len, seq_len, num_heads]
        fire_bias = fire_bias.permute(0, 3, 1, 2)
        # [1, num_heads, seq_len, seq_len]
        # print(fire_bias.shape) #(1, 16, 82, 82)
        return fire_bias


class NOPE_RAND(torch.nn.Module):
    def __init__(self, num_heads=12):
        super(NOPE_RAND, self).__init__()
        self.num_heads = num_heads

    def forward(self, tensor_shape, device):
        if self.training:
            bias = torch.randn((1, self.num_heads, tensor_shape, tensor_shape), device=device)
        else:
            bias = torch.zeros((1, self.num_heads, tensor_shape, tensor_shape), device=device)
        # print(bias.shape) # (1, 16, 82, 82)
        return bias


class RecycleForSorting(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embedding_dim, max_seq_length=1024, max_k=99):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("position_ids", torch.arange(max_seq_length).expand((1, -1)))
        self.max_k = max_k # the max_k here by default is 99 as we add it on after istead of generate with it

    def helper(self, mask, device):
        mask_shape = mask.shape
        mask = mask.flatten()
        # Create a shifted version of the mask to detect changes from 0 to 1
        shifted_mask = torch.cat([torch.tensor([0], device=device), mask[:-1]])
        starts = (shifted_mask != mask) & mask
        
        # Generate IDs for each segment of 1s
        segment_ids = torch.cumsum(starts, dim=0)
        
        # Generate an index array
        index = torch.arange(mask.size(0)).to(device)
        
        # Reset index at the start of each segment
        reset_index = torch.zeros_like(mask).long()
        second_term = index * starts.long()
        reset_index = reset_index.scatter_add(0, segment_ids, second_term)
        
        # Calculate positions in segment
        positions = index - reset_index[segment_ids] + 1
        
        # Ensure only values within 1-segments are non-zero
        result = positions * mask
        result = result.reshape(mask_shape)
        return result

    def forward(self, input_ids):
        """This is a batch-first implementation"""
        """
        sort tokenizer: '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13

        {'0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, 'D': 14, ',': 15, ':': 16, '=': 17, ' ': 18, 'A': 19, 'B': 20, 'C': 21, 'E': 22, 'F': 23, 'G': 24, 'H': 25, 'I': 26, 'J': 27, 'K': 28, 'L': 29, 'M': 30, 'N': 31, 'O': 32, 'P': 33, 'Q': 34, 'R': 35, 'S': 36, 'T': 37, 'U': 38, 'V': 39, 'W': 40, 'X': 41, 'Y': 42, 'Z': 43, 'a': 44, 'b': 45, 'c': 46, 'd': 47, 'e': 48, 'f': 49, 'g': 50, 'h': 51, 'i': 52, 'j': 53, 'k': 54, 'l': 55, 'm': 56, 'n': 57, 'o': 58, 'p': 59, 'q': 60, 'r': 61, 's': 62, 't': 63, 'u': 64, 'v': 65, 'w': 66, 'y': 67, 'z': 68, '!': 69, '@': 70, '£': 71, '#': 72, '$': 73, '%': 74, '^': 75, '&': 76, '*': 77, '(': 78, ')': 79, '~': 80, '?': 81, '.': 82, '<': 83, '>': 84, '{': 85, '}': 86, '[': 87, ']': 88, ';': 89, '/': 90, '|': 91, 'β': 92, 'Γ': 93, 'Δ': 94, 'δ': 95, 'ε': 96, 'ζ': 97, 'η': 98, 'θ': 99, 'κ': 100, 'Λ': 101, 'λ': 102, 'μ': 103, 'Ξ': 104, 'ξ': 105, 'Π': 106, 'π': 107, 'Σ': 108, 'ς': 109, 'τ': 110, 'Φ': 111, 'φ': 112, 'χ': 113, 'Ψ': 114, 'ψ': 115, 'Ω': 116, 'ω': 117, '[PAD]': 0, '[UNK]': 1, '[BOS]': 2, '[EOS]': 3}
        """
        mask = (input_ids >= 4) & (input_ids <= 13)
        output = self.helper(mask, input_ids.device)
        
        k=0
        if self.training:
            k = random.randint(0, self.max_k)
            output[output>0] += k # as we already have ones in the tensor, the tensor values will be k+1

        """
        input_ids[:3]=tensor([[99, 16,  7, 17, 99,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
                [59, 16, 12, 15, 60, 16, 13, 11,  9,  4, 12, 15, 61, 16,  9, 10,  9,  6,
                13, 15, 62, 16, 11, 15, 63, 16, 12, 13, 15, 64, 16,  6, 15, 65, 16, 11,
                8,  7,  4,  9, 15, 66, 16,  7, 10, 17, 64, 15, 62, 15, 59, 15, 66, 15,
                63, 15, 65, 15, 60, 15, 61,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
                [38, 16,  9,  9, 10,  8,  4, 11,  4, 12, 13, 15, 39, 16,  6, 15, 40, 16,
                7,  8, 13,  9,  4, 11, 15, 41, 16,  8, 10, 15, 42, 16,  7,  4, 12,  8,
                10, 15, 43, 16,  6, 10,  9, 15, 44, 16,  6,  9, 15, 45, 16,  5,  5,  4,
                12, 15, 46, 16,  6,  9, 10, 10,  7,  4,  5,  7,  8, 10, 15, 47, 16,  7,
                8,  6, 17, 39, 15, 44, 15, 41, 15, 47, 15, 43, 15, 45, 15, 42, 15, 40,
                15, 38, 15, 46,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],
            device='cuda:0')
        (input_ids*mask)[:3]=tensor([[ 0,  0,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
                [ 0,  0, 12,  0,  0,  0, 13, 11,  9,  4, 12,  0,  0,  0,  9, 10,  9,  6,
                13,  0,  0,  0, 11,  0,  0,  0, 12, 13,  0,  0,  0,  6,  0,  0,  0, 11,
                8,  7,  4,  9,  0,  0,  0,  7, 10,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
                [ 0,  0,  9,  9, 10,  8,  4, 11,  4, 12, 13,  0,  0,  0,  6,  0,  0,  0,
                7,  8, 13,  9,  4, 11,  0,  0,  0,  8, 10,  0,  0,  0,  7,  4, 12,  8,
                10,  0,  0,  0,  6, 10,  9,  0,  0,  0,  6,  9,  0,  0,  0,  5,  5,  4,
                12,  0,  0,  0,  6,  9, 10, 10,  7,  4,  5,  7,  8, 10,  0,  0,  0,  7,
                8,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],
            device='cuda:0')
        k=5
        output[:3]=tensor([[ 0,  0,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
                [ 0,  0,  6,  0,  0,  0,  6,  7,  8,  9, 10,  0,  0,  0,  6,  7,  8,  9,
                10,  0,  0,  0,  6,  0,  0,  0,  6,  7,  0,  0,  0,  6,  0,  0,  0,  6,
                7,  8,  9, 10,  0,  0,  0,  6,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
                [ 0,  0,  6,  7,  8,  9, 10, 11, 12, 13, 14,  0,  0,  0,  6,  0,  0,  0,
                6,  7,  8,  9, 10, 11,  0,  0,  0,  6,  7,  0,  0,  0,  6,  7,  8,  9,
                10,  0,  0,  0,  6,  7,  8,  0,  0,  0,  6,  7,  0,  0,  0,  6,  7,  8,
                9,  0,  0,  0,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  0,  0,  0,  6,
                7,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],
            device='cuda:0')
        """
        return self.embedding(output)