import asyncio
import time
from typing import List, Dict, Any
from datetime import datetime
import logging
from urllib.parse import quote_plus
import re

from ...shared.selenium_base.selenium_manager import SeleniumManager
from ...shared.kafka_client.kafka_client import SearchKafkaClient
from ...shared.common_models.search_models import SearchResult, SearchResponse
from ..core.config import settings

logger = logging.getLogger(__name__)

class RedditSearchService:
    def __init__(self):
        self.kafka_client = SearchKafkaClient(
            bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS,
            consumer_group=settings.KAFKA_CONSUMER_GROUP
        )
    
    async def process_search_task(self, message: Dict[str, Any]):
        """Process a Reddit search task"""
        try:
            task_id = message.get("id")
            keywords = message.get("keywords", [])
            
            logger.info(f"Processing Reddit search task: {task_id}")
            
            start_time = time.time()
            results = await self.search_reddit(keywords)
            processing_time = time.time() - start_time
            
            response = SearchResponse(
                task_id=task_id,
                platform="reddit",
                keywords=keywords,
                results=results,
                total_results=len(results),
                processing_time=processing_time,
                timestamp=datetime.utcnow()
            )
            
            await self.kafka_client.send_search_result(
                task_id=task_id,
                platform="reddit",
                results=response.dict()
            )
            
            logger.info(f"Reddit search completed: {task_id}, found {len(results)} results")
            
        except Exception as e:
            logger.error(f"Error processing Reddit search task: {e}")
            await self.kafka_client.send_error_result(
                task_id=message.get("id", "unknown"),
                platform="reddit",
                error=str(e)
            )
    
    async def search_reddit(self, keywords: List[str], max_results: int = 20) -> List[SearchResult]:
        """Perform Reddit search"""
        results = []
        
        async with SeleniumManager(headless=True).driver_context() as selenium:
            try:
                query = " ".join(keywords)
                encoded_query = quote_plus(query)
                search_url = f"https://www.reddit.com/search/?q={encoded_query}&sort=relevance"
                
                logger.info(f"Searching Reddit for: {query}")
                
                await selenium.navigate_to(search_url)
                await selenium.wait_for_page_load()
                await asyncio.sleep(3)
                
                # Handle Reddit's "Continue" button if it appears
                continue_button = await selenium.find_element("button[data-testid='continue-button']")
                if continue_button:
                    continue_button.click()
                    await asyncio.sleep(2)
                
                # Scroll to load more results
                for _ in range(3):
                    await selenium.scroll_page(1000)
                    await asyncio.sleep(1)
                
                # Extract post results
                post_elements = await selenium.find_elements("[data-testid='post-container']")
                
                for element in post_elements[:max_results]:
                    try:
                        result = await self._extract_post_result(element)
                        if result:
                            results.append(result)
                    except Exception as e:
                        logger.warning(f"Error extracting post result: {e}")
                        continue
                
                logger.info(f"Extracted {len(results)} results from Reddit")
                
            except Exception as e:
                logger.error(f"Error during Reddit search: {e}")
                raise
        
        return results
    
    async def _extract_post_result(self, element) -> SearchResult:
        """Extract post result from element"""
        try:
            # Extract title
            title_element = element.find_element("css selector", "[data-testid='post-content'] h3")
                        title = title_element.text if title_element else "No title"
            
            # Extract URL
            link_element = element.find_element("css selector", "[data-testid='post-content'] a")
            url = link_element.get_attribute("href") if link_element else ""
            
            # Extract post content/description
            content_element = element.find_element("css selector", "[data-testid='post-content'] p")
            description = content_element.text if content_element else ""
            
            # Extract subreddit
            subreddit_element = element.find_element("css selector", "[data-testid='subreddit-name']")
            subreddit = subreddit_element.text if subreddit_element else ""
            
            # Extract author
            author_element = element.find_element("css selector", "[data-testid='post_author_link']")
            author = author_element.text if author_element else ""
            
            # Extract upvotes
            upvotes_element = element.find_element("css selector", "[data-testid='vote-arrows'] button")
            upvotes = upvotes_element.get_attribute("aria-label") if upvotes_element else "0"
            
            # Extract timestamp
            time_element = element.find_element("css selector", "[data-testid='post_timestamp']")
            post_time = time_element.text if time_element else ""
            
            return SearchResult(
                title=title,
                url=url,
                description=description,
                snippet=description,
                author=author,
                source="reddit",
                metadata={
                    "platform": "reddit",
                    "subreddit": subreddit,
                    "upvotes": upvotes,
                    "post_time": post_time,
                    "extracted_at": datetime.utcnow().isoformat()
                }
            )
            
        except Exception as e:
            logger.warning(f"Error extracting post result: {e}")
            return None
