import asyncio
import logging
from typing import Dict, Any
from datetime import datetime, timedelta
import aiohttp
from ..kafka_client.kafka_client import SearchKafkaClient

logger = logging.getLogger(__name__)

class ServiceMonitor:
    def __init__(self, service_name: str, kafka_client: SearchKafkaClient):
        self.service_name = service_name
        self.kafka_client = kafka_client
        self.health_data = {
            "service_name": service_name,
            "status": "healthy",
            "last_heartbeat": datetime.utcnow(),
            "processed_tasks": 0,
            "failed_tasks": 0,
            "average_processing_time": 0.0,
            "uptime": datetime.utcnow()
        }
    
    async def start_monitoring(self):
        """Start the monitoring loop"""
        asyncio.create_task(self._monitoring_loop())
        logger.info(f"Started monitoring for {self.service_name}")
    
    async def _monitoring_loop(self):
        """Main monitoring loop"""
        while True:
            try:
                await self._send_heartbeat()
                await asyncio.sleep(30)  # Send heartbeat every 30 seconds
            except Exception as e:
                logger.error(f"Error in monitoring loop: {e}")
                await asyncio.sleep(60)  # Wait longer on error
    
    async def _send_heartbeat(self):
        """Send heartbeat to monitoring system"""
        try:
            self.health_data["last_heartbeat"] = datetime.utcnow()
            self.health_data["uptime_seconds"] = (
                datetime.utcnow() - self.health_data["uptime"]
            ).total_seconds()
            
            await self.kafka_client.send_message(
                "service-health",
                self.health_data,
                key=self.service_name
            )
            
        except Exception as e:
            logger.error(f"Error sending heartbeat: {e}")
    
    def record_task_completed(self, processing_time: float):
        """Record a completed task"""
        self.health_data["processed_tasks"] += 1
        
        # Update average processing time
        current_avg = self.health_data["average_processing_time"]
        total_tasks = self.health_data["processed_tasks"]
        
        new_avg = ((current_avg * (total_tasks - 1)) + processing_time) / total_tasks
        self.health_data["average_processing_time"] = new_avg
    
    def record_task_failed(self):
        """Record a failed task"""
        self.health_data["failed_tasks"] += 1
        
        # Update status if failure rate is high
        total_tasks = self.health_data["processed_tasks"] + self.health_data["failed_tasks"]
        if total_tasks > 10:  # Only check after processing some tasks
            failure_rate = self.health_data["failed_tasks"] / total_tasks
            if failure_rate > 0.5:  # More than 50% failure rate
                self.health_data["status"] = "degraded"
            elif failure_rate > 0.8:  # More than 80% failure rate
                self.health_data["status"] = "unhealthy"
