# api/tasks.py
from celery import shared_task
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
import librosa
import moviepy.editor as mp
from langdetect import detect

from .models import Video


@shared_task
def generate_subtitles(video_id):
    video = Video.objects.get(id=video_id)
    video.status = 'PROCESSING'
    video.save()

    try:
        # Extract audio from video
        video_clip = mp.VideoFileClip(video.file.path)
        audio_path = f"{video.file.path}.wav"
        video_clip.audio.write_audiofile(audio_path)

        # Load and process audio
        processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
        model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")

        speech, rate = librosa.load(audio_path, sr=16000)
        input_values = processor(speech, return_tensors="pt", padding="longest").input_values
        logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]

        # Detect language
        detected_language = detect(transcription)
        video.detected_language = detected_language

        # Split transcription into segments
        words = transcription.split()
        segments = []
        segment_size = 10  # Number of words per segment
        for i in range(0, len(words), segment_size):
            segments.append(' '.join(words[i:i + segment_size]))

        video.subtitles = {'segments': segments}
        video.status = 'COMPLETED'
    except Exception as e:
        video.status = 'FAILED'
        video.subtitles = str(e)

    video.save()
