import os
import shutil
import subprocess
import re
import aiohttp
import asyncio
import aiofiles
import warnings
import time
import csv
import argparse
from tqdm.asyncio import tqdm
import numpy as np


async def download_pgn(game_hash, session, retry_delay=1):
    url = f"https://lichess.org/game/export/{game_hash}?literate=1"
    while True:
        try:
            async with session.get(url) as response:
                response.raise_for_status()
                await asyncio.sleep(0.1)  # Slight delay to avoid rate limits
                return (await response.text()).rstrip()
        except aiohttp.ClientError as e:
            warnings.warn(
                f"Error downloading PGN: {e}, retrying in {retry_delay} seconds..."
            )
            await asyncio.sleep(retry_delay)
            retry_delay *= 2  # Exponential backoff


async def download_and_append(
    game_hash, session, file, semaphore, backup_interval, download_count
):
    async with semaphore:
        pgn_content = await download_pgn(game_hash, session)
        async with aiofiles.open(file, "a") as f:
            await f.write(pgn_content + "\n\n")

        download_count["count"] += 1
        if download_count["count"] % backup_interval == 0:
            backup_file(file)


def load_puzzle_csv(csv_file_path):
    puzzles = {}
    # Read the CSV file and populate the dictionary with headers
    with open(csv_file_path, mode="r") as file:
        csv_reader = csv.reader(file)
        headers = next(csv_reader)  # Read the first row to get the headers
        puzzles = {
            header: [] for header in headers
        }  # Initialize dictionary with headers
        # Headers: PuzzleId,FEN,Moves,Rating,RatingDeviation,Popularity,NbPlays,Themes,GameUrl,OpeningTags
        # Populate the dictionary with the rest of the data
        for row in tqdm(
            csv.DictReader(file, fieldnames=headers),
            desc="Reading puzzles CSV",
            total=int(subprocess.check_output(["wc", "-l", csv_file_path]).split()[0])
            - 1,
        ):
            for key, value in row.items():
                if key in ["Moves", "Themes", "OpeningTags"]:
                    puzzles[key].append(value.split(" "))
                elif key in ["Rating", "RatingDeviation", "Popularity", "NbPlays"]:
                    puzzles[key].append(int(value))
                else:
                    puzzles[key].append(value)
    return puzzles


def parse_game_url(url):
    # Pattern to extract game hash and step number, accounting for possible URL structures
    pattern = re.compile(r"lichess\.org/([\w]+)(?:/black|/white)?#?(\d*)")
    match = pattern.search(url)
    if match:
        game_hash = match.group(1)
        num_step = int(match.group(2)) if match.group(2) else None
        return game_hash, num_step
    else:
        raise ValueError(f"Invalid URL: {url}")


def backup_file(original_file):
    timestamp = int(time.time())
    backup_path = f"{original_file}.{timestamp}.backup"
    shutil.copyfile(original_file, backup_path)
    print(f"Backup created at {backup_path}")


def find_downloaded_hashes(pgn_file_path):
    pattern = re.compile(r'\[Site "([^"]+)"\]')
    existing_hashes = set()
    if not os.path.exists(pgn_file_path):
        with open(pgn_file_path, "w") as file:
            pass
        return existing_hashes
    with open(pgn_file_path, "r") as file:
        for line in tqdm(
            file,
            desc="Reading PGN file",
            total=int(subprocess.check_output(["wc", "-l", pgn_file_path]).split()[0]),
        ):
            match = pattern.search(line)
            if match:
                existing_hashes.add(match.group(1).split("/")[-1])
    return existing_hashes


async def main(pgn_file_path, puzzle_csv_path, max_concurrent_tasks, backup_interval):
    puzzles = load_puzzle_csv(puzzle_csv_path)
    one_move_puzzle_indices = np.array([len(moves) for moves in puzzles["Moves"]]) == 2
    all_game_hases = [
        parse_game_url(url)[0]
        for url in np.array(puzzles["GameUrl"])[one_move_puzzle_indices]
    ]

    existing_hashes = find_downloaded_hashes(pgn_file_path)
    missing_hashes = [gh for gh in all_game_hases if gh not in existing_hashes]

    download_count = {"count": 0}
    async with aiohttp.ClientSession() as session:
        semaphore = asyncio.Semaphore(max_concurrent_tasks)
        tasks = [
            download_and_append(
                gh, session, pgn_file_path, semaphore, backup_interval, download_count
            )
            for gh in missing_hashes
        ]
        await tqdm.gather(*tasks, total=len(tasks))


if __name__ == "__main__":
    # Set up argument parsing
    parser = argparse.ArgumentParser(
        description="Download PGN files with backup and concurrency control."
    )
    parser.add_argument(
        "--pgn_file_path",
        type=str,
        default="./data/one_move_puzzle.pgn",
        help="Path to the PGN file",
    )
    parser.add_argument(
        "--puzzle_csv_path",
        type=str,
        default="./data/lichess_db_puzzle.csv",
        help="Path to the CSV file containing the puzzles",
    )
    parser.add_argument(
        "--max_concurrent_tasks",
        type=int,
        default=1,
        help="Maximum number of concurrent download tasks",
    )
    parser.add_argument(
        "--backup_interval",
        type=int,
        default=10000000,
        help="Backup interval based on number of downloads",
    )

    args = parser.parse_args()

    asyncio.run(
        main(
            pgn_file_path=args.pgn_file_path,
            puzzle_csv_path=args.puzzle_csv_path,
            max_concurrent_tasks=args.max_concurrent_tasks,
            backup_interval=args.backup_interval,
        )
    )
