import io
import subprocess
import re
import csv
from tqdm.auto import tqdm, trange
from concurrent.futures import ProcessPoolExecutor, as_completed
import argparse
import numpy as np
from datasets import Dataset, DatasetDict, Features, Sequence, Value
import chess.pgn

# Vocabulary constants
VOCAB_SIZE = 76
N_POSITION = 512
PAD_TOKEN_ID = 0
BOS_TOKEN_ID = 75
EOS_TOKEN_ID = 74

# xLANs tokenizer dictionary
CHESS_PIECES = ["K", "Q", "R", "B", "N", "P"]
CHESS_POSITIONS = [
    col + row
    for col in ["a", "b", "c", "d", "e", "f", "g", "h"]
    for row in ["1", "2", "3", "4", "5", "6", "7", "8"]
]
TOKENIZER_DICT = {
    **{piece: i for i, piece in enumerate(CHESS_PIECES, start=1)},
    **{pos: i + len(CHESS_PIECES) for i, pos in enumerate(CHESS_POSITIONS, start=1)},
}

# Dataset template
DATASET_FEATURES = {
    "hash_code": Value("string"),
    "prompt_token_ids": Sequence(Value("int64")),
    "answer_token_ids": Sequence(Value("int64")),
    "rating": Value("int64"),
    "rating_deviation": Value("int64"),
    "num_moves": Value("int64"),
}


# Load puzzles from CSV file
def load_puzzle_csv(csv_file_path):
    puzzles = {}
    # Read the CSV file and populate the dictionary with headers
    with open(csv_file_path, mode="r") as file:
        csv_reader = csv.reader(file)
        headers = next(csv_reader)  # Read the first row to get the headers
        puzzles = {
            header: [] for header in headers
        }  # Initialize dictionary with headers
        # Headers: PuzzleId,FEN,Moves,Rating,RatingDeviation,Popularity,NbPlays,Themes,GameUrl,OpeningTags
        # Populate the dictionary with the rest of the data
        for row in tqdm(
            csv.DictReader(file, fieldnames=headers),
            desc="Reading puzzles CSV",
            total=int(subprocess.check_output(["wc", "-l", csv_file_path]).split()[0])
            - 1,
        ):
            for key, value in row.items():
                if key in ["Moves", "Themes", "OpeningTags"]:
                    puzzles[key].append(value.split(" "))
                elif key in ["Rating", "RatingDeviation", "Popularity", "NbPlays"]:
                    puzzles[key].append(int(value))
                else:
                    puzzles[key].append(value)
    return puzzles


# Load from PGN file
def load_pgn_file(pgn_file_path):
    pgns = []  # Initialize an empty list to hold the pgn strings
    pgn = []  # Initialize an empty list to hold the lines of a pgn block
    empty_line_seen = False  # Flag to indicate an empty line within a pgn block
    with open(pgn_file_path, "r") as file:
        for line in tqdm(
            file,
            desc="Reading PGN file",
            total=int(subprocess.check_output(["wc", "-l", pgn_file_path]).split()[0]),
        ):
            if line.strip():
                pgn.append(line.rstrip())  # Add non-empty line to the current pgn block
            else:
                if not empty_line_seen:  # If first empty line in a pgn block
                    pgn.append("")  # Add the empty line to the pgn block
                    empty_line_seen = True  # Set flag to indicate empty line seen
                elif pgn:  # If an empty line is seen again, it's end of a pgn block
                    pgns.append("\n".join(pgn))  # Add the pgn block to the list
                    pgn = []  # Reset for the next pgn block
                    empty_line_seen = False  # Reset flag for the new pgn block
        if pgn:  # Add the last pgn block if it's not added yet
            pgns.append("\n".join(pgn))
    # Create a dictionary with game hashes as keys and pgns as values
    pgns = {
        re.search(r'\[Site "https://lichess.org/([a-zA-Z0-9]+)"\]', pgn).group(1): pgn
        for pgn in pgns
    }
    return pgns


# Parse game URL
def parse_game_url(url):
    # Pattern to extract game hash and step number, accounting for possible URL structures
    pattern = re.compile(r"lichess\.org/([\w]+)(?:/black|/white)?#?(\d*)")
    match = pattern.search(url)
    if match:
        game_hash = match.group(1)
        num_moves = int(match.group(2)) + 1 if match.group(2) else None
        return game_hash, num_moves
    else:
        raise ValueError(f"Invalid URL: {url}")


# Convert PGN prompt & UCI answer to xLANs
def pgn_to_xlan(pgn, uci, num_moves):
    # Convert the string to a file-like object
    pgn = io.StringIO(pgn)
    # Load the game using python-chess
    game = chess.pgn.read_game(pgn)
    # Initialize an empty board to play moves on
    board = game.board()
    move_strings = []
    # Construct the prompt by playing the moves on the board
    for move_num, move in enumerate(game.mainline_moves(), 1):
        # Stop if we've reached the desired number of moves
        if move_num > num_moves:
            break
        # Retrieve the piece at the from_square
        piece_symbol = board.piece_at(move.from_square).symbol().upper()
        # Use move.from_square and move.to_square to get starting and ending positions
        from_square = chess.square_name(move.from_square)
        to_square = chess.square_name(move.to_square)
        # Apply the move to update the board's state
        board.push(move)
        # Format the move as a xLAN string
        move_strings.append(f"{piece_symbol} {from_square} {to_square}")
    prompt = " ".join(move_strings)
    # Construct the answer by converting UCI to xLAN
    move = board.parse_uci(uci)
    piece_symbol = board.piece_at(move.from_square).symbol().upper()
    from_square = chess.square_name(move.from_square)
    to_square = chess.square_name(move.to_square)
    answer = f"{piece_symbol} {from_square} {to_square}"
    # Return the prompt and answer
    return prompt, answer


# Tokenize XLAN
def tokenize_xlan(xlan_text):
    words = [word for word in xlan_text.split(" ")]
    token_ids = [TOKENIZER_DICT[word] for word in words]
    return token_ids


def worker(args):
    game_hash, pgn, uci, num_moves, rating, rating_deviation = args
    prompt, answer = pgn_to_xlan(pgn, uci, num_moves)
    prompt_token_ids = tokenize_xlan(prompt)
    answer_token_ids = tokenize_xlan(answer)
    return {
        "hash_code": game_hash,
        "prompt_token_ids": prompt_token_ids,
        "answer_token_ids": answer_token_ids,
        "rating": rating,
        "rating_deviation": rating_deviation,
        "num_moves": num_moves,
    }


def submit_tasks_in_chunks(executor, tasks, chunk_size=10000):
    futures = []
    for i in trange(0, len(tasks), chunk_size, desc="Submitting tasks"):
        chunk = tasks[i : i + chunk_size]
        futures.extend([executor.submit(worker, args) for args in chunk])
    return futures


def process_in_chunks(tasks):
    dataset = []
    with ProcessPoolExecutor(max_workers=8) as executor:
        futures = submit_tasks_in_chunks(executor, tasks)
        # Wrap as_completed(futures) with tqdm for progress visualization
        for future in tqdm(
            as_completed(futures), total=len(futures), desc="Processing tasks"
        ):
            dataset.append(future.result())
    return dataset


def main(pgn_file_path, puzzle_csv_path, huggingface_path, version_tag):
    # Load puzzles and pgns
    puzzles = load_puzzle_csv(puzzle_csv_path)
    pgns = load_pgn_file(pgn_file_path)
    puzzle_indices = np.where([len(moves) == 2 for moves in puzzles["Moves"]])[0]
    # Create a list of arguments for the worker function
    tasks = []
    for idx in tqdm(puzzle_indices, desc="Preparing args"):
        tasks.append(
            (
                parse_game_url(puzzles["GameUrl"][idx])[0],
                pgns[parse_game_url(puzzles["GameUrl"][idx])[0]],
                puzzles["Moves"][idx][1],
                parse_game_url(puzzles["GameUrl"][idx])[1],
                puzzles["Rating"][idx],
                puzzles["RatingDeviation"][idx],
            )
        )
    # Process puzzles in parallel
    dataset = process_in_chunks(tasks)

    # Convert your data into a Dataset, specify features explicitly if you need to control the types
    dataset_dict = {key: [item[key] for item in dataset] for key in DATASET_FEATURES}
    DatasetDict(
        {
            f"default": Dataset.from_dict(
                dataset_dict, features=Features(DATASET_FEATURES)
            )
        }
    ).push_to_hub(huggingface_path, version_tag)


if __name__ == "__main__":
    # Set up argument parsing
    argparser = argparse.ArgumentParser(
        description="Download PGN files with backup and concurrency control."
    )
    argparser.add_argument(
        "--pgn_file_path",
        type=str,
        default="./data/Lichess/one_move_puzzle.pgn",
        help="Path to the PGN file",
    )
    argparser.add_argument(
        "--puzzle_csv_path",
        type=str,
        default="./data/Lichess/lichess_db_puzzle.csv",
        help="Path to the CSV file containing the puzzles",
    )
    argparser.add_argument(
        "--huggingface_path",
        type=str,
        default="mcding-org/Easy2Hard-Lichess",
        help="Path to the Hugging Face dataset",
    )
    argparser.add_argument(
        "--version_tag",
        type=str,
        default="v1",
        help="Version tag for the Hugging Face dataset",
    )
    args = argparser.parse_args()

    main(
        args.pgn_file_path,
        args.puzzle_csv_path,
        args.huggingface_path,
        args.version_tag,
    )
