"""
JWT token handling
"""
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
import secrets

from jose import JWTError, jwt
from pydantic import ValidationError

from app.core.config.settings import get_settings
from app.core.exceptions.auth import InvalidTokenError, TokenExpiredError


class JWTHandler:
    """JWT token handler"""
    
    def __init__(self):
        self.settings = get_settings()
        self.algorithm = self.settings.JWT_ALGORITHM
        self.secret_key = self.settings.JWT_SECRET_KEY
        self.access_token_expire_minutes = self.settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
        self.refresh_token_expire_days = self.settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
    
    def create_access_token(
        self, 
        subject: str, 
        additional_claims: Optional[Dict[str, Any]] = None,
        expires_delta: Optional[timedelta] = None
    ) -> str:
        """Create access token"""
        if expires_delta:
            expire = datetime.utcnow() + expires_delta
        else:
            expire = datetime.utcnow() + timedelta(minutes=self.access_token_expire_minutes)
        
        to_encode = {
            "exp": expire,
            "iat": datetime.utcnow(),
            "sub": subject,
            "type": "access"
        }
        
        if additional_claims:
            to_encode.update(additional_claims)
        
        return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
    
    def create_refresh_token(
        self, 
        subject: str,
        expires_delta: Optional[timedelta] = None
    ) -> str:
        """Create refresh token"""
        if expires_delta:
            expire = datetime.utcnow() + expires_delta
        else:
            expire = datetime.utcnow() + timedelta(days=self.refresh_token_expire_days)
        
        to_encode = {
            "exp": expire,
            "iat": datetime.utcnow(),
            "sub": subject,
            "type": "refresh"
        }
        
        return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
    
    def decode_token(self, token: str) -> Dict[str, Any]:
        """Decode and validate token"""
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
            return payload
        except jwt.ExpiredSignatureError:
            raise TokenExpiredError("Token has expired")
        except JWTError:
            raise InvalidTokenError("Invalid token")
    
    def generate_reset_token(self) -> str:
        """Generate password reset token"""
        return secrets.token_urlsafe(32)
    
    def generate_verification_token(self) -> str:
        """Generate email verification token"""
        return secrets.token_urlsafe(32)
