import cv2
import json
import time 
import uuid
import boto3
import logging
import datetime
from decouple import config



import multiprocessing
from multiprocessing import Pool

# Set the logging level (debug, info, warning, error, critical)
logging.basicConfig(level=logging.DEBUG)
# Create a logger object
logger = logging.getLogger()
 
# Create a file handler for the logger
file_handler = logging.FileHandler("aws_metadata_test.log")
# Create a formatter for the file handler
formatter = logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] :: %(message)s")
file_handler.setFormatter(formatter)
# Add the file handler to the logger
logger.addHandler(file_handler)
logger.removeHandler(logging.getLogger().handlers[0])

# Define the AWS region and service
AWS_REGION            = config("AWS_REGION", default="us-west-2", cast=str)
AWS_SERVICE           = config("AWS_SERVICE", default="rekognition", cast=str)

# Set the AWS access key and secret access key
AWS_ACCESS_KEY_ID     = config("AWS_ACCESS_KEY_ID", default="", cast=str)
AWS_SECRET_ACCESS_KEY = config("AWS_SECRET_ACCESS_KEY", default="", cast=str)

#🚀 working good
def execute_time(func):
    def wrapper(*args, **kwargs):
        start_time = datetime.datetime.now()
        result = func(*args, **kwargs)
        end_time = datetime.datetime.now()
        logger.info(f"Function '{func.__name__}' called with arguments: {args} {kwargs}")
        logger.info(f'Time taken by {func.__name__}: {(end_time - start_time).total_seconds()} seconds')
        return result
    return wrapper

#🚀 working good
@execute_time
def seconds_to_hhmmss(seconds):
    return (datetime.datetime(1,1,1) + datetime.timedelta(seconds=seconds)).strftime("%H:%M:%S.%f")[:-3]

# callback function
def custom_callback(result):
	print(f'Got result: {result}')

# error callback function
def custom_error_callback(error):
	print(f'Got error: {error}')


def detect_labels_in_frame(frame, results_queue, counter=0, frame_duration=None):
    # # Initialize the Amazon Rekognition client :: Connect to AWS service using boto3
    client = boto3.client(
        AWS_SERVICE,
        aws_access_key_id=AWS_ACCESS_KEY_ID,
        aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
        region_name=AWS_REGION
    )
    # Convert the image to a bytes object
    _, img_encoded = cv2.imencode('.jpg', frame)
    img_bytes = img_encoded.tobytes()
    # Call the AWS Rekognition detect_labels function
    print(f"Proccesed Frame {counter}")
    # cv2.imwrite(f"frames/{uuid.uuid4()}.png", frame)
    result = client.detect_labels(Image={'Bytes': img_bytes})
    # put result in Queue
    # print({"frame_id":counter, "frame_time":frame_duration, "respense":json.dump(result)})
    results_queue.put({"frame_id":counter, "frame_time":frame_duration, "respense":result})

@execute_time
def analyse(input_path, input_type=None, interval=None, save=True):
    results = []
    # Create a queue to store the result
    result_queue = multiprocessing.Manager().Queue()
    # Get the number of CPU cores
    num_cores = multiprocessing.cpu_count()

    if input_type == 'image':
        # Load the image
        image = cv2.imread(input_path)
        # Convert the image to a bytes object
        # Call the AWS Rekognition detect_labels function
        detect_labels_in_frame(frame=image, results_queue=result_queue)
        
        # if save:
        #     image = draw_boxes(results, image)
        #     cv2.imwrite(output_path, image)

    if input_type == 'video':
        # Read the video file into memory
        video = cv2.VideoCapture(input_path)
        # Get the number of frames
        num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        # Get 
        frame_rate = video.get(cv2.CAP_PROP_FPS)
        # Get video duration
        duration = seconds_to_hhmmss(num_frames/frame_rate)
        # Initialize the frame counter 
        counter = 0
        # 
        time_l = 0
        with Pool(num_cores) as pool:
            while video.isOpened():
                # Read the next frame from the video
                success, frame = video.read()
                # If we couldn't read a frame, exit the loop
                if not success:
                    break
                # Extract every interval-th frame
                if counter % interval == 0:
                    # Get Frame Time  
                    frame_duration = seconds_to_hhmmss(counter/frame_rate)
                    #  Add task to Pool & start processing frame
                    pool.apply_async(
                        detect_labels_in_frame, 
                        args=(frame, result_queue, counter, frame_duration), 
                        # callback=custom_callback,
                        error_callback=custom_error_callback
                    )
                    time.sleep(0.2)
                    # time_l += 1
                    # if time_l == 5 :
                    #     time_l = 0
                    #     time.sleep(0.5)
                    # detect_labels_in_frame(frame, result_queue, counter, frame_duration)
                # Increment the counter
                counter += 1

    # Get Return from the Queue
    while not result_queue.empty():
        results.append(result_queue.get())
    # Return the results
    return results
    
if __name__ == '__main__':  

    import os
    import argparse
    import mimetypes

    # Create an argument parser
    parser = argparse.ArgumentParser(description="Analyze Video By Extracting Video Content", prog="VideoAnalyzer", epilog="Thanks for using %(prog)s ! 🚀", )
    # Add the argument for the video path
    parser.add_argument("-i", "--input", type=str, help="The path to the video file", required=True)
    # Add the argument for the interval
    parser.add_argument("-frame", "--interval", type=int, default=1, help="The interval at which to extract the MetaData")
    # Add the argument for saving the result video
    parser.add_argument("-save", "--save-output", action="store_true", default=False, help="save the video with the boxes drawn on it in the same folder of the original video")
    # Add the argument for saving the result as a JSON file
    parser.add_argument("-json", "--save-json", action="store_true", default=False, help="Save the result as a JSON file")
    # Add the argument for project version
    parser.add_argument("-v", "--version", action="version", version="%(prog)s V 0.1.0 🚀")
    # Parse the command line arguments
    args = parser.parse_args()

    # 
    try:
        # 
        if not args.input:
            raise ValueError("[Error][Message : Input Path is Required, Please Provide Input Path.]")
        # 
        if not os.path.exists(args.input):
            print(f"[Error][< {args.input} >][Message : The Target Directory Doesn't Exist.]")
            raise SystemExit(1)
        # 
        if not os.path.isfile(args.input):
            print(f"[Error][< {args.input} >][Message : The Input Path Is Not A File.]")
            raise SystemExit(1)
        # 
        print(" 🚀 : Start Processing The Image/Video, Please Wait ...")

        # Start the timer
        start_time = datetime.datetime.now()
        # 
        input_type = mimetypes.guess_type(args.input)[0]
        # 
        if input_type.startswith("image"):
            input_type = "image"
        elif input_type.startswith("video"):
            input_type =  "video"
        # 
        response = analyse(args.input, input_type, args.interval , args.save_output)
        # 
        print(f" ⛏️  : Processing Image/Video Complete - File: {args.input} -  Time taken: {int((datetime.datetime.now() - start_time).total_seconds() * 1000)} ms  ")
        # 
        if args.save_json:
            # Save the response to a JSON file
            file_name = os.path.splitext(os.path.basename(args.input))[0] + '_' + str(datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
            with open(f"{file_name}.json", "w") as f:
                json.dump(response, f)
        # 
        print(" 💻 : Processing Image/Video Finished, Thanks For patience ")        
        
    # 
    except ValueError as ve:
        print(ve)
        parser.print_help()
    # 
    except Exception as e:
        print(f"An error occurred: {e}")