import os
import cv2
import json
import math 
import numpy as np
from datetime import datetime

from celery import Task

# import face_recognition 
from deepface import DeepFace
from sklearn.cluster import DBSCAN

from app.core.config import get_config
from app.tasks.worker import celery_app


settings = get_config()
 
 
@celery_app.task(ignore_result=False, bind=True, base=Task)
def process_faces(self, video_file_path):
    """
    Process a video file to extract face data from each frame.

    Args:
        video_file_path (str): Path to the input video file.

    Returns:
        dict or None: A dictionary containing face data extracted from each frame of the video.
            The dictionary structure is as follows:
            {
                'face_data': [
                    {
                        'id': <frame_id>,
                        'time': <time_in_seconds>,
                        'faces': [
                            {
                                'age': <age>,
                                'region': {
                                    'x': <x_coordinate>,
                                    'y': <y_coordinate>,
                                    'w': <width>,
                                    'h': <height>,
                                    'left_eye': <left_eye_coordinates>,
                                    'right_eye': <right_eye_coordinates>
                                },
                                'confidence': <face_confidence>,
                                'gender': <gender>,
                                'emotion': <emotion>,
                                'race': <race>
                            },
                            ...
                        ]
                    },
                    ...
                ]
            }
        If an error occurs during processing, returns None.
    """
    
    try:
        # Load the video
        start_time = datetime.now()
        video = cv2.VideoCapture(video_file_path)
        # Get frames per second (FPS) of the video
        fps = math.ceil(video.get(cv2.CAP_PROP_FPS))
        # Get total number of frames in the video
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

        # Initialize face data and face embeddings list
        face_data = []
        face_embeddings = []

        # Initialize frame count
        frame_count = 0 
        # Process each frame
        while True:
            # Read the next frame
            ret, frame = video.read() 

            if frame_count % 25 == 0:            
                # Check if there are no more frames
                if not ret:
                    break 
                # Detect faces in the frame
                faces_dict, face_encodings = extract_faces(
                    frame = frame, 
                    frame_id = frame_count, 
                    save_faces = True
                ) 
                if faces_dict :
                    # Construct the frame dictionary
                    frame_dict = {
                        "id": frame_count,
                        "time": frame_count / fps,
                        "faces": faces_dict
                    } 
                    # Append the frame dictionary to the face data list
                    face_data.append(frame_dict)  
                    if face_encodings :
                        # Append the face encodings to the face embeddings list
                        for face_encoding in face_encodings:
                            if "embedding" in face_encoding and face_encoding['embedding'] is not None:
                                face_embeddings.append(face_encoding)
                    # 
                    print(f"Processed {frame_count}/{total_frames}")

            # Increment frame count
            frame_count += 1
        # Release the video capture object
        video.release()

        # Create a unique name for the JSON file
        video_filename = os.path.basename(video_file_path).split(".")[0] 
        json_filename = f'faces_{video_filename}_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.json'
        json_file_path = os.path.join(settings.UPLOAD_DIR, json_filename)

        # Save face_data to the JSON file
        with open(json_file_path, 'w') as json_file:
            json.dump({'face_data': face_data}, json_file)

        # Cluster detected faces
        cluster_faces(
            face_embeddings = face_embeddings, 
            output_dir = f"{settings.UPLOAD_DIR}/faces/clustering",
            min_samples = 10,
            metric = "cosine"
        ) 
        end_time = datetime.now()
        # Return the face data as a dictionary
        return {
            'face_data': face_data,
            'process_time': (end_time-start_time).total_seconds()
            }

    except Exception as e:
        # Handle any errors that occurred during processing
        print(f"Error processing video file: {e}")
        return None


def cluster_faces(face_embeddings, output_dir, eps=0.5, min_samples=5, metric="euclidean", include_noise=True):
    """
    Cluster detected faces based on their embeddings and save cropped face images in cluster directories.

    Args:
        face_embeddings (list): List of dictionaries containing face embeddings and metadata.
        output_dir (str): Path to the directory where the clusters and cropped face images will be saved.
        eps (float, optional): The maximum distance between two samples for one to be considered as in the neighborhood of the other.
        min_samples (int, optional): The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
        metric (str, optional): The distance metric to use for clustering. Default is "euclidean".
        include_noise (bool, optional): Whether to include noise points (labeled as -1) in the clustering results.
    """
    try: 
        # Extract face embeddings
        faces_encoding = [face_encode["embedding"] for face_encode in face_embeddings]
        
        # Perform DBSCAN clustering on face embeddings
        cluster = DBSCAN(
            eps=eps, 
            min_samples=min_samples, 
            metric=metric
        )
        cluster.fit(faces_encoding)

        labelIDs = np.unique(cluster.labels_)
        # Process clustered labels
        for label, embeddings_info in zip(cluster.labels_, face_embeddings):
            if not include_noise and label == -1:
                # Skip noise point if include_noise is False
                continue

            # Create cluster directory if it doesn't exist
            cluster_size = np.sum(cluster.labels_ == label) if label != -1 else 1
            cluster_dir = os.path.join(output_dir, f"Cluster_{label}_Size_{cluster_size}")
            os.makedirs(cluster_dir, exist_ok=True) 

            # Get the current date and time
            current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            # Save face images in the cluster directory 
            frame_id = embeddings_info["frame_id"] 
            cropped_face = embeddings_info["face"] 
            # Save the cropped face with the date and time appended to the filename
            cv2.imwrite(
                os.path.join(cluster_dir, f"Person_{frame_id}_{current_datetime}.jpg"),
                cropped_face
            ) 
    
    except Exception as e:
        print(f"Error clustering faces: {e}")
 

def extract_faces(frame, frame_id, face_confidence_threshold=0.990, save_faces=False):
    """
    Extract faces from a given frame/image.

    Args:
        frame (str): Path to the input image/frame.
        frame_id (int): Unique identifier for the frame.
        face_confidence_threshold (float, optional): Confidence threshold for face detection. 
            Faces with confidence below this threshold will be ignored. Defaults to 0.990.
        save_faces (bool, optional): Whether to save the extracted faces to disk. Defaults to False.

    Returns:
        Tuple[List[Dict], List[numpy.ndarray]]: A tuple containing:
            - List of dictionaries, each representing a detected face with its attributes.
            - List of face embeddings (numpy arrays) corresponding to the detected faces.

    Raises:
        Exception: If there is an error during face detection or face embedding computation.
    """
    try:
        # Detect faces in the frame
        detected_faces = DeepFace.analyze(
            img_path=frame,  # Provide the path to the input image/frame
            actions=["age"],  # Specify the facial attributes to analyze
            detector_backend="retinaface",  # Use RetinaFace for face detection
            enforce_detection=False,  # Allow detection even if faces are not found
            align=True  # Align detected faces
        )

        # Check if no faces are detected
        if not detected_faces:
            return None, None

        # Prepare data structures for face embeddings and face data
        face_encodings = []  # List to store face embeddings
        face_dicts = []  # List to store dictionaries representing detected faces

        # Loop through detected faces
        for face in detected_faces:
            # Check if face confidence meets the threshold
            if face["face_confidence"] >= face_confidence_threshold:
                # Extract face region coordinates
                x, y, w, h = face["region"]["x"], face["region"]["y"], face["region"]["w"], face["region"]["h"]

                # Set person name to Unknown by default
                person_name = "Unknown"
                # Set person name to Person by default
                person_name = "Person"

                # Check if face dimensions meet the threshold for skipping recognition
                if w > 26 and h > 26:
                    # Crop the face from the frame
                    cropped_face = frame[y:y+h, x:x+w]

                    # Find similar faces in the database
                    face_reco = DeepFace.find(
                        img_path = cropped_face, 
                        db_path = "data/faces/database", 
                        model_name = "Facenet512", 
                        distance_metric = "cosine",
                        enforce_detection = False,
                        detector_backend = "retinaface"
                    )
        

                    if face_reco and len(face_reco) >= 1:
                        # Sort face_reco based on distance for each element
                        face_reco_sorted = sorted(face_reco, key=lambda x: x['distance'])
                        # Get the path of the most similar face from the database 
                        # Check if the detected face is found in the database
                        if not face_reco_sorted[0].empty and face_reco_sorted[0]['identity'].notna().any(): 
                            # Get the distance between the detected face and the most similar face in the database
                            distance = face_reco_sorted[0]['distance'].iloc[0]
                            print(f"Distance between faces: {distance}")
                            if distance <= 0.13:
                                print("Similar face found in the database!")
                                # Extract name from file path
                                person_name = (os.path.basename(face_reco_sorted[0]['identity'].iloc[0]).split(".")[0]).split("_")[0]
                        else:
                            # Update person name to Person if face is not found in the database
                            print("No similar face found in the database.")

                    # Compute face embeddings
                    face_embeddings = DeepFace.represent(
                        img_path=cropped_face,
                        enforce_detection=False,
                        model_name="Facenet512"
                    )

                    # Append the face embedding to the list of face encodings
                    # Include the frame_id and face_index from face_dicts for reference
                    face_encodings.append({
                        "embedding": face_embeddings[0]['embedding'],
                        "frame_id": frame_id,
                        "face": cropped_face
                    })

                    # Save faces to disk if required
                    if save_faces:
                        # Get the current date and time
                        current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 
                        faces_path = os.path.join(f"{settings.UPLOAD_DIR}/faces/detection", f"{person_name}")
                        os.makedirs(faces_path, exist_ok=True) 
                        # Save the cropped face with the date and time appended to the filename
                        cv2.imwrite(
                            os.path.join(faces_path, f"{person_name}_{frame_id}_{current_datetime}.jpg"),
                            cropped_face
                        )

                # Create dictionary representing the detected face
                face_dict = {
                    "region": {
                        "x": x,
                        "y": y,
                        "w": w,
                        "h": h,
                        "left_eye": face["region"]["left_eye"],
                        "right_eye": face["region"]["right_eye"]
                    },
                    "age": face["age"] if "age" in face else None,
                    "confidence": face["face_confidence"] if "face_confidence" in face else None,
                    "gender": face["dominant_gender"] if "dominant_gender" in face else None,
                    "emotion": face["dominant_emotion"] if "dominant_emotion" in face else None,
                    "race": face["dominant_race"] if "dominant_race" in face else None,
                    "name": person_name
                }

                # Append the dictionary representing the detected face to the list of face dictionaries
                face_dicts.append(face_dict)

        # Return both the list of face dictionaries and the list of face encodings
        return face_dicts, face_encodings
    except Exception as e:
        print(f"Error processing frame {frame_id}: {e}")
        return None, None

