import os

import torch
import torch.distributed as dist
import torch.nn as nn

from ploraq import bitnet158

# from bitnet import BitLinear


def get_proj_update_steps(
    proj_gap_progression, proj_gap, num_iters, increment_size, warm_up, max_proj_gap
):
    steps = [0]
    current_step = proj_gap
    current_gap = proj_gap
    step_count = 0

    if proj_gap_progression == "static":
        while current_step <= num_iters:
            steps.append(current_step)
            current_step += proj_gap
    elif proj_gap_progression == "linear":
        increment = increment_size
        while current_step <= num_iters:
            steps.append(current_step)
            if current_step >= warm_up:
                current_gap += increment
                step_count += 1
            current_step += int(current_gap)
    elif proj_gap_progression == "exponential":
        while current_step <= num_iters:
            steps.append(current_step)
            if current_step >= warm_up:
                step_count += 1
                current_step += proj_gap + int((increment_size) ** step_count)
            else:
                current_step += proj_gap

    if max_proj_gap != 0:
        steps = [min(step, max_proj_gap) for step in steps]

    return steps


def get_model(model):
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        return model.module
    return model


def quantize_linear_to_1bit(W, use_bias, dtype, use_bitnet158=True):
    # Create an instance of BitLinear with the same in/out features and bias as W
    # if use_bitnet158:
    linear_1bit = bitnet158.BitLinear158(W.in_features, W.out_features, bias=use_bias)
    # else:
    #    linear_1bit = BitLinear(W.in_features, W.out_features, bias=use_bias)

    # Set as bfloat16
    linear_1bit.to(dtype=dtype)

    # Copy the weight from W to linear_1bit
    linear_1bit.weight = nn.Parameter(W.weight.data, requires_grad=False)
    if use_bias:
        linear_1bit.bias = nn.Parameter(W.bias.data, requires_grad=True)

    return linear_1bit


def create_zero_initialized_linear_layer(
    input_features, output_features, use_bias, device, dtype=None
):
    # Create a linear layer with specified input and output features and bias setting
    linear_layer = nn.Linear(input_features, output_features, bias=use_bias)

    # Initialize the weights and biases of the layer to zero
    nn.init.constant_(linear_layer.weight, 0)
    if use_bias:
        nn.init.constant_(linear_layer.bias, 0)

    # Move the layer to the specified device and convert to the specified data type
    return linear_layer.to(device).to(dtype=dtype)


# Function to check if the dataset is already downloaded
def is_dataset_downloaded(path):
    # Check if the directory exists and has dataset files
    return os.path.exists(path) and len(os.listdir(path)) > 0


def compare_parameters(model):
    for name, param in model.named_parameters():
        # Gather parameters from all processes to process 0
        gathered_param = [torch.zeros_like(param) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_param, param)

        if dist.get_rank() == 0:  # Let's do the comparison on process 0
            reference = gathered_param[0]
            for rank, tensor in enumerate(gathered_param):
                if not torch.equal(reference, tensor):
                    print(
                        f"Mismatch found in parameter {name} between rank 0 and rank {rank}"
                    )


def broadcast_parameters(model, rank, root=0):
    try:
        for param in model.parameters():
            if param.device != torch.device(f"cuda:{rank}"):
                param.data = param.data.contiguous().to(f"cuda:{rank}")
            else:
                param.data = param.data.contiguous()
            dist.broadcast(param.data, src=root)
        dist.barrier()
    except Exception as e:
        print(f"Rank {rank} encountered an error: {e}")


def log_tensor_statistics(param, name, rank):
    mean_val = torch.mean(param.data).item()
    max_val = torch.max(param.data).item()
    min_val = torch.min(param.data).item()
    std_val = torch.std(param.data).item()
    print(
        f"Rank {rank}, Tensor {name} - Mean: {mean_val}, Max: {max_val}, Min: {min_val}, Std: {std_val}"
    )


def check_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_mean = torch.mean(param.grad).item()
            if grad_mean == 0:
                print(f"Gradient for {name} is zero!")
            else:
                print(f"Gradient for {name} is non-zero and mean is {grad_mean}")
        else:
            print(f"No gradient for {name}")


def eigenH_decomposition(A, out="u", checks=False):
    out = out.lower()
    if out == "u":
        symmetric_input = A @ A.T  # Form A * A^H to make it symmetric
    elif out == "v":
        symmetric_input = A.T @ A
    else:
        raise ValueError("Invalid output type. Choose 'u' or 'v'.")
    res = torch.linalg.eigh(symmetric_input)

    # ascending to descending
    # eigenvalues = res.eigenvalues.flip(0)
    eigenvectors = res.eigenvectors.flip(1)

    if checks:
        diff_unitary = torch.norm(
            eigenvectors @ eigenvectors.mH - torch.eye(A.shape[0], device=A.device),
            "fro",
        )
        print("Difference from unitary:", diff_unitary)

    return eigenvectors  # , eigenvalues


def load_model_from_checkpoint(directory, model):
    if not os.path.exists(directory):
        raise FileNotFoundError(f"The directory {directory} does not exist.")

    files = os.listdir(directory)
    bin_file = "pytorch_model.bin"
    safetensors_file = "model.safetensors"

    checkpoint_path = None
    if bin_file in files:
        checkpoint_path = os.path.join(directory, bin_file)
        model = torch.load(checkpoint_path)
        print(f"Loaded model from {checkpoint_path}")
    elif safetensors_file in files:
        print(model)
        checkpoint_path = os.path.join(directory, safetensors_file)
        # Ensure the path is absolute
        abs_path = os.path.abspath(checkpoint_path)
        print("checkpoint_path", checkpoint_path)
        print("abs_path", abs_path)
        model = model.from_pretrained(directory)
        print(f"Loaded model from {abs_path}")
    else:
        raise FileNotFoundError(
            "No compatible checkpoint file found (.bin or .safetensors)."
        )

    return model


def filter_target_modules(model, target_modules_list):
    model_modules = dict(model.named_modules())
    filtered_modules = [
        module
        for module in target_modules_list
        if any(module in name for name in model_modules)
    ]
    return filtered_modules


def filter_linear_target_modules(model, target_modules_list):
    filtered_modules = []
    for name, module in model.named_modules():
        if any(target in name for target in target_modules_list) and isinstance(
            module, nn.Linear
        ):
            filtered_modules.append(name)
    return filtered_modules
