import time
from threading import Thread

import GPUtil
import torch


def optimizer_memory_usage_in_MB(optimizer):
    """
    Calculates the memory usage of a PyTorch optimizer's state dict in megabytes (MB).
    """
    optimizer_state = optimizer.state_dict()
    total_size_bytes = 0
    for state in optimizer_state["state"].values():
        for k, v in state.items():
            if torch.is_tensor(v):
                total_size_bytes += v.numel() * v.element_size()
            elif isinstance(v, list):
                for item in v:
                    if torch.is_tensor(item):
                        total_size_bytes += item.numel() * item.element_size()
    memory_usage_MB = total_size_bytes / (1024 ** 2)  # Convert bytes to MB
    gradient_memory = sum(
        p.numel() * p.element_size()
        for group in optimizer.param_groups
        for p in group["params"]
        if p.requires_grad
    ) / (1024 ** 2)
    return memory_usage_MB, gradient_memory


def layer_wise_memory_usage_in_MB(optimizer_dict, scheduler_dict):
    memory = 0
    for val in optimizer_dict.values():
        for state in val.state_dict()["state"].values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    memory += v.numel() * v.element_size()
                elif isinstance(v, list):
                    for item in v:
                        if torch.is_tensor(item):
                            memory += item.numel() * item.element_size()
    return memory / (1024 ** 2)


def galore_optim_memory_usage_in_MB(optimizer):
    memory = 0
    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                memory += v.numel() * v.element_size()

    return memory / (1024 ** 2)


def model_memory_usage_in_MB(model):
    total_size_bytes = sum([p.numel() * p.element_size() for p in model.parameters()])
    gradients = sum(
        [p.numel() * p.element_size() for p in model.parameters() if p.requires_grad]
    ) / (1024 ** 2)
    memory_usage_MB = total_size_bytes / (1024 ** 2)
    return memory_usage_MB, gradients


# # Assuming `model` and `optimizer` are already defined
# model_memory = model_memory_usage_in_MB(model)
# optimizer_memory = optimizer_memory_usage_in_MB(optimizer)

# print(f"Model Memory Usage: {model_memory:.2f} MB")
# print(f"Optimizer Memory Usage: {optimizer_memory:.2f} MB")


class Monitor(Thread):
    def __init__(self, delay):
        super(Monitor, self).__init__()
        self.stopped = False
        self.delay = delay  # Time between calls to GPUtil
        self.max_memory_used = 0  # Initialize max memory usage
        self.start()
        time.sleep(20)

    def run(self):
        while not self.stopped:
            GPUs = GPUtil.getGPUs()
            for GPU in GPUs:
                current_memory_used = GPU.memoryUsed
                if current_memory_used > self.max_memory_used:
                    self.max_memory_used = current_memory_used
            time.sleep(self.delay)

    def stop(self):
        self.stopped = True

    def get_max_memory_used(self):
        return self.max_memory_used


def get_gpu_metrics_nvitop(this_process, suffix=""):
    # Update the GPU status
    this_process.update_gpu_status()

    # Retrieve GPU metrics
    gpu_metrics = {
        f"device{suffix}/memory_used_MB": f"{float(this_process.device.memory_used()) / (1 << 20):.2f} MB",  # Convert bytes to MiBs
        f"device{suffix}/memory_percent": f"{this_process.device.memory_percent():.2f} %",
        f"device{suffix}/memory_utilization": f"{this_process.device.memory_utilization():.2f} %",
        f"device{suffix}/gpu_utilization": f"{this_process.device.gpu_utilization():.2f} %",
        f"process{suffix}/cpu_percent": f"{this_process.cpu_percent():.2f} %",
        f"process{suffix}/memory_percent": f"{this_process.memory_percent():.2f} %",
        f"process{suffix}/used_gpu_memory_MB": f"{float(this_process.gpu_memory()) / (1 << 20):.2f} MB",  # Convert bytes to MiBs
        f"process{suffix}/gpu_sm_utilization": f"{this_process.gpu_sm_utilization():.2f} %",
        f"process{suffix}/gpu_memory_utilization": f"{this_process.gpu_memory_utilization():.2f} %",
    }

    return gpu_metrics
