# https://github.com/FlagOpen/FlagEmbedding/blob/master/LM_Cocktail/LM_Cocktail/utils.py#L120
import torch

def llm_loss(base_model, input_data):
    loss = 0
    for data in input_data:
        output = base_model(**data)
        loss += output.loss.cpu()
    loss = float(loss / len(input_data))
    return loss

def embedder_loss(base_model, input_data):

    def generate_embeddings(model, inputs):
        embeddings = model(**inputs, return_dict=True).last_hidden_state[:, 0]
        embeddings = torch.nn.functional.normalize(embeddings, dim=-1)        
        return embeddings

    loss = 0
    for q_inputs, p_inputs in input_data:

        q_embeddings = generate_embeddings(base_model, q_inputs)
        p_embeddings = generate_embeddings(base_model, p_inputs)
        scores = torch.matmul(q_embeddings, p_embeddings.transpose(0, 1)) / 0.05

        target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
        target = target * (p_embeddings.size(0) // q_embeddings.size(0))

        batch_loss = torch.nn.CrossEntropyLoss(reduction='mean')(scores, target)
        loss += batch_loss.cpu()

    loss = float(loss / len(input_data))
    return float(loss)

@torch.no_grad()
def compute_weights(
    base_model, 
    tokenizer, 
    param_list, 
    model_type: str, 
    example_data, 
    temperature: float=5.0, 
    batch_size:int=2, 
    max_input_length:int=2048, 
    neg_number:int=7
):
    
    if model_type == 'decoder':
        input_data = preprocess_data_for_llm(
            example_data=example_data, tokenizer=tokenizer, batch_size=batch_size, max_input_length=max_input_length
        )
        loss_func = llm_loss
    elif model_type == 'encoder':
        input_data = preprocess_data_for_embedder(
            example_data=example_data, tokenizer=tokenizer,  batch_size=batch_size, max_input_length=max_input_length, neg_number=neg_number
        )
        loss_func = embedder_loss

    # weight_i = softmax(-EvalLoss(Model, fewshot-examples})/temperature) 
    # the larger loss, the smaller weight
    example_loss = [] 
    for params in param_list:
        base_model.load_state_dict(params)
        loss = loss_func(base_model=base_model, input_data=input_data)
        example_loss.append(loss)
    
    weights = torch.softmax(-torch.FloatTensor(example_loss)/temperature, -1).numpy().tolist()
    return weights