import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from peft import PeftModel, get_peft_model, LoraConfig
from fastchat.model import get_conversation_template
from transformers import (
    AutoConfig,
    AutoTokenizer, 
    AutoModelForCausalLM, 
    GenerationConfig, 
    PreTrainedModel,
)
# from transformers import AutoModelForCausalLM
from transformers.models.llama import LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel
from transformers.models.mistral import MistralForCausalLM
from transformers.generation import GenerationMixin
from omegaconf import OmegaConf
import os
import openai
from typing import Any, List, Optional, Union, Tuple
import time
import concurrent.futures
import random
import copy
from .gen_util import ContrastGenerationMixin
from .peft_util import find_all_linear_names
from .utils import NameTimer

def create_small_llm(basellm, small_layer_num):
    if small_layer_num == len(basellm.model.base_model.layers):
        llm = basellm 
    else:
        llm = SmallLLM(
            basellm.model.config,
            num_layer=small_layer_num,
            device='cpu',
            base_llm=basellm,
        )
    return llm

def copy_weights(base_llm, model):
    config = model.config
    name = model.config._name_or_path.lower()
    if ('llama' in name) or ('zephyr' in name) or ('mistral' in name):
        print(f"Copying {name} first layer: {config.num_hidden_layers}")
        model.model.embed_tokens.load_state_dict(
            base_llm.model.embed_tokens.state_dict()
        )
        model.model.norm.load_state_dict(
            base_llm.model.norm.state_dict()
        )
        for layer_num in range(config.num_hidden_layers):
            model.model.layers[layer_num].load_state_dict(
                base_llm.model.layers[layer_num].state_dict()
            )
        model.lm_head.load_state_dict(
            base_llm.lm_head.state_dict()
        )
        return model
    else:
        raise ValueError(f"Unsupported model: {name}")
 
def init_small_huggingface_llm(origin_config, num_layer, device, hparams=None, base_llm=None, saved_path=None):
    config = copy.deepcopy(origin_config)
    tmppath = config._name_or_path.lower()
    if ('llama' in tmppath) or ('zephyr' in tmppath) or ('mistral' in tmppath):
        #! this is for llama-2
        config.num_hidden_layers = num_layer
        num_hidden_layers = config.num_hidden_layers
        
    model = AutoModelForCausalLM.from_config(
        config,
        use_flash_attention_2=True, 
        torch_dtype=torch.bfloat16, 
    ).to('cpu')

    if base_llm is not None:
        copy_weights(base_llm, model)
        
    if saved_path is not None:
        model.load_state_dict(
            torch.load(saved_path)
        )

    return model

class HuggingfaceLLMWrapper:
    model: AutoModelForCausalLM

    def __init__(self) -> None:
        super().__init__()
        pass

    def prepare_inputs_for_generation(self, *args, **kwargs):
        return self.model.prepare_inputs_for_generation(
            *args, **kwargs
        )

    def __call__(self, *args, **kwargs):
        return self.model(
            *args, **kwargs
        )

    @property
    def generation_config(self):
        return self.model.generation_config
    
    def generate(self, *args, **kwargs):
        return self.model.generate(
            *args, **kwargs
        )

    def to(self, device):
        self.device = device
        self.model.to(device)



class UnlearnLLM(HuggingfaceLLMWrapper):
    def __init__(self, model_config, device, hparams=None):
        """
        Args:
            model_config (_type_): _description_
            assistant_model (_type_): the helper model which contains the information of all unlearn requests
            device (_type_): _description_
            hparams (_type_, optional): _description_. Defaults to None.
        """

        self.haprams = hparams
        self.device = device
        self.model_config = model_config
        self.tokenizer= AutoTokenizer.from_pretrained(
            self.model_config['tokenizer_path'],
            trust_remote_code=True,
            use_fast=False,
        )
        self.tokenizer.padding_side = 'left'
        # if 'llama-2' in self.model_config['tokenizer_path'].lower():
            # self.tokenizer.pad_token = self.tokenizer.unk_token
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_config['model_path'],
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            low_cpu_mem_usage=True,
            use_cache=True,
        ).to(device).eval()

    @property
    def config(self):
        return self.model.config
    
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

class SmallLLM(HuggingfaceLLMWrapper):
    def __init__(self, origin_config, num_layer, device, hparams=None, base_llm=None, save_path=None, is_lora=False) -> None:
        super().__init__()

        self.hparams = hparams
        self.device = device
        self.config = copy.deepcopy(origin_config)
        if 'llama-2' in self.config._name_or_path.lower():
            #! this is for llama-2
            self.config.num_hidden_layers = num_layer
            self.num_hidden_layers = self.config.num_hidden_layers
        
        if save_path is None: 
            print(f"Loading assist model from large model")
            self.model = AutoModelForCausalLM.from_config(
                self.config,
                torch_dtype=torch.bfloat16,
            )
            if base_llm is not None:
                self.copy_weights(base_llm)
        else:
            print(f"Loading assist model from {save_path}")
            if is_lora:
                print("Loading lora model")
                fullpath = os.path.join(save_path, "../fullmodel")
                self.model = AutoModelForCausalLM.from_pretrained(
                    fullpath, torch_dtype=torch.bfloat16, 
                ).to(device)
                peftmod = PeftModel.from_pretrained(self.model, save_path, torch_dtype=torch.bfloat16)
                peftmod = peftmod.merge_and_unload()
                self.model = peftmod
            else:
                self.model = AutoModelForCausalLM.from_pretrained(
                    save_path,
                    torch_dtype=torch.bfloat16,
                )

        self.model.to(device)
    
    def copy_weights(self, base_llm : UnlearnLLM):
        name = self.model.config._name_or_path.lower()
        # embed
        if 'llama-2' in name:
            print(f"Copying llama-2 first layer: {self.config.num_hidden_layers}")
            self.model.model.embed_tokens.load_state_dict(
                base_llm.model.model.embed_tokens.state_dict()
            )
            self.model.model.norm.load_state_dict(
                base_llm.model.model.norm.state_dict()
            )
            for layer_num in range(self.config.num_hidden_layers):
                self.model.model.layers[layer_num].load_state_dict(
                    base_llm.model.model.layers[layer_num].state_dict()
                )
            self.model.lm_head.load_state_dict(
                base_llm.model.lm_head.state_dict()
            )
        else:
            raise ValueError(f"Unsupported odel: {name}")

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

# TODO: muhao's paper
def init_offset_model(model_path, data_type='bfloat16', **kwargs):
    baseconfig = AutoConfig.from_pretrained(model_path)
    model = OffsetAssitedModel(
        baseconfig,
        torch_dtype=torch.bfloat16, 
        **kwargs,
    )
    if device := (kwargs.get('device', None)):
        model = model.to(device=device)
    return model

class OffsetAssitedModel(PreTrainedModel):
    _keys_to_ignore_on_load_missing = [
        r"assist_model.*",
    ]
    _keys_to_ignore_on_load_unexpected = [
        r"assist_model.*",
    ]

    def __init__(self, config, base_assist_path, new_assist_path=None, weight=1.0, is_lora=False, Lora=OmegaConf.create({"r":0, "alpha": 32, "dropout": 0.05}), **kwargs):
        tmplora = OmegaConf.to_container(Lora)
        config.Lora = tmplora
        config.base_model_name = config._name_or_path
        config.is_offset = True
        config.base_assist_path = base_assist_path
        config.new_assist_path = new_assist_path
        config.weight = weight
        config.new_assist_path = new_assist_path
        super().__init__(config, **kwargs)
        
        self.vocab_size = config.vocab_size

        self.basellm = AutoModelForCausalLM.from_pretrained(config.base_model_name, torch_dtype=torch.bfloat16, use_flash_attention_2=True)
        self.basellm.eval()
        self.basellm.requires_grad_(False) #! Freeze

        self.base_assist_llm = AutoModelForCausalLM.from_pretrained(base_assist_path, use_flash_attention_2=True, torch_dtype=torch.bfloat16)
        self.base_assist_llm.eval()
        self.base_assist_llm.requires_grad_(False) #! Freeze
        
        if new_assist_path is None:
            assist_path = base_assist_path
        else:
            assist_path = new_assist_path
        
        self.assist_llm = AutoModelForCausalLM.from_pretrained(assist_path, use_flash_attention_2=True, torch_dtype=torch.bfloat16)
        if Lora.r != 0:
            peftconfig = LoraConfig(
                r=Lora.r,
                lora_alpha=Lora.alpha,
                target_modules=find_all_linear_names(self.assist_llm), 
                lora_dropout=Lora.dropout,
                bias=Lora.bias, 
                task_type="CAUSAL_LM",
            )
            self.assist_llm = get_peft_model(self.assist_llm, peftconfig)

        self.weight = weight
        self.generation_config = GenerationConfig.from_model_config(self.config)
    
    def prepare_inputs_for_generation(self, *args, **kwargs):
        return self.basellm.prepare_inputs_for_generation(
            *args, **kwargs
        )

    def forward(
        self, 
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        with torch.no_grad():
            outputs = self.basellm(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                # cache_position=cache_position,
            )
            base_logits = outputs.logits.detach() # make sure the gradient stops for oracle
            outputs = self.base_assist_llm(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                # cache_position=cache_position, # Don't use cache for assistant               
            )
            # print("hihi")
            base_assist_logits = outputs.logits.detach()
            # print("hgg")

        assist_outputs = self.assist_llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            # cache_position=cache_position, # Don't use cache for assistant
        )
        assist_logits = assist_outputs.logits

        logits = base_logits + self.weight * (assist_logits - base_assist_logits) #! ajust the final distribution
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        # print("OffsetLLM", logits)
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
    def save_pretrained(self, path, **kwargs):
        self.assist_llm.save_pretrained(path)
        self.config.save_pretrained(path)

    
class AssistedModel(LlamaPreTrainedModel, GenerationMixin):
    _keys_to_ignore_on_load_missing = [
        r"assist_model.*",
    ]
    _keys_to_ignore_on_load_unexpected = [
        r"assist_model.*",
    ]
    def __init__(self, config, assist_num_layer=8, 
                 is_lora=True, Lora=OmegaConf.create({"r": 16, "alpha": 32, "dropout": 0.05}), **kwargs):
        config.assist_num_layer = assist_num_layer
        config.is_lora = is_lora
        tmplora = OmegaConf.to_container(Lora)
        config.Lora = tmplora
        config.base_model_name = config._name_or_path
        super().__init__(config, **kwargs)

        # ! Name the origin model oracle, as we would remove the weights in BaseModule.on_save_checkpoint to save disk space
        self.basellm = AutoModelForCausalLM.from_pretrained(config.base_model_name, torch_dtype=torch.bfloat16, use_flash_attention_2=True)
        self.basellm.requires_grad_(False) #! Freeze
        self.vocab_size = config.vocab_size

        # Freeze the original model, but copy a small part to work as the assistant 
        small_config = copy.deepcopy(config)
        small_config.num_hidden_layers = assist_num_layer
        self.assist_llm = AutoModelForCausalLM.from_config(small_config, use_flash_attention_2=True, torch_dtype=torch.bfloat16)

        self.weight = 1.0
        self.top_logit_filter = 0.0
        self.generation_config = GenerationConfig.from_model_config(self.config)
   
    def copy_weights(self):
        print("Copying small model weights for assisted model")
        return copy_weights(self.basellm, self.assist_llm)
    
    def from_pretrained(self, path):
        raise ValueError("Don't use this function, use init_from_basellm instead")
    
    @classmethod
    def init_from_basellm(cls, model_path, assist_num_layer=8, is_lora=True, Lora=OmegaConf.create({"r": 16, "alpha": 32, "dropout": 0.05, "bias": "none", "task_type":"CAUSAL_LM"})):
        baseconfig = AutoConfig.from_pretrained(model_path)
        model = cls(baseconfig, assist_num_layer, is_lora, Lora, torch_dtype=torch.bfloat16)
        model.copy_weights()
        if model.config.is_lora:
            peftconfig = LoraConfig(
                r=Lora.r,
                lora_alpha=Lora.alpha,
                target_modules=find_all_linear_names(model.assist_llm), 
                lora_dropout=Lora.dropout,
                bias=Lora.bias, 
                task_type="CAUSAL_LM",
            )
            model.assist_llm = get_peft_model(model.assist_llm, peftconfig)
        return model
    
    @classmethod 
    def init_from_pretrained(cls, model_path):
        config = AutoConfig.from_pretrained(model_path)
        config.Lora = OmegaConf.create(config.Lora)
        model = cls.init_from_basellm(config.base_model_name, config.assist_num_layer, config.is_lora, config.Lora)
        if config.is_lora:
            print("Initialize with loaded peft assist")
            peftpath = model_path
            unloaded = model.assist_llm.unload() #! Unload the randomly initialized
            model.assist_llm = PeftModel.from_pretrained(unloaded, peftpath, torch_dtype=torch.bfloat16)
        else:
            print("Initialize with loadef assist")
            saved_assist = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
            model.assist_llm.load_state_dict(saved_assist.state_dict())
        model.eval()
        return model
    
    def save_pretrained(self, path, **kwargs):
        self.assist_llm.save_pretrained(path)
        self.config.save_pretrained(path)

    def forward(
        self, 
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        with torch.no_grad():
            outputs = self.basellm(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                cache_position=cache_position,
            )
            logits = outputs.logits.detach() # make sure the gradient stops for oracle

        assist_outputs = self.assist_llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=None,
        )
        assist_logits = assist_outputs.logits
        logits = logits + assist_logits #! ajust the final distribution

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast

#! Used for computing next-token probability
class ContrastLLM(torch.nn.Module, ContrastGenerationMixin):
    def __init__(self, basellm : AutoModelForCausalLM, assist_llm : AutoModelForCausalLM, weight : float, top_logit_filter=0.0) -> None:
        super().__init__()
        self.basellm = basellm
        self.assist_llm = assist_llm
        self.weight = weight
        self.device = self.basellm.device
        self.config = self.basellm.config
        self.generation_config = basellm.generation_config
        self.top_logit_filter = top_logit_filter
        self.tokenizer = AutoTokenizer.from_pretrained(self.config._name_or_path)
    
    def get_loss(self, logits, labels=None, attention_mask=None, reduciton='mean'):
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            if reduciton == 'batchmean':
                loss_fct = CrossEntropyLoss(reduction='none')
                shift_logits = shift_logits.view(-1, self.config.vocab_size)
                shift_labels = shift_labels.view(-1)
                shift_labels = shift_labels.to(shift_logits.device)
                loss = loss_fct(shift_logits, shift_labels)
                loss = loss.sum(dim=-1) / (attention_mask.sum(dim=-1))
            else:
                loss_fct = CrossEntropyLoss(reduction=reduciton)
                shift_logits = shift_logits.view(-1, self.config.vocab_size)
                shift_labels = shift_labels.view(-1)
                # Enable model parallelism
                shift_labels = shift_labels.to(shift_logits.device)
                loss = loss_fct(shift_logits, shift_labels)
        return loss

    # adapted from LLamaForCausalLM
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple]:
        #! This forward only returns the logits, never use this for training

        output_attentions = False
        output_hidden_states = False
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.basellm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )

        assit_outputs = self.assist_llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )
        
        # outlogits = outputs.logits        
        baselogits = outputs.logits
        assist_logits = assit_outputs.logits

        if self.top_logit_filter > 0.0:
            baselogits, mask, probs_thresh = self.relative_top_filter(baselogits, self.top_logit_filter)
            
            #! lowprob-0 filter
            assist_logits[mask] = 0
            logits = baselogits + self.weight * assist_logits
            # prevloss = self.get_loss(logits, labels, attention_mask=attention_mask, reduciton='batchmean')
            # print("Prevl:", prevloss)
            # # logits[mask] = -float('1e3')
            # #! reweight high-prob parts to make them higher than the low-prob parts, make sure they are still higher than lowprob tokens

            # clonebase = logits.clone()
            # clonebase[mask] = float('inf') # mask out low-prob
            # highprobmin = clonebase.min(dim=-1)[0].unsqueeze(-1)
            # # highprobmin = clonebase.max(dim=-1)[0].unsqueeze(-1)
            # clonebase = logits.clone()
            # clonebase[~mask] = -float('inf') # mask out high-prob
            # lowprobmax = clonebase.max(dim=-1)[0].unsqueeze(-1)
            # #! Make highprob min higher than lowprob max
            # mask_inverted = (~mask)
            # logits = (
            #     logits * mask_inverted + # keep low-prob parts
            #     logits * mask.to(torch.float32) + (lowprobmax - highprobmin) * mask_inverted.to(torch.float32)  # high-prob parts
            # )

            # clonebase = logits.clone()
            # clonebase[mask] = float('inf') # mask out low-prob
            # newhighprobmin = clonebase.min(dim=-1)[0].unsqueeze(-1)
            # # highprobmin = clonebase.max(dim=-1)[0].unsqueeze(-1)
            # clonebase = logits.clone()
            # clonebase[~mask] = -float('inf') # mask out high-prob
            # newlowprobmax = clonebase.max(dim=-1)[0].unsqueeze(-1)

            # after = self.get_loss(logits, labels, attention_mask=attention_mask, reduciton='batchmean')
            # print("After:", after)
            # import ipdb; ipdb.set_trace()

            # loss = None
            # if labels is None:
            #     labels = input_ids.clone()
            #     labels[attention_mask == 0] = -100
            #     # Shift so that tokens < n predict n
            #     shift_logits = logits[..., :-1, :].clone().contiguous()
            #     shift_labels = labels[..., 1:].contiguous()
            #     # Flatten the tokens
            #     loss_fct = CrossEntropyLoss(reduction='none')
            #     shift_logits = shift_logits.view(-1, self.config.vocab_size)
            #     shift_labels = shift_labels.view(-1)
            #     # Enable model parallelism
            #     shift_labels = shift_labels.to(shift_logits.device)
            #     loss = loss_fct(shift_logits, shift_labels).view_as(labels[..., 1:])
            #     unlearnloss = loss.sum(dim=-1) / attention_mask.sum(dim=-1)
            #     print(unlearnloss)

            #     shift_logits = copiedbase[..., :-1, :].clone().contiguous()
            #     shift_logits = shift_logits.view(-1, self.config.vocab_size)
            #     loss = loss_fct(shift_logits, shift_labels).view_as(labels[..., 1:])
            #     originloss = loss.sum(dim=-1) / attention_mask.sum(dim=-1)
            #     print(originloss)
            #     print(torch.mean(unlearnloss - originloss))
            
            # import ipdb; ipdb.set_trace()
            #! reweight low-prob parts to make sure they are not  
            # mask_inverted = 1 - mask.to(torch.float32)
            # mask_inverted = 1 - mask.to(torch.float32)
            # expanded = probs_thresh.expand_as(baselogits)
            # logits = logits * mask_inverted + expanded * mask

            #! naive reweight 
            # mask_inverted = 1 - mask.to(torch.float32)
            # expanded = probs_thresh.expand_as(baselogits)
            # logits = logits * mask_inverted + expanded * mask
        else:
            # baselogits = torch.log_softmax(baselogits, dim=-1)
            # assist_logits = torch.log_softmax(assist_logits, dim=-1)
            logits = baselogits + self.weight * assist_logits

        loss = None
        loss = self.get_loss(logits, labels)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )

def load_unlearned_model(hparams, device):
    if 'remember' not in hparams or hparams.remember is None: #! evaluate baseline LLM
        path = hparams.model.model_path
        print(f"Loading without assist from {path}")
        # Load configuration
        if os.path.exists(path): #! local model
            if not os.path.exists(os.path.join(path, 'config.json')): #! this is lora model
                #! Baseline full model with Lora
                model = AutoModelForCausalLM.from_pretrained(
                    "locuslab/tofu_ft_llama2-7b", torch_dtype=torch.bfloat16, use_flash_attention_2=True
                ).to(device)
                with NameTimer("Load lora modules"):
                    peftmod = PeftModel.from_pretrained(model, path, torch_dtype=torch.bfloat16, use_flash_attention_2=True)
                with NameTimer("Merge and unload"):
                    peftmod = peftmod.merge_and_unload()
                model = peftmod.eval().to(device)
            else:
                print("Load from a huggingface pretrained model")
                config = AutoConfig.from_pretrained(path)
                if hasattr(config, 'is_offset') and config.is_offset:
                    if hasattr(config, 'weight'):
                        weight = config.weight
                    else:
                        weight = 1.0
                    if 'mistral' in config.base_model_name.lower():
                        base_name = "trained_models2/finetune-hp/mistral/checkpoint-441"
                    else:
                        base_name = config.base_model_name
                    model = init_offset_model(
                        base_name, 
                        device=device, 
                        base_assist_path=config.base_assist_path, 
                        weight=weight, 
                        new_assist_path=path)
                else:
                    model = AutoModelForCausalLM.from_pretrained(
                        path, torch_dtype=torch.bfloat16, use_flash_attention_2=True
                    ).to(device)
        else:
            config = AutoConfig.from_pretrained(path)
            if 'base_model_name' in vars(config):
                print("Loading assisted models")
                #! Assisted model
                with NameTimer("Init assisted model"):
                    model = AssistedModel.init_from_pretrained(path)
                with NameTimer("Merge and unload model"):
                    model.assist_llm.merge_and_unload()
                model = model.to(device).eval() 
            else:
                #! load from a huggingface pretrained model
                print("Load from a huggingface pretrained model")
                model = AutoModelForCausalLM.from_pretrained(
                    path, torch_dtype=torch.bfloat16, use_flash_attention_2=True
                ).to(device)
        model = model.eval()
    else: #! evaluate constrast llm 
        with NameTimer("Loading our model"):
            basellm = UnlearnLLM(
                hparams.model, device=device, hparams=hparams,
            )
            assist_llm = SmallLLM(
                basellm.model.config,
                num_layer=hparams.remember.num_layer,
                device=device,
                base_llm=basellm if hparams.remember.save_path == "" else None,
                save_path=hparams.remember.save_path if hparams.remember.save_path != "" else None,
                is_lora=hparams.remember.is_lora,
            )
            model = ContrastLLM(
                basellm, assist_llm, 
                weight=hparams.remember.weight, 
                top_logit_filter=hparams.remember.top_logit_filter
            )

    return model