# from models.mpt_7b.modeling_mpt import *
# import models.mpt_7b.modeling_mpt as modeling_mpt
import copy

from models.mpt_7b.attention import *
# import models.mpt_7b.attention as mpt_attention

from models.mpt_7b.modeling_mpt import *


# import models.mpt_7b.attention as mpt_attention
# from models.mpt_7b.weave_attention import scaled_multihead_dot_product_attention
# mpt_attention.scaled_multihead_dot_product_attention = scaled_multihead_dot_product_attention

"""
    weave-mpt3:  mesa 阶梯形
        尝试固定宽度 width=100
        初始，512/ 1024/ 1536
    fix-update: 采用 chunk
    validate: chunk, 应该跟 weave-mpt3的结果一致。
    
    update： 
        重新修改chunk对 log 的使用，log 的使用，本质上是为了温度下降，使得softmax后，更确定。
        应该使用 chunk 中，key 的个数，来逆序地，对q 赋值
        
"""


def weave_build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool = False,
                          alibi: bool = False, alibi_bias_max: int = 8) -> Optional[torch.Tensor]:
    if attn_impl == 'flash':
        return None
    elif attn_impl in ['torch', 'triton']:
        if alibi:
            (device, dtype) = (attn_bias.device, attn_bias.dtype)
            attn_bias = attn_bias.add(
                build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device,
                                 dtype=dtype))
        return attn_bias
    else:
        raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')


def build_alibi_bias(n_heads: int, seq_len: int, full: bool = False, alibi_bias_max: int = 8,
                     device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
    # alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)

    if seq_len > train_max_pos:

        pos1 = torch.arange(1 - push_pos, 1, dtype=torch.int32, device=device)

        last_pos = (seq_len - push_pos) // push_width + 1 + push_pos
        # pos2_repeat = torch.full((interval, ), 1 - (push_pos + interval), dtype=torch.int32, device=device)
        pos2_indices = torch.arange(- last_pos, 1 - push_pos, dtype=torch.int32, device=device)
        pos2_repeat = pos2_indices.repeat(push_width)
        sorted_pos2, _ = torch.sort(pos2_repeat)

        alibi_bias = torch.concat([sorted_pos2, pos1], dim=0)[-seq_len:]

    else:
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)

    if full:
        alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
        alibi_bias = alibi_bias.abs().mul(-1)
    slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
    alibi_bias = alibi_bias * slopes
    return alibi_bias.to(dtype=dtype)





# # 添加 log n 缩放因子
def log_scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
    if multiquery:
        warnings.warn(DeprecationWarning('The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'))
        kv_n_heads = 1
    elif kv_n_heads is None:
        warnings.warn(DeprecationWarning('Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'))
        kv_n_heads = n_heads
    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
    if past_key_value is not None:
        if len(past_key_value) != 0:
            k = torch.cat([past_key_value[0], k], dim=3)
            v = torch.cat([past_key_value[1], v], dim=2)
        past_key_value = (k, v)
    (b, _, s_q, d) = q.shape
    s_k = k.size(-1)
    if kv_n_heads > 1 and kv_n_heads < n_heads:
        k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
        v = repeat_kv_for_gqa(v.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
    if softmax_scale is None:
        softmax_scale = 1 / math.sqrt(d)


    # 修改 log, 使用key的个数，倒序给 q 赋值
    log_n = (torch.arange(1, k.shape[3] + 1)[None,][:, None, :, None].log() / np.log(train_max_len)).clip(
        1).to(q.dtype)
    q = q * log_n[:,:,-q.shape[2]:,:].to(q.device)

    attn_weight = q.matmul(k) * softmax_scale

    if attn_bias is not None:
        _s_q = max(0, attn_bias.size(2) - s_q)
        _s_k = max(0, attn_bias.size(3) - s_k)
        attn_bias = attn_bias[:, :, _s_q:, _s_k:]
        if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
            raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')

        attn_weight = attn_weight + attn_bias

    min_val = torch.finfo(q.dtype).min
    if key_padding_mask is not None:
        if attn_bias is not None:
            warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
        attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
    if is_causal and (not q.size(2) == 1):
        s = max(s_q, s_k)
        causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
        causal_mask = causal_mask.tril()
        causal_mask = causal_mask.to(torch.bool)
        causal_mask = ~causal_mask
        causal_mask = causal_mask[-s_q:, -s_k:]
        attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
    attn_weight = torch.softmax(attn_weight, dim=-1)
    if dropout_p:
        attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
    out = attn_weight.to(v.dtype).matmul(v)
    out = rearrange(out, 'b h s d -> b s (h d)')
    if needs_weights:
        return (out, attn_weight, past_key_value)
    return (out, None, past_key_value)



# 修改 MPTForCaulsual.forward
def chunk_forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
                  attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None,
                  sequence_id: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
                  return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None,
                  output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None,
                  inputs_embeds: Optional[torch.FloatTensor] = None) -> CausalLMOutputWithPast:
    return_dict = return_dict if return_dict is not None else self.config.return_dict
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    if inputs_embeds is not None:
        raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')

    #
    # outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask,
    #                            prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict,
    #                            output_attentions=output_attentions, output_hidden_states=output_hidden_states,
    #                            use_cache=use_cache)

    if past_key_values == None:
        new_past_key_values = None
        first_chunk_past_key_values = None
        new_logis = None

        i = 0
        # first_chunk_width = 100
        beg, end = 0, 0 + chunk_width
        # beg, end = 0, 0 + first_chunk_width
        mask_beg = 0
        input_length = len(input_ids[0])
        while i < input_length:
            outputs = self.transformer(input_ids=input_ids[...,beg:end],
                                       past_key_values=past_key_values,
                                       attention_mask=attention_mask[...,mask_beg:end] if attention_mask is not None else None,
                                       prefix_mask=prefix_mask,
                                       sequence_id=sequence_id,
                                       return_dict=return_dict,
                                       output_attentions=output_attentions,
                                       output_hidden_states=output_hidden_states,
                                       use_cache=use_cache)

            logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
            current_input_len_q = logits.shape[1]
            if new_past_key_values == None:
                new_past_key_values = outputs.past_key_values
                first_chunk_past_key_values = copy.deepcopy(new_past_key_values)
                new_logis = copy.deepcopy(logits)
            else:
                _past_key_values = []
                for tup, pkv in zip(list(new_past_key_values), list(outputs.past_key_values)):
                    kv_ = (torch.concat([tup[0], pkv[0][:,:,:,-current_input_len_q:]], dim=3), torch.concat([tup[1], pkv[1][:,:,-current_input_len_q:,:]], dim=2))
                    _past_key_values.append(kv_)
                new_past_key_values = tuple(_past_key_values)

                new_logis = torch.concat([new_logis, logits], dim=1)

            i = end
            beg = end

            if end + chunk_width < input_length:
                # # 拼接first-chunk
                # past_key_values = first_chunk_past_key_values
                # # assert past_key_values[0][0].shape[3] == 512
                # mask_beg = end - past_key_values[0][0].shape[3]

                # # 不拼接 first-chunk
                # past_key_values = None
                # mask_beg = end

                # 拼接前面所有 chunk
                past_key_values = new_past_key_values
                mask_beg = end - past_key_values[0][0].shape[3]

            else:
                # last chunk
                past_key_values = new_past_key_values
                # 拼接前面所有的chunk
                mask_beg = end - past_key_values[0][0].shape[3]


            # end += 1024
            end += chunk_width


        past_key_values = new_past_key_values
        logits = new_logis

    else:
        # next token inference
        outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask,
                                   prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict,
                                   output_attentions=output_attentions, output_hidden_states=output_hidden_states,
                                   use_cache=use_cache)

        logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
        past_key_values = outputs.past_key_values

    if self.logit_scale is not None:
        if self.logit_scale == 0:
            warnings.warn(
                f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
        logits *= self.logit_scale
    loss = None
    return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values,
                                  hidden_states=None, attentions=None)



# 1.设置alibi位置编码
import models.mpt_7b.modeling_mpt as modeling_mpt
modeling_mpt.build_attn_bias = weave_build_attn_bias

train_max_pos = 2048 - 1
push_width = 200 # 20000 #200 #200
push_pos = 512 # 4096 # 512 # 1536 #512 # 1536  # 1536


# 2. 设置 log 缩放给 attention
import models.mpt_7b.attention as mpt_attention
mpt_attention.scaled_multihead_dot_product_attention = log_scaled_multihead_dot_product_attention


# 3. 设置 切 chunk
modeling_mpt.MPTForCausalLM.forward = chunk_forward
chunk_width = 512 #100 # 512 #1024 # 512 # 100 # 512
train_max_len = 2048