import os, socket
import numpy as np
import random
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import lightning as L
from functools import partial

from datasets import load_dataset


def collator(sample_list, tokenizer):
    inputs = pad_sequence([torch.LongTensor(s[:-1]) for s in sample_list], batch_first=True,
                          padding_value=tokenizer.pad_token_id)
    targets = pad_sequence([torch.LongTensor(s[1:]) for s in sample_list], batch_first=True,
                           padding_value=tokenizer.pad_token_id)
    return inputs, targets

def load_medmcqa(tokenizer, cache_dir, max_seq_length):

    if not os.path.exists(os.path.join(cache_dir, "medmcqa.pt")):
        dataset = load_dataset("medmcqa", cache_dir=cache_dir)['train']
        answer_key_mapping = {
            0 : 'opa',
            1 : 'opb',
            2 : 'opc',
            3 : 'opd'
        }
        key_name_mapping = {
            'opa': 'A. ',
            'opb': 'B. ',
            'opc': 'C. ',
            'opd': 'D. '
        }
        dataset = [
            tokenizer.encode(
                f" ".join([
                s['question'],
                f"A. {s['opa']}",
                f"B. {s['opb']}",
                f"C. {s['opc']}",
                f"D. {s['opd']}.",
                f"{s['exp']}.", f"Answer: {key_name_mapping[answer_key_mapping[s['cop']]]}{s[answer_key_mapping[s['cop']]]}"])
            )
            for s in dataset
        ]
        dataset = [s for s in dataset if 1 < len(s) and len(s) < max_seq_length]
        torch.save({"dataset": dataset}, os.path.join(cache_dir, "medmcqa.pt"))
    else:
        data = torch.load(os.path.join(cache_dir, "medmcqa.pt"))
        dataset = data["dataset"]

    return dataset

def get_medmcqa(tokenizer, max_seq_length, val_split, effective_batch_size, cache_dir):

    dataset = load_medmcqa(tokenizer, cache_dir=cache_dir, max_seq_length=max_seq_length)

    print("MedQA dataset samples:", len(dataset))

    random.shuffle(dataset)
    train_samples = dataset[:int(len(dataset) * (1 - val_split))]
    val_samples = dataset[int(len(dataset) * (1 - val_split)):]

    collator_tok = partial(collator, tokenizer=tokenizer)

    train_loader = DataLoader(
        train_samples,
        batch_size=effective_batch_size,
        collate_fn=collator_tok,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_samples,
        batch_size=effective_batch_size,
        collate_fn=collator_tok,
        num_workers=0,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
    )

    print("train samples:", len(train_samples))
    print("val samples:", len(val_samples))

    return train_loader, val_loader