import logging
from typing import Callable, Generic, List

from typing_extensions import ParamSpec  # Python 3.10+

logger = logging.getLogger(__name__)
P = ParamSpec("P")


class CallbackRegistry(Generic[P]):
    def __init__(self, name: str):
        self.name = name
        self.callback_list: List[Callable[P, None]] = []

    def add_callback(self, cb: Callable[P, None]) -> None:
        self.callback_list.append(cb)

    def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
        for cb in self.callback_list:
            try:
                cb(*args, **kwargs)
            except Exception as e:
                logger.exception(
                    "Exception in callback for %s registered with CUDA trace", self.name
                )


CUDAEventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
    "CUDA event creation"
)
CUDAEventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
    "CUDA event deletion"
)
CUDAEventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
    "CUDA event record"
)
CUDAEventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
    "CUDA event wait"
)
CUDAMemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
    "CUDA memory allocation"
)
CUDAMemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
    "CUDA memory deallocation"
)
CUDAStreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
    "CUDA stream creation"
)
CUDADeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
    "CUDA device synchronization"
)
CUDAStreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
    "CUDA stream synchronization"
)
CUDAEventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
    "CUDA event synchronization"
)


def register_callback_for_cuda_event_creation(cb: Callable[[int], None]) -> None:
    CUDAEventCreationCallbacks.add_callback(cb)


def register_callback_for_cuda_event_deletion(cb: Callable[[int], None]) -> None:
    CUDAEventDeletionCallbacks.add_callback(cb)


def register_callback_for_cuda_event_record(cb: Callable[[int, int], None]) -> None:
    CUDAEventRecordCallbacks.add_callback(cb)


def register_callback_for_cuda_event_wait(cb: Callable[[int, int], None]) -> None:
    CUDAEventWaitCallbacks.add_callback(cb)


def register_callback_for_cuda_memory_allocation(cb: Callable[[int], None]) -> None:
    CUDAMemoryAllocationCallbacks.add_callback(cb)


def register_callback_for_cuda_memory_deallocation(cb: Callable[[int], None]) -> None:
    CUDAMemoryDeallocationCallbacks.add_callback(cb)


def register_callback_for_cuda_stream_creation(cb: Callable[[int], None]) -> None:
    CUDAStreamCreationCallbacks.add_callback(cb)


def register_callback_for_cuda_device_synchronization(cb: Callable[[], None]) -> None:
    CUDADeviceSynchronizationCallbacks.add_callback(cb)


def register_callback_for_cuda_stream_synchronization(
    cb: Callable[[int], None]
) -> None:
    CUDAStreamSynchronizationCallbacks.add_callback(cb)


def register_callback_for_cuda_event_synchronization(cb: Callable[[int], None]) -> None:
    CUDAEventSynchronizationCallbacks.add_callback(cb)
