from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time
from functools import wraps
import logging

logger = logging.getLogger(__name__)

# Metrics
search_requests_total = Counter(
    'search_requests_total',
    'Total number of search requests',
    ['platform', 'search_type', 'status']
)

search_request_duration = Histogram(
    'search_request_duration_seconds',
    'Time spent processing search requests',
    ['platform']
)

active_search_requests = Gauge(
    'active_search_requests',
    'Number of currently active search requests'
)

kafka_messages_sent = Counter(
    'kafka_messages_sent_total',
    'Total number of Kafka messages sent',
    ['topic']
)

kafka_messages_failed = Counter(
    'kafka_messages_failed_total',
    'Total number of failed Kafka messages',
    ['topic']
)

database_operations = Counter(
    'database_operations_total',
    'Total number of database operations',
    ['operation', 'collection']
)

cache_operations = Counter(
    'cache_operations_total',
    'Total number of cache operations',
    ['operation', 'result']
)

class MetricsCollector:
    def __init__(self):
        self.start_time = time.time()
    
    def record_search_request(self, platform: str, search_type: str, status: str):
        """Record search request metrics"""
        search_requests_total.labels(
            platform=platform,
            search_type=search_type,
            status=status
        ).inc()
    
    def record_kafka_message_sent(self, topic: str):
        """Record successful Kafka message"""
        kafka_messages_sent.labels(topic=topic).inc()
    
    def record_kafka_message_failed(self, topic: str):
        """Record failed Kafka message"""
        kafka_messages_failed.labels(topic=topic).inc()
    
    def record_database_operation(self, operation: str, collection: str):
        """Record database operation"""
        database_operations.labels(
            operation=operation,
            collection=collection
        ).inc()
    
    def record_cache_operation(self, operation: str, result: str):
        """Record cache operation"""
        cache_operations.labels(
            operation=operation,
            result=result
        ).inc()
    
    def update_active_requests(self, count: int):
        """Update active requests gauge"""
        active_search_requests.set(count)

metrics_collector = MetricsCollector()

def track_time(metric_name: str):
    """Decorator to track execution time"""
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            start_time = time.time()
            try:
                result = await func(*args, **kwargs)
                return result
            finally:
                duration = time.time() - start_time
                if metric_name == 'search_request':
                    platform = kwargs.get('platform', 'unknown')
                    search_request_duration.labels(platform=platform).observe(duration)
        return wrapper
    return decorator

async def start_metrics_server(port: int = 8001):
    """Start Prometheus metrics server"""
    try:
        start_http_server(port)
        logger.info(f"Metrics server started on port {port}")
    except Exception as e:
        logger.error(f"Error starting metrics server: {e}")
