from aiokafka import AIOKafkaProducer, AIOKafkaConsumer
from typing import Dict, Any, Optional, Callable
import json
import logging
import asyncio
from datetime import datetime

logger = logging.getLogger(__name__)

class SearchKafkaClient:
    def __init__(self, bootstrap_servers: list, consumer_group: str):
        self.bootstrap_servers = bootstrap_servers
        self.consumer_group = consumer_group
        self.producer: Optional[AIOKafkaProducer] = None
        self.consumer: Optional[AIOKafkaConsumer] = None
        
    async def start_producer(self):
        """Start Kafka producer"""
        try:
            self.producer = AIOKafkaProducer(
                bootstrap_servers=self.bootstrap_servers,
                value_serializer=lambda v: json.dumps(v, default=str).encode('utf-8'),
                key_serializer=lambda k: k.encode('utf-8') if k else None
            )
            await self.producer.start()
            logger.info("Kafka producer started")
        except Exception as e:
            logger.error(f"Error starting Kafka producer: {e}")
            raise
    
    async def start_consumer(self, topics: list):
        """Start Kafka consumer"""
        try:
            self.consumer = AIOKafkaConsumer(
                *topics,
                bootstrap_servers=self.bootstrap_servers,
                group_id=self.consumer_group,
                value_deserializer=lambda m: json.loads(m.decode('utf-8')),
                auto_offset_reset='latest'
            )
            await self.consumer.start()
            logger.info(f"Kafka consumer started for topics: {topics}")
        except Exception as e:
            logger.error(f"Error starting Kafka consumer: {e}")
            raise
    
    async def send_message(self, topic: str, message: Dict[str, Any], key: Optional[str] = None):
        """Send message to Kafka topic"""
        try:
            if not self.producer:
                await self.start_producer()
            
            await self.producer.send(topic, value=message, key=key)
            logger.debug(f"Message sent to topic {topic}")
        except Exception as e:
            logger.error(f"Error sending message to {topic}: {e}")
            raise
    
    async def consume_messages(self, message_handler: Callable[[Dict[str, Any]], None]):
        """Consume messages from Kafka"""
        try:
            if not self.consumer:
                raise RuntimeError("Consumer not started")
            
            async for message in self.consumer:
                try:
                    await message_handler(message.value)
                except Exception as e:
                    logger.error(f"Error processing message: {e}")
        except Exception as e:
            logger.error(f"Error consuming messages: {e}")
            raise
    
    async def send_search_result(self, task_id: str, platform: str, results: Dict[str, Any], status: str = "completed"):
        """Send search results"""
        result_message = {
            "type": "task_completed" if status == "completed" else "task_failed",
            "task_id": task_id,
            "platform": platform,
            "results": results,
            "timestamp": datetime.utcnow().isoformat(),
            "status": status
        }
        
        await self.send_message("search-results", result_message, key=task_id)
    
    async def send_error_result(self, task_id: str, platform: str, error: str):
        """Send error result"""
        error_message = {
            "type": "task_failed",
            "task_id": task_id,
            "platform": platform,
            "error": error,
            "timestamp": datetime.utcnow().isoformat(),
            "status": "failed"
        }
        
        await self.send_message("search-results", error_message, key=task_id)
    
    async def stop(self):
        """Stop producer and consumer"""
        if self.producer:
            await self.producer.stop()
            logger.info("Kafka producer stopped")
        
        if self.consumer:
            await self.consumer.stop()
            logger.info("Kafka consumer stopped")
