from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import time
from typing import Dict, Tuple
from ...infrastructure.cache.redis_cache import redis_cache
from ...core.config import settings
import logging

logger = logging.getLogger(__name__)

class RateLimitMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)
        self.enabled = settings.ENABLE_RATE_LIMITING
        self.requests_per_window = settings.RATE_LIMIT_REQUESTS
        self.window_size = settings.RATE_LIMIT_WINDOW
    
    async def dispatch(self, request: Request, call_next):
        if not self.enabled:
            return await call_next(request)
        
        # Skip rate limiting for health checks
        if request.url.path.startswith("/health") or request.url.path.startswith("/metrics"):
            return await call_next(request)
        
        client_ip = self._get_client_ip(request)
        
        # Check rate limit
        is_allowed, remaining_requests = await self._check_rate_limit(client_ip)
        
        if not is_allowed:
            raise HTTPException(
                status_code=429,
                detail="Rate limit exceeded. Please try again later.",
                headers={"X-RateLimit-Remaining": "0"}
            )
        
        response = await call_next(request)
        
        # Add rate limit headers
        response.headers["X-RateLimit-Limit"] = str(self.requests_per_window)
        response.headers["X-RateLimit-Remaining"] = str(remaining_requests)
        response.headers["X-RateLimit-Window"] = str(self.window_size)
        
        return response
    
    def _get_client_ip(self, request: Request) -> str:
        """Get client IP address"""
        # Check for forwarded headers first
        forwarded_for = request.headers.get("X-Forwarded-For")
        if forwarded_for:
            return forwarded_for.split(",")[0].strip()
        
        real_ip = request.headers.get("X-Real-IP")
        if real_ip:
            return real_ip
        
        return request.client.host
    
    async def _check_rate_limit(self, client_ip: str) -> Tuple[bool, int]:
        """Check if client is within rate limit"""
        try:
            current_time = int(time.time())
            window_start = current_time - (current_time % self.window_size)
            
            key = f"rate_limit:{client_ip}:{window_start}"
            
            # Get current count
            current_count = await redis_cache.get(key) or 0
            
            if current_count >= self.requests_per_window:
                return False, 0
            
            # Increment counter
            new_count = await redis_cache.increment(key)
            
            # Set expiry if this is the first request in the window
            if new_count == 1:
                await redis_cache.set(key, new_count, ttl=self.window_size)
            
            remaining = max(0, self.requests_per_window - new_count)
            return True, remaining
            
        except Exception as e:
            logger.error(f"Error checking rate limit: {e}")
            # Allow request if rate limiting fails
            return True, self.requests_per_window
