from imports import *
from utils import load_json
from torch.utils.data import DataLoader, Dataset

from tango import step
import time
import tango
from tango.common import FromParams
from weights_composer import remove_components
from tango.common.det_hash import CustomDetHash

class PromptDataset(Dataset):
    def __init__(self, prompts):
        self.prompts = prompts

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx]
    

#@step(cacheable=False)
def load_dataset(path, batch_size=20) -> DataLoader:
    print("loading dataset from", path)
    data = load_json(path)
    dataset = PromptDataset(data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

#@step(cacheable=True, deterministic=True)
def load_model(model_name):
    if 'pythia' not in model_name:
        return HookedTransformer.from_pretrained(model_name, fold_value_biases=True, refactor_factored_attn_matrices=True)
    else:
        return HookedTransformer.from_pretrained(model_name, fold_value_biases=True)


class DataParams(FromParams):
    VERSION='002'
    def __init__(self, dataset_path: str, batch_size: int =20):
        self.dataset_path = dataset_path
        self.dataset = load_dataset(dataset_path, batch_size=batch_size)

    def det_hash_object(self):
        return str(self.dataset_path)+'-'+self.VERSION

class ModelParams(CustomDetHash):
    CACHEABLE=True
    DETERIMINISTIC=True
    def __init__(self, model_name: str, should_load=True):
        self.model = None
        if should_load:
            self.model = load_model(model_name)
        self.model_name = model_name
    
    def load(self):
        self.model = load_model(self.model_name)
    
    def det_hash_object(self):
        return str(self.model_name)

#TODO: make this is a registrable or something
class SimpleEditedModel(FromParams):
    VERSION = '002'
    def __init__(self, model_params: ModelParams, note='removing all top in inhibition heads'):
        model = model_params.model
        inhibition_heads = [(7,3), (7,9), (8,6), (8,10)]
        ov = model.blocks[7].attn.OV[9]
        ou, osig, ovh = ov.svd()
        newov = remove_components(ou, osig, ovh, [6])# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[7].attn.W_V[9]= newov.A
        model.blocks[7].attn.W_O[9] = newov.B

        ov = model.blocks[7].attn.OV[3]
        ou, osig, ovh = ov.svd()
        newov = remove_components(ou, osig, ovh, [1])# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[7].attn.W_V[3]= newov.A
        model.blocks[7].attn.W_O[3] = newov.B

        ov = model.blocks[8].attn.OV[6]
        ou, osig, ovh = ov.svd()
        newov = remove_components(ou, osig, ovh, [2])# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[8].attn.W_V[6]= newov.A
        model.blocks[8].attn.W_O[6] = newov.B

        ov = model.blocks[8].attn.OV[10]
        ou, osig, ovh = ov.svd()
        newov = remove_components(ou, osig, ovh, [1])# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[8].attn.W_V[10]= newov.A
        model.blocks[8].attn.W_O[10] = newov.B
        self.model = model

class OnlyTopCompsModel(FromParams):
    VERSION='002'
    def __init__(self, model_params: ModelParams, note='removing all top in inhibition heads'):
        model = model_params.model
        inhibition_heads = [(7,3), (7,9), (8,6), (8,10)]
        ov = model.blocks[7].attn.OV[9]
        ou, osig, ovh = ov.svd()
        keep6 = list(range(64))
        keep6.remove(6)
        c

        ov = model.blocks[7].attn.OV[3]
        ou, osig, ovh = ov.svd()
        keep1 = list(range(64))
        keep1.remove(1)
        newov = remove_components(ou, osig, ovh, keep1)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[7].attn.W_V[3]= newov.A
        model.blocks[7].attn.W_O[3] = newov.B

        ov = model.blocks[8].attn.OV[6]
        ou, osig, ovh = ov.svd()
        keep2 = list(range(64))
        keep2.remove(2)
        newov = remove_components(ou, osig, ovh, keep2)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[8].attn.W_V[6]= newov.A
        model.blocks[8].attn.W_O[6] = newov.B

        ov = model.blocks[8].attn.OV[10]
        ou, osig, ovh = ov.svd()
        newov = remove_components(ou, osig, ovh, keep1)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[8].attn.W_V[10]= newov.A
        model.blocks[8].attn.W_O[10] = newov.B
        self.model = model

class TopAndZeroCompsModel(FromParams):
    VERSION='002'
    def __init__(self, model_params: ModelParams, note='removing all top in inhibition heads'):
        model = model_params.model
        inhibition_heads = [(7,3), (7,9), (8,6), (8,10)]
        ov = model.blocks[7].attn.OV[9]
        ou, osig, ovh = ov.svd()
        keep6 = list(range(1,64))
        keep6.remove(6)
        newov = remove_components(ou, osig, ovh, keep6)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[7].attn.W_V[9]= newov.A
        model.blocks[7].attn.W_O[9] = newov.B

        ov = model.blocks[7].attn.OV[3]
        ou, osig, ovh = ov.svd()
        keep1 = list(range(1,64))
        keep1.remove(1)
        newov = remove_components(ou, osig, ovh, keep1)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[7].attn.W_V[3]= newov.A
        model.blocks[7].attn.W_O[3] = newov.B

        ov = model.blocks[8].attn.OV[6]
        ou, osig, ovh = ov.svd()
        keep2 = list(range(1,64))
        keep2.remove(2)
        newov = remove_components(ou, osig, ovh, keep2)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[8].attn.W_V[6]= newov.A
        model.blocks[8].attn.W_O[6] = newov.B

        ov = model.blocks[8].attn.OV[10]
        ou, osig, ovh = ov.svd()
        newov = remove_components(ou, osig, ovh, keep1)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[8].attn.W_V[10]= newov.A
        model.blocks[8].attn.W_O[10] = newov.B
        self.model = model

class ZeroCompsModel(FromParams):
    VERSION='001'
    def __init__(self, model_params: ModelParams, note='removing all top in inhibition heads'):
        model = model_params.model
        inhibition_heads = [(7,3), (7,9), (8,6), (8,10)]
        ov = model.blocks[7].attn.OV[9]
        ou, osig, ovh = ov.svd()
        keep0 = list(range(1,64))
        newov = remove_components(ou, osig, ovh, keep0)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[7].attn.W_V[9]= newov.A
        model.blocks[7].attn.W_O[9] = newov.B

        ov = model.blocks[7].attn.OV[3]
        ou, osig, ovh = ov.svd()
        newov = remove_components(ou, osig, ovh, keep0)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[7].attn.W_V[3]= newov.A
        model.blocks[7].attn.W_O[3] = newov.B

        ov = model.blocks[8].attn.OV[6]
        ou, osig, ovh = ov.svd()
        #for this one, we keep the 0th component already
        newov = remove_components(ou, osig, ovh, keep0)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[8].attn.W_V[6]= newov.A
        model.blocks[8].attn.W_O[6] = newov.B

        ov = model.blocks[8].attn.OV[10]
        ou, osig, ovh = ov.svd()
        newov = remove_components(ou, osig, ovh, keep0)# [2]#sum_comps(u, s, v.T, keepdims)#re_get_single_component(u, s, v, i)#sum_comps(u, s, v.T, keepdims)#
        model.blocks[8].attn.W_V[10]= newov.A
        model.blocks[8].attn.W_O[10] = newov.B
        self.model = model

def get_token_idx(str_tokens, token):
    return str_tokens.index(token) 

def calc_inhib_score(model, prompt, cache, mover_layer, mover_head):
    attn_pat = cache['pattern', mover_layer, 'attn']
    io_token = ' '+prompt['IO']
    s_token = ' '+prompt['S']
    str_tokens = model.to_str_tokens(prompt['text'])
    io_idx = get_token_idx(str_tokens, io_token)
    s_idx = get_token_idx(str_tokens, s_token)
    last_tok = len(str_tokens)-1
    #print('last token:', str_tokens[last_tok], last_tok)
    #print('io idx', io_idx, s_idx)
    #print("text", [f"{i}_{s}" for i, s in enumerate(str_tokens)])
    #print('io attn', attn_pat[mover_head, last_tok, io_idx])
    #print('s attn', attn_pat[mover_head, last_tok, s_idx])
    #print('junk attn', attn_pat[mover_head, last_tok, 0])
    return attn_pat[mover_head, last_tok, io_idx] - attn_pat[mover_head, last_tok, s_idx]

def calc_inhib_score_to_all_S(model, prompt, cache, mover_layer, mover_head):
    attn_pat = cache['pattern', mover_layer, 'attn']
    io_token = ' '+prompt['IO']
    s_token = ' '+prompt['S']
    str_tokens = model.to_str_tokens(prompt['text'])
    io_idx = get_token_idx(str_tokens, io_token)
    s_idxs = [i for i, s in enumerate(str_tokens) if s==s_token]
    last_tok = len(str_tokens)-1
    avg_inhib_score = 0.0
    for s_idx in s_idxs:
        avg_inhib_score += attn_pat[mover_head, last_tok, io_idx] - attn_pat[mover_head, last_tok, s_idx]
    avg_inhib_score/=len(s_idxs)
    return avg_inhib_score

def calc_inhib_score_to_S2(model, prompt, cache, mover_layer, mover_head):
    attn_pat = cache['pattern', mover_layer, 'attn']
    io_token = ' '+prompt['IO']
    s_token = ' '+prompt['S']
    str_tokens = model.to_str_tokens(prompt['text'])
    io_idx = get_token_idx(str_tokens, io_token)
    s_idx = [i for i,s in enumerate(str_tokens) if s == s_token][1] #get the second occurrence
    last_tok = len(str_tokens)-1
    return attn_pat[mover_head, last_tok, io_idx] - attn_pat[mover_head, last_tok, s_idx]

@step(cacheable=True, deterministic=True, version='003')
def get_inhibition_scores(
    model : FromParams,
    dataset: DataParams,
    #model_name: str,
    #dataset_path: str, 
    mover_layer: int, 
    mover_head: int) -> np.array:
    #get the inhibition scores of the given mover head on the dataset
    #return a list of inhibition scores for each prompt
    #model =HookedTransformer.from_pretrained(model_name)
    if model.model is None:
        model.load()
    model = model.model
    dataset = dataset.dataset
    inhib_scores = []
    #dataset = load_dataset(path=dataset_path, batch_size=20)
    def get_prompt(prompts, idx):
        newprompt = dict.fromkeys(prompts)
        for key in prompts:
            newprompt[key] = prompts[key][idx]
        return newprompt
    for batch in dataset:
        text = batch['text']
        _, cache = model.run_with_cache(text)
        for batch_idx in range(len(text)):
            cur_prompt = get_prompt(batch, batch_idx)
            score = calc_inhib_score(model, cur_prompt, cache.apply_slice_to_batch_dim(batch_idx), mover_layer, mover_head)
            inhib_scores.append(score.item())
            #print(text[batch_idx], score)
    return np.array(inhib_scores)

if __name__ == "__main__":
    ws = tango.Workspace.from_url("./tango_workspace")
    model_name = 'gpt2-small'
    #model = load_model(model_name=model_name).result(ws)
    
    dataset_path = 'datasets/ioi_dataset_200.json'
    #dataset = load_dataset(path=dataset_path, batch_size=20)#.result(ws)
    #print(dataset, type(dataset))
    model_params = ModelParams(model_name)
    data_params = DataParams(dataset_path, batch_size=20)
    start_time = time.time()
    inhibition_scores = get_inhibition_scores(
        model = model_params,
        dataset = data_params,
        #model_name=model_name, 
        #dataset_path=dataset_path,#dataset, 
        mover_layer=9, 
        mover_head=9).result(ws)
    end_time = time.time()
    print('time taken', end_time-start_time)

    import termplotlib as tpl
    #print('inhibition scores', inhibition_scores)
    x = np.arange(len(inhibition_scores))
    y = inhibition_scores
    fig = tpl.figure()
    counts, bin_edges = np.histogram(y, range=(-.05, 1.0), bins=10)
    fig = tpl.figure()
    fig.hist(counts, bin_edges, orientation="horizontal", force_ascii=False)
    fig.show()
    print(inhibition_scores.mean(), 'MEAN')
    print('# <5.0', (inhibition_scores<.05).sum())

    print("Edited Model")
    start_time = time.time()
    inhibition_scores = get_inhibition_scores(
        model = SimpleEditedModel(model_params),
        dataset = data_params,
        mover_layer=9, 
        mover_head=9).result(ws)
    end_time = time.time()
    print('time taken', end_time-start_time)

    #print('inhibition scores', inhibition_scores)
    x = np.arange(len(inhibition_scores))
    y = inhibition_scores
    fig = tpl.figure()
    counts, bin_edges = np.histogram(y, range=(-.05, 1.0), bins=10)
    fig = tpl.figure()
    fig.hist(counts, bin_edges,orientation="horizontal", force_ascii=False)
    fig.show()
    print(inhibition_scores.mean(), 'MEAN')
    print('# <5.0', (inhibition_scores<.05).sum())

    del model_params

    print("Retained Components Model")
    start_time = time.time()
    inhibition_scores = get_inhibition_scores(
        model = OnlyTopCompsModel(ModelParams(model_name)),
        dataset = data_params,
        mover_layer=9, 
        mover_head=9).result(ws)
    end_time = time.time()
    print('time taken', end_time-start_time)

    #print('inhibition scores', inhibition_scores)
    x = np.arange(len(inhibition_scores))
    y = inhibition_scores
    fig = tpl.figure()
    counts, bin_edges = np.histogram(y, range=(-1.0, 1.0), bins=10)
    fig = tpl.figure()
    fig.hist(counts, bin_edges,orientation="horizontal", force_ascii=False)
    fig.show()
    print(inhibition_scores.mean(), 'MEAN')
    print('# <5.0', (inhibition_scores<.05).sum())


    print("Top and 0 Components Model")
    start_time = time.time()
    inhibition_scores = get_inhibition_scores(
        model = TopAndZeroCompsModel(ModelParams(model_name)),
        dataset = data_params,
        mover_layer=9, 
        mover_head=9).result(ws)
    end_time = time.time()
    print('time taken', end_time-start_time)

    #print('inhibition scores', inhibition_scores)
    x = np.arange(len(inhibition_scores))
    y = inhibition_scores
    fig = tpl.figure()
    counts, bin_edges = np.histogram(y, range=(-1.0, 1.0), bins=10)
    fig = tpl.figure()
    fig.hist(counts, bin_edges,orientation="horizontal", force_ascii=False)
    fig.show()
    print(inhibition_scores.mean(), 'MEAN')
    print('# <5.0', (inhibition_scores<.05).sum())

    print("Index 0 Only Components Model")
    start_time = time.time()
    inhibition_scores = get_inhibition_scores(
        model = ZeroCompsModel(ModelParams(model_name)),
        dataset = data_params,
        mover_layer=9, 
        mover_head=9).result(ws)
    end_time = time.time()
    print('time taken', end_time-start_time)

    #print('inhibition scores', inhibition_scores)
    x = np.arange(len(inhibition_scores))
    y = inhibition_scores
    fig = tpl.figure()
    counts, bin_edges = np.histogram(y, range=(-1.0, 1.0), bins=10)
    fig = tpl.figure()
    fig.hist(counts, bin_edges,orientation="horizontal", force_ascii=False)
    fig.show()
    print(inhibition_scores.mean(), 'MEAN')
    print('# <5.0', (inhibition_scores<.05).sum())