# -*- coding: utf-8 -*-
"""
Authentication Utilities Module

This module contains utility functions for authentication and authorization
including IP detection, user agent parsing, location detection, and security helpers.

Author: Senior Django Developer
Date: 2024
"""

import re
import json
import logging
import requests
from urllib.parse import urlparse
from django.conf import settings
from django.core.cache import cache
from django.utils import timezone
from datetime import timedelta
from user_agents import parse as parse_user_agent

# Set up logging
logger = logging.getLogger(__name__)


def get_client_ip(request):
    """
    Get the real IP address of the client from the request.
    
    This function handles various proxy configurations and headers
    to extract the actual client IP address.
    
    Args:
        request: Django HTTP request object
        
    Returns:
        str: Client IP address
    """
    # List of headers to check for real IP (in order of preference)
    ip_headers = [
        'HTTP_CF_CONNECTING_IP',  # Cloudflare
        'HTTP_X_FORWARDED_FOR',   # Standard proxy header
        'HTTP_X_REAL_IP',         # Nginx proxy
        'HTTP_X_FORWARDED',       # Alternative proxy header
        'HTTP_X_CLUSTER_CLIENT_IP',  # Cluster environments
        'HTTP_FORWARDED_FOR',     # RFC 7239
        'HTTP_FORWARDED',         # RFC 7239
        'REMOTE_ADDR',            # Direct connection
    ]
    
    for header in ip_headers:
        ip = request.META.get(header)
        if ip:
            # Handle comma-separated IPs (proxy chains)
            if ',' in ip:
                # Take the first IP (original client)
                ip = ip.split(',')[0].strip()
            
            # Validate IP format
            if is_valid_ip(ip):
                return ip
    
    # Fallback to REMOTE_ADDR
    return request.META.get('REMOTE_ADDR', '127.0.0.1')


def is_valid_ip(ip):
    """
    Validate if a string is a valid IP address (IPv4 or IPv6).
    
    Args:
        ip (str): IP address string to validate
        
    Returns:
        bool: True if valid IP, False otherwise
    """
    # IPv4 pattern
    ipv4_pattern = r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$'
    
    # IPv6 pattern (simplified)
    ipv6_pattern = r'^(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}$|^::1$|^::$'
    
    # Check for private/local IPs that should be excluded
    private_patterns = [
        r'^127\.',          # Loopback
        r'^10\.',           # Private Class A
        r'^172\.(1[6-9]|2[0-9]|3[0-1])\.',  # Private Class B
        r'^192\.168\.',     # Private Class C
        r'^169\.254\.',     # Link-local
        r'^::1$',           # IPv6 loopback
        r'^fe80:',          # IPv6 link-local
    ]
    
    if not ip or ip.lower() in ['unknown', 'none', '']:
        return False
    
    # Check if it matches IPv4 or IPv6 pattern
    if re.match(ipv4_pattern, ip) or re.match(ipv6_pattern, ip):
        # For production, you might want to exclude private IPs
        # For development, we'll allow them
        if not settings.DEBUG:
            for pattern in private_patterns:
                if re.match(pattern, ip):
                    return False
        return True
    
    return False


def get_user_agent_info(user_agent_string):
    """
    Parse user agent string to extract device, browser, and OS information.
    
    Args:
        user_agent_string (str): Raw user agent string
        
    Returns:
        dict: Parsed user agent information
    """
    if not user_agent_string:
        return {
            'device_type': 'Unknown',
            'browser': 'Unknown',
            'browser_version': 'Unknown',
            'os': 'Unknown',
            'os_version': 'Unknown',
            'is_mobile': False,
            'is_tablet': False,
            'is_pc': False,
            'is_bot': False
        }
    
    try:
        # Parse user agent using user-agents library
        user_agent = parse_user_agent(user_agent_string)
        
        # Determine device type
        device_type = 'Unknown'
        if user_agent.is_mobile:
            device_type = 'Mobile'
        elif user_agent.is_tablet:
            device_type = 'Tablet'
        elif user_agent.is_pc:
            device_type = 'Desktop'
        elif user_agent.is_bot:
            device_type = 'Bot'
        
        return {
            'device_type': device_type,
            'browser': user_agent.browser.family or 'Unknown',
            'browser_version': user_agent.browser.version_string or 'Unknown',
            'os': user_agent.os.family or 'Unknown',
            'os_version': user_agent.os.version_string or 'Unknown',
            'is_mobile': user_agent.is_mobile,
            'is_tablet': user_agent.is_tablet,
            'is_pc': user_agent.is_pc,
            'is_bot': user_agent.is_bot
        }
        
    except Exception as e:
        logger.warning(f'Error parsing user agent "{user_agent_string}": {e}')
        return {
            'device_type': 'Unknown',
            'browser': 'Unknown',
            'browser_version': 'Unknown',
            'os': 'Unknown',
            'os_version': 'Unknown',
            'is_mobile': False,
            'is_tablet': False,
            'is_pc': False,
            'is_bot': False
        }


def get_location_from_ip(ip_address):
    """
    Get geographical location information from IP address.
    
    This function uses a free IP geolocation service to get location data.
    Results are cached to avoid excessive API calls.
    
    Args:
        ip_address (str): IP address to lookup
        
    Returns:
        dict: Location information (country, city, etc.)
    """
    # Default response for invalid or private IPs
    default_location = {
        'country': 'Unknown',
        'country_code': 'XX',
        'city': 'Unknown',
        'region': 'Unknown',
        'latitude': None,
        'longitude': None,
        'timezone': 'UTC',
        'isp': 'Unknown'
    }
    
    # Skip location lookup for invalid or private IPs
    if not is_valid_ip(ip_address) or is_private_ip(ip_address):
        return default_location
    
    # Check cache first
    cache_key = f'ip_location_{ip_address}'
    cached_location = cache.get(cache_key)
    if cached_location:
        return cached_location
    
    try:
        # Use ip-api.com (free service with rate limiting)
        # For production, consider using a paid service like MaxMind GeoIP2
        response = requests.get(
            f'http://ip-api.com/json/{ip_address}',
            timeout=5,
            params={
                'fields': 'status,country,countryCode,region,city,lat,lon,timezone,isp'
            }
        )
        
        if response.status_code == 200:
            data = response.json()
            
            if data.get('status') == 'success':
                location_info = {
                    'country': data.get('country', 'Unknown'),
                    'country_code': data.get('countryCode', 'XX'),
                    'city': data.get('city', 'Unknown'),
                    'region': data.get('region', 'Unknown'),
                    'latitude': data.get('lat'),
                    'longitude': data.get('lon'),
                    'timezone': data.get('timezone', 'UTC'),
                    'isp': data.get('isp', 'Unknown')
                }
                
                # Cache the result for 24 hours
                cache.set(cache_key, location_info, 86400)
                return location_info
    
    except Exception as e:
        logger.warning(f'Error getting location for IP {ip_address}: {e}')
    
    # Cache the default result for 1 hour to avoid repeated failed lookups
    cache.set(cache_key, default_location, 3600)
    return default_location


def is_private_ip(ip_address):
    """
    Check if an IP address is private/local.
    
    Args:
        ip_address (str): IP address to check
        
    Returns:
        bool: True if private IP, False otherwise
    """
    private_patterns = [
        r'^127\.',          # Loopback
        r'^10\.',           # Private Class A
        r'^172\.(1[6-9]|2[0-9]|3[0-1])\.',  # Private Class B
        r'^192\.168\.',     # Private Class C
        r'^169\.254\.',     # Link-local
        r'^::1$',           # IPv6 loopback
        r'^fe80:',          # IPv6 link-local
        r'^localhost$',     # Localhost
    ]
    
    for pattern in private_patterns:
        if re.match(pattern, ip_address, re.IGNORECASE):
            return True
    
    return False


def is_safe_url(url, allowed_hosts=None):
    """
    Check if a URL is safe for redirection.
    
    This function prevents open redirect vulnerabilities by validating
    that the URL is safe to redirect to.
    
    Args:
        url (str): URL to validate
        allowed_hosts (list): List of allowed hosts (optional)
        
    Returns:
        bool: True if URL is safe, False otherwise
    """
    if not url:
        return False
    
    # Remove leading/trailing whitespace
    url = url.strip()
    
    # Reject URLs that start with multiple slashes (protocol-relative URLs)
    if url.startswith('//'):
        return False
    
    # Allow relative URLs that start with a single slash
    if url.startswith('/'):
        return True
    
    try:
        # Parse the URL
        parsed = urlparse(url)
        
        # Reject URLs with schemes other than http/https
        if parsed.scheme and parsed.scheme not in ['http', 'https']:
            return False
        
        # If no host is specified, it's a relative URL
        if not parsed.netloc:
            return True
        
        # Check against allowed hosts
        if allowed_hosts:
            return parsed.netloc in allowed_hosts
        
        # If no allowed hosts specified, reject external URLs
        return False
        
    except Exception:
        return False


def generate_secure_token(length=32):
    """
    Generate a cryptographically secure random token.
    
    Args:
        length (int): Length of the token to generate
        
    Returns:
        str: Secure random token
    """
    import secrets
    import string
    
    # Use URL-safe characters
    alphabet = string.ascii_letters + string.digits + '-_'
    return ''.join(secrets.choice(alphabet) for _ in range(length))


def hash_password_with_salt(password, salt=None):
    """
    Hash a password with salt using Django's password hashing.
    
    Args:
        password (str): Password to hash
        salt (str): Optional salt (if None, Django will generate one)
        
    Returns:
        str: Hashed password
    """
    from django.contrib.auth.hashers import make_password
    return make_password(password, salt)


def verify_password(password, hashed_password):
    """
    Verify a password against its hash.
    
    Args:
        password (str): Plain text password
        hashed_password (str): Hashed password to verify against
        
    Returns:
        bool: True if password matches, False otherwise
    """
    from django.contrib.auth.hashers import check_password
    return check_password(password, hashed_password)


def is_password_strong(password):
    """
    Check if a password meets strength requirements.
    
    Args:
        password (str): Password to check
        
    Returns:
        dict: Dictionary with strength check results
    """
    checks = {
        'length': len(password) >= 8,
        'uppercase': bool(re.search(r'[A-Z]', password)),
        'lowercase': bool(re.search(r'[a-z]', password)),
        'digit': bool(re.search(r'\d', password)),
        'special': bool(re.search(r'[!@#$%^&*(),.?":{}|<>]', password)),
        'no_common': not is_common_password(password)
    }
    
    checks['is_strong'] = all(checks.values())
    checks['score'] = sum(checks.values()) - 1  # Exclude 'is_strong' from score
    
    return checks


def is_common_password(password):
    """
    Check if a password is in the list of common passwords.
    
    Args:
        password (str): Password to check
        
    Returns:
        bool: True if password is common, False otherwise
    """
    # List of common passwords (in production, load from a file)
    common_passwords = {
        'password', '123456', '123456789', 'qwerty', 'abc123',
        'password123', 'admin', 'letmein', 'welcome', 'monkey',
        'dragon', 'master', 'shadow', 'superman', 'michael',
        'football', 'baseball', 'liverpool', 'jordan', 'princess'
    }
    
    return password.lower() in common_passwords


def rate_limit_key(identifier, action, window_minutes=60):
    """
    Generate a cache key for rate limiting.
    
    Args:
        identifier (str): Unique identifier (IP, user ID, etc.)
        action (str): Action being rate limited
        window_minutes (int): Time window in minutes
        
    Returns:
        str: Cache key for rate limiting
    """
    # Round down to the nearest window to create time buckets
    now = timezone.now()
    window_start = now.replace(
        minute=(now.minute // window_minutes) * window_minutes,
        second=0,
        microsecond=0
    )
    
    return f'rate_limit_{action}_{identifier}_{window_start.timestamp()}'


def is_rate_limited(identifier, action, max_attempts=5, window_minutes=60):
    """
    Check if an action is rate limited for a given identifier.
    
    Args:
        identifier (str): Unique identifier (IP, user ID, etc.)
        action (str): Action being checked
        max_attempts (int): Maximum attempts allowed
        window_minutes (int): Time window in minutes
        
    Returns:
        bool: True if rate limited, False otherwise
    """
    cache_key = rate_limit_key(identifier, action, window_minutes)
    current_attempts = cache.get(cache_key, 0)
    
    return current_attempts >= max_attempts


def increment_rate_limit(identifier, action, window_minutes=60):
    """
    Increment the rate limit counter for a given identifier and action.
    
    Args:
        identifier (str): Unique identifier (IP, user ID, etc.)
        action (str): Action being incremented
        window_minutes (int): Time window in minutes
        
    Returns:
        int: New attempt count
    """
    cache_key = rate_limit_key(identifier, action, window_minutes)
    
    try:
        # Try to increment existing counter
        new_count = cache.get(cache_key, 0) + 1
        cache.set(cache_key, new_count, window_minutes * 60)
        return new_count
    except Exception:
        # If cache operation fails, return 1 (first attempt)
        return 1


def clean_expired_tokens():
    """
    Clean up expired tokens from the database.
    
    This function should be called periodically (e.g., via a cron job)
    to remove expired tokens and keep the database clean.
    """
    from .models import PasswordResetToken, EmailVerificationToken
    
    # Delete expired password reset tokens
    expired_reset_tokens = PasswordResetToken.objects.filter(
        expires_at__lt=timezone.now()
    )
    reset_count = expired_reset_tokens.count()
    expired_reset_tokens.delete()
    
    # Delete expired email verification tokens
    expired_verification_tokens = EmailVerificationToken.objects.filter(
        expires_at__lt=timezone.now()
    )
    verification_count = expired_verification_tokens.count()
    expired_verification_tokens.delete()
    
    logger.info(
        f'Cleaned up {reset_count} expired password reset tokens '
        f'and {verification_count} expired email verification tokens'
    )
    
    return {
        'password_reset_tokens': reset_count,
        'email_verification_tokens': verification_count
    }


def log_security_event(event_type, user=None, ip_address=None, details=None):
    """
    Log security-related events for monitoring and auditing.
    
    Args:
        event_type (str): Type of security event
        user: User instance (optional)
        ip_address (str): IP address (optional)
        details (dict): Additional event details (optional)
    """
    log_data = {
        'event_type': event_type,
        'timestamp': timezone.now().isoformat(),
        'user_id': user.id if user else None,
        'user_email': user.email if user else None,
        'ip_address': ip_address,
        'details': details or {}
    }
    
    # Log as JSON for structured logging
    logger.warning(f'SECURITY_EVENT: {json.dumps(log_data)}')


def mask_sensitive_data(data, fields_to_mask=None):
    """
    Mask sensitive data in dictionaries for logging.
    
    Args:
        data (dict): Data to mask
        fields_to_mask (list): List of field names to mask
        
    Returns:
        dict: Data with sensitive fields masked
    """
    if fields_to_mask is None:
        fields_to_mask = [
            'password', 'token', 'secret', 'key', 'auth',
            'credential', 'session', 'csrf'
        ]
    
    if not isinstance(data, dict):
        return data
    
    masked_data = data.copy()
    
    for key, value in masked_data.items():
        # Check if field name contains sensitive keywords
        if any(sensitive in key.lower() for sensitive in fields_to_mask):
            if isinstance(value, str) and len(value) > 4:
                # Show first 2 and last 2 characters
                masked_data[key] = f"{value[:2]}{'*' * (len(value) - 4)}{value[-2:]}"
            else:
                masked_data[key] = '***'
        elif isinstance(value, dict):
            # Recursively mask nested dictionaries
            masked_data[key] = mask_sensitive_data(value, fields_to_mask)
    
    return masked_data