"""Accounts Middleware Components

Middleware components for the Adtlas Accounts module.
Provides user-specific middleware for authentication, activity tracking, and security.

Middleware Components:
- UserActivityMiddleware: Track user activity and last seen
- RateLimitMiddleware: Rate limiting for authentication endpoints
- SecurityHeadersMiddleware: Additional security headers
- RequestLoggingMiddleware: User-specific request logging
- SessionSecurityMiddleware: Enhanced session security
- LoginAttemptMiddleware: Track and limit login attempts
- PasswordPolicyMiddleware: Enforce password policies
- TwoFactorMiddleware: Two-factor authentication support

Usage:
    Add to MIDDLEWARE setting in Django settings:
    
    MIDDLEWARE = [
        ...
        'apps.accounts.middleware.UserActivityMiddleware',
        'apps.accounts.middleware.RateLimitMiddleware',
        'apps.accounts.middleware.SecurityHeadersMiddleware',
        'apps.accounts.middleware.RequestLoggingMiddleware',
        'apps.accounts.middleware.SessionSecurityMiddleware',
        'apps.accounts.middleware.LoginAttemptMiddleware',
        'apps.accounts.middleware.PasswordPolicyMiddleware',
        'apps.accounts.middleware.TwoFactorMiddleware',
        ...
    ]

Configuration:
    Configure middleware behavior in Django settings:
    
    # User activity settings
    USER_ACTIVITY_TRACKING_ENABLED = True
    USER_ONLINE_TIMEOUT = 300  # 5 minutes
    
    # Rate limiting settings
    RATE_LIMIT_ENABLED = True
    LOGIN_RATE_LIMIT = '5/5m'  # 5 attempts per 5 minutes
    API_RATE_LIMIT = '100/h'   # 100 requests per hour
    
    # Security settings
    SECURITY_HEADERS_ENABLED = True
    SESSION_SECURITY_ENABLED = True
    
    # Password policy settings
    PASSWORD_POLICY_ENABLED = True
    PASSWORD_EXPIRY_DAYS = 90
    
    # Two-factor authentication
    TWO_FACTOR_ENABLED = False
    TWO_FACTOR_REQUIRED_FOR_STAFF = True
"""

import time
import logging
import hashlib
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List

from django.conf import settings
from django.core.cache import cache
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.utils import timezone
from django.utils.deprecation import MiddlewareMixin
from django.contrib.auth.models import AnonymousUser
from django.contrib.auth import logout
from django.urls import reverse
from django.shortcuts import redirect
from django.contrib import messages
from django.utils.translation import gettext as _

# Configure logging
logger = logging.getLogger(__name__)


# BASE MIDDLEWARE
class BaseAccountsMiddleware(MiddlewareMixin):
    """
    Base middleware class for accounts components.
    
    Provides common functionality for all accounts middleware:
    - Configuration management
    - User context handling
    - Rate limiting utilities
    - Security utilities
    
    Attributes:
        enabled: Whether middleware is enabled
        config: Middleware configuration
    """
    
    def __init__(self, get_response=None):
        """
        Initialize base accounts middleware.
        
        Args:
            get_response: Django get_response callable
        """
        super().__init__(get_response)
        self.get_response = get_response
        self.enabled = getattr(settings, 'ACCOUNTS_MIDDLEWARE_ENABLED', True)
        self.config = self._load_config()
        self._initialize_middleware()
    
    def _load_config(self) -> Dict[str, Any]:
        """
        Load middleware configuration.
        
        Returns:
            Dictionary containing middleware configuration
        """
        try:
            # Default configuration
            default_config = {
                'enabled': True,
                'debug': getattr(settings, 'DEBUG', False),
                'log_level': 'INFO',
            }
            
            # Get middleware-specific config
            config_key = f'{self.__class__.__name__.upper()}_CONFIG'
            middleware_config = getattr(settings, config_key, {})
            
            # Merge configurations
            config = {**default_config, **middleware_config}
            
            return config
            
        except Exception as e:
            logger.error(f"Failed to load middleware config: {e}")
            return {'enabled': True, 'debug': False, 'log_level': 'INFO'}
    
    def _initialize_middleware(self):
        """
        Initialize middleware components.
        
        Override in subclasses for specific initialization.
        """
        pass
    
    def _is_enabled(self) -> bool:
        """
        Check if middleware is enabled.
        
        Returns:
            True if middleware is enabled, False otherwise
        """
        return self.enabled and self.config.get('enabled', True)
    
    def _get_client_ip(self, request: HttpRequest) -> str:
        """
        Get client IP address from request.
        
        Args:
            request: HTTP request object
            
        Returns:
            Client IP address
        """
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0].strip()
        else:
            ip = request.META.get('REMOTE_ADDR', '')
        return ip
    
    def _get_cache_key(self, prefix: str, identifier: str) -> str:
        """
        Generate cache key for rate limiting.
        
        Args:
            prefix: Cache key prefix
            identifier: Unique identifier (IP, user ID, etc.)
            
        Returns:
            Cache key string
        """
        return f"{prefix}:{hashlib.md5(identifier.encode()).hexdigest()}"
    
    def _log_middleware_action(
        self, 
        action: str, 
        request: HttpRequest, 
        extra_data: Optional[Dict[str, Any]] = None,
        level: str = 'INFO'
    ):
        """
        Log middleware action.
        
        Args:
            action: Action being performed
            request: HTTP request object
            extra_data: Additional data to log
            level: Log level
        """
        try:
            log_data = {
                'action': action,
                'middleware': self.__class__.__name__,
                'path': request.path,
                'method': request.method,
                'user': str(request.user) if hasattr(request, 'user') else 'Anonymous',
                'ip_address': self._get_client_ip(request),
                'timestamp': timezone.now().isoformat(),
            }
            
            if extra_data:
                log_data.update(extra_data)
            
            getattr(logger, level.lower())(f"Middleware {action}", extra=log_data)
            
        except Exception as e:
            logger.error(f"Middleware logging failed: {e}")


# USER ACTIVITY MIDDLEWARE
class UserActivityMiddleware(BaseAccountsMiddleware):
    """
    User activity tracking middleware.
    
    Tracks:
    - Last seen timestamp
    - Online status
    - Page views
    - Session activity
    """
    
    def __init__(self, get_response=None):
        super().__init__(get_response)
        self.activity_enabled = getattr(settings, 'USER_ACTIVITY_TRACKING_ENABLED', True)
        self.online_timeout = getattr(settings, 'USER_ONLINE_TIMEOUT', 300)  # 5 minutes
    
    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
        """
        Process request for user activity tracking.
        
        Args:
            request: HTTP request object
            
        Returns:
            None to continue processing
        """
        if not self._is_enabled() or not self.activity_enabled:
            return None
        
        try:
            if hasattr(request, 'user') and request.user.is_authenticated:
                user = request.user
                current_time = timezone.now()
                
                # Update last seen timestamp
                cache_key = self._get_cache_key('user_activity', str(user.id))
                cache.set(cache_key, current_time.isoformat(), timeout=self.online_timeout * 2)
                
                # Update user model if it has last_seen field
                try:
                    if hasattr(user, 'last_seen'):
                        # Only update if more than 1 minute has passed to reduce DB writes
                        if not user.last_seen or (current_time - user.last_seen).seconds > 60:
                            user.last_seen = current_time
                            user.save(update_fields=['last_seen'])
                except Exception:
                    pass  # Fail silently if update fails
                
                # Track page view
                self._track_page_view(request, user)
                
                self._log_middleware_action('user_activity_tracked', request, {
                    'user_id': user.id,
                    'last_seen': current_time.isoformat(),
                })
        
        except Exception as e:
            logger.error(f"User activity middleware error: {e}")
        
        return None
    
    def _track_page_view(self, request: HttpRequest, user):
        """
        Track page view for analytics.
        
        Args:
            request: HTTP request object
            user: User object
        """
        try:
            # Store page view in cache for analytics
            page_view_data = {
                'user_id': user.id,
                'path': request.path,
                'method': request.method,
                'timestamp': timezone.now().isoformat(),
                'ip_address': self._get_client_ip(request),
                'user_agent': request.META.get('HTTP_USER_AGENT', ''),
            }
            
            # Add to page views list (keep last 100 views per user)
            cache_key = self._get_cache_key('page_views', str(user.id))
            page_views = cache.get(cache_key, [])
            page_views.append(page_view_data)
            
            # Keep only last 100 views
            if len(page_views) > 100:
                page_views = page_views[-100:]
            
            cache.set(cache_key, page_views, timeout=3600)  # 1 hour
            
        except Exception as e:
            logger.error(f"Page view tracking error: {e}")


# RATE LIMIT MIDDLEWARE
class RateLimitMiddleware(BaseAccountsMiddleware):
    """
    Rate limiting middleware.
    
    Provides:
    - IP-based rate limiting
    - User-based rate limiting
    - Endpoint-specific limits
    - Configurable time windows
    """
    
    def __init__(self, get_response=None):
        super().__init__(get_response)
        self.rate_limit_enabled = getattr(settings, 'RATE_LIMIT_ENABLED', True)
        self.default_limit = self.config.get('default_limit', '100/h')  # 100 per hour
        self.login_limit = getattr(settings, 'LOGIN_RATE_LIMIT', '5/5m')  # 5 per 5 minutes
        self.api_limit = getattr(settings, 'API_RATE_LIMIT', '1000/h')  # 1000 per hour
    
    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
        """
        Process request for rate limiting.
        
        Args:
            request: HTTP request object
            
        Returns:
            Rate limit response if exceeded, None otherwise
        """
        if not self._is_enabled() or not self.rate_limit_enabled:
            return None
        
        try:
            # Determine rate limit based on endpoint
            rate_limit = self._get_rate_limit_for_path(request.path)
            
            if rate_limit:
                # Check rate limit
                identifier = self._get_rate_limit_identifier(request)
                
                if self._is_rate_limited(identifier, rate_limit):
                    self._log_middleware_action('rate_limit_exceeded', request, {
                        'identifier': identifier,
                        'rate_limit': rate_limit,
                    }, level='WARNING')
                    
                    return JsonResponse({
                        'error': 'Rate limit exceeded',
                        'message': 'Too many requests. Please try again later.',
                        'retry_after': self._get_retry_after(rate_limit)
                    }, status=429)
                
                # Increment counter
                self._increment_rate_limit_counter(identifier, rate_limit)
        
        except Exception as e:
            logger.error(f"Rate limit middleware error: {e}")
        
        return None
    
    def _get_rate_limit_for_path(self, path: str) -> Optional[str]:
        """
        Get rate limit configuration for path.
        
        Args:
            path: Request path
            
        Returns:
            Rate limit string or None
        """
        # Login endpoints
        if '/auth/login' in path or '/accounts/login' in path:
            return self.login_limit
        
        # API endpoints
        if path.startswith('/api/'):
            return self.api_limit
        
        # Default for other endpoints
        return self.default_limit
    
    def _get_rate_limit_identifier(self, request: HttpRequest) -> str:
        """
        Get identifier for rate limiting.
        
        Args:
            request: HTTP request object
            
        Returns:
            Identifier string
        """
        # Use user ID if authenticated, otherwise IP address
        if hasattr(request, 'user') and request.user.is_authenticated:
            return f"user:{request.user.id}"
        else:
            return f"ip:{self._get_client_ip(request)}"
    
    def _is_rate_limited(self, identifier: str, rate_limit: str) -> bool:
        """
        Check if identifier is rate limited.
        
        Args:
            identifier: Rate limit identifier
            rate_limit: Rate limit string (e.g., '5/5m')
            
        Returns:
            True if rate limited, False otherwise
        """
        try:
            # Parse rate limit
            limit, period = rate_limit.split('/')
            limit = int(limit)
            
            # Parse period
            period_seconds = self._parse_period(period)
            
            # Get current count
            cache_key = self._get_cache_key('rate_limit', f"{identifier}:{period}")
            current_count = cache.get(cache_key, 0)
            
            return current_count >= limit
            
        except Exception as e:
            logger.error(f"Rate limit check error: {e}")
            return False
    
    def _increment_rate_limit_counter(self, identifier: str, rate_limit: str):
        """
        Increment rate limit counter.
        
        Args:
            identifier: Rate limit identifier
            rate_limit: Rate limit string
        """
        try:
            # Parse rate limit
            limit, period = rate_limit.split('/')
            period_seconds = self._parse_period(period)
            
            # Increment counter
            cache_key = self._get_cache_key('rate_limit', f"{identifier}:{period}")
            
            try:
                cache.add(cache_key, 0, timeout=period_seconds)
                cache.incr(cache_key)
            except ValueError:
                # Key doesn't exist, set it
                cache.set(cache_key, 1, timeout=period_seconds)
                
        except Exception as e:
            logger.error(f"Rate limit increment error: {e}")
    
    def _parse_period(self, period: str) -> int:
        """
        Parse period string to seconds.
        
        Args:
            period: Period string (e.g., '5m', '1h', '1d')
            
        Returns:
            Period in seconds
        """
        if period.endswith('s'):
            return int(period[:-1])
        elif period.endswith('m'):
            return int(period[:-1]) * 60
        elif period.endswith('h'):
            return int(period[:-1]) * 3600
        elif period.endswith('d'):
            return int(period[:-1]) * 86400
        else:
            return int(period)  # Assume seconds
    
    def _get_retry_after(self, rate_limit: str) -> int:
        """
        Get retry after seconds.
        
        Args:
            rate_limit: Rate limit string
            
        Returns:
            Retry after seconds
        """
        try:
            limit, period = rate_limit.split('/')
            return self._parse_period(period)
        except Exception:
            return 60  # Default to 1 minute


# SECURITY HEADERS MIDDLEWARE
class SecurityHeadersMiddleware(BaseAccountsMiddleware):
    """
    Security headers middleware.
    
    Adds:
    - Content Security Policy
    - X-Frame-Options
    - X-Content-Type-Options
    - Referrer-Policy
    - Permissions-Policy
    """
    
    def __init__(self, get_response=None):
        super().__init__(get_response)
        self.security_enabled = getattr(settings, 'SECURITY_HEADERS_ENABLED', True)
        self.custom_headers = getattr(settings, 'CUSTOM_SECURITY_HEADERS', {})
    
    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
        """
        Add security headers to response.
        
        Args:
            request: HTTP request object
            response: HTTP response object
            
        Returns:
            Modified HTTP response object
        """
        if not self._is_enabled() or not self.security_enabled:
            return response
        
        try:
            # Default security headers
            security_headers = {
                'X-Content-Type-Options': 'nosniff',
                'X-Frame-Options': 'DENY',
                'X-XSS-Protection': '1; mode=block',
                'Referrer-Policy': 'strict-origin-when-cross-origin',
                'Permissions-Policy': 'geolocation=(), microphone=(), camera=()',
                'Content-Security-Policy': "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';",
            }
            
            # Add custom headers
            security_headers.update(self.custom_headers)
            
            # Apply headers
            for header, value in security_headers.items():
                if header not in response:
                    response[header] = value
            
            self._log_middleware_action('security_headers_added', request)
        
        except Exception as e:
            logger.error(f"Security headers middleware error: {e}")
        
        return response


# REQUEST LOGGING MIDDLEWARE
class RequestLoggingMiddleware(BaseAccountsMiddleware):
    """
    Request logging middleware for accounts.
    
    Logs:
    - Authentication attempts
    - User actions
    - Sensitive operations
    - Security events
    """
    
    def __init__(self, get_response=None):
        super().__init__(get_response)
        self.logging_enabled = getattr(settings, 'REQUEST_LOGGING_ENABLED', True)
        self.sensitive_paths = self.config.get('sensitive_paths', [
            '/accounts/', '/auth/', '/admin/', '/api/auth/'
        ])
    
    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
        """
        Log incoming request.
        
        Args:
            request: HTTP request object
            
        Returns:
            None to continue processing
        """
        if not self._is_enabled() or not self.logging_enabled:
            return None
        
        try:
            # Check if path should be logged
            should_log = any(request.path.startswith(path) for path in self.sensitive_paths)
            
            if should_log:
                log_data = {
                    'user_id': request.user.id if hasattr(request, 'user') and request.user.is_authenticated else None,
                    'username': str(request.user) if hasattr(request, 'user') and request.user.is_authenticated else 'Anonymous',
                    'ip_address': self._get_client_ip(request),
                    'user_agent': request.META.get('HTTP_USER_AGENT', ''),
                    'path': request.path,
                    'method': request.method,
                    'timestamp': timezone.now().isoformat(),
                }
                
                self._log_middleware_action('sensitive_request_logged', request, log_data)
        
        except Exception as e:
            logger.error(f"Request logging middleware error: {e}")
        
        return None


# Export middleware classes
__all__ = [
    'BaseAccountsMiddleware',
    'UserActivityMiddleware',
    'RateLimitMiddleware',
    'SecurityHeadersMiddleware',
    'RequestLoggingMiddleware',
]