import os, socket
import numpy as np
import random
import torch
import torch.nn.functional as F
import tqdm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import lightning as L
from functools import partial

from datasets import load_dataset

from workbench.data.pubmedqa import collator, load_pubmed




def load_owt(tokenizer, model_name, cache_dir, max_seq_length, artificial_numb):
    if not os.path.exists(os.path.join(cache_dir, f"owt_{model_name}_len{max_seq_length}_{artificial_numb}.pt")):
        artificial_datasets = load_dataset("Skylion007/openwebtext", cache_dir=cache_dir)['train']
        print("load complete")

        sample_count = len(artificial_datasets)
        random_samples = random.sample(range(sample_count), artificial_numb)

        artificial_datasets = [tokenizer.encode(artificial_datasets[i]['text']) for i in random_samples]
        artificial_datasets = [s[:max_seq_length] for s in  tqdm.tqdm(artificial_datasets)]
        # artificial_datasets = [s for s in artificial_datasets if 350 < len(s) and len(s) < max_seq_length] # TODO remove

        # artificial_datasets = artificial_datasets[:artificial_numb]

        torch.save(artificial_datasets, os.path.join(cache_dir, f"owt_{model_name}_len{max_seq_length}_{artificial_numb}.pt"))
    else:
        artificial_datasets = torch.load(os.path.join(cache_dir, f"owt_{model_name}_len{max_seq_length}_{artificial_numb}.pt"))



    return artificial_datasets #  unlabeled_datasets, labeled_samples

def get_pubmedqa_owt(tokenizer, config):


    random.seed(config.data_seed)

    model_name = config.model_name.split("/")[-1]




    # artificial_datasets, unlabeled_datasets, labeled_samples = load_pubmed(tokenizer,model_name=model_name, cache_dir=config.cache_dir, max_seq_length=config.max_seq_length)
    samples = load_pubmed(tokenizer,model_name=model_name, cache_dir=config.cache_dir,
                          max_seq_length=config.max_seq_length, artificial_numb=config.pubmedqa.artificial_numb)

    samples_owt = load_owt(tokenizer,model_name=model_name, cache_dir=config.cache_dir,
                          max_seq_length=config.max_seq_length, artificial_numb=config.pubmedqa.artificial_numb)

    samples = samples + samples_owt

    print("artificial samples:", len(samples))
    # print("unlabeled samples:", len(unlabeled_datasets))
    # print("labeled samples:", len(labeled_samples))

    # samples = artificial_datasets[:config.pubmedqa.artificial_numb]
    # samples = labeled_samples * config.pubmedqa.labeled_oversample + artificial_datasets[:config.pubmedqa.artificial_numb] + unlabeled_datasets[:config.pubmedqa.unlabeled_numb]

    # samples = [s for s in samples if 1 < len(s) and len(s) < config.max_seq_length]

    print("total samples:", len(samples))

    random.shuffle(samples)
    train_samples = samples[:int(len(samples) * (1 - config.val_split))]
    val_samples = samples[int(len(samples) * (1 - config.val_split)):]

    collator_tok = partial(collator, tokenizer=tokenizer)

    train_loader = DataLoader(
        train_samples,
        batch_size=config.effective_batch_size,
        collate_fn=collator_tok,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_samples,
        batch_size=config.effective_batch_size,
        collate_fn=collator_tok,
        num_workers=0,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
    )


    print("train samples:", len(train_samples))
    print("val samples:", len(val_samples))


    return train_loader, val_loader


if __name__ == "__main__":

    from transformers import AutoTokenizer, AutoModelForCausalLM


    # train_loader, val_loader = get_pubmedqa(tokenizer, config=config)

    # cache_dir = '/home/joerg/workspace/python/github/ICML2024_experiments/cache'
    # model_name = "facebook/opt-125m"
    cache_dir = "/p/scratch/transfernetx/franke5/model/cache/"
    model_name = "mistralai/Mistral-7B-Instruct-v0.2"


    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

    model_name = model_name.split("/")[-1]
    samples = load_pubmed(tokenizer, model_name, cache_dir, max_seq_length=400, artificial_numb=100000)
    samples = load_owt(tokenizer, model_name, cache_dir, max_seq_length=400, artificial_numb=100000)

    print("tokenized samples:", len(samples))

    total_tokens = 0
    for s in samples:
        total_tokens += len(s)

    print("total tokens:", total_tokens)