# import os
import torch
import logging
import numpy as np
from tqdm import tqdm

import cv2
import librosa
import soundfile as sf

from lib import dataset
from lib import nets
from lib import spec_utils

import gc

# Create a logger for this file
logger = logging.getLogger(__file__)


#

class VocalRemover(object):
    def __init__(
        self,
        device,
        n_fft=2048,
        hop_length=1024,
        window_size=512,
        tta=False,
        postprocess=False,
        output_image=False,
        model_weights="weights/baseline.pth",
    ):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.tta = tta
        self.postprocess = postprocess
        self.output_image = output_image
        self.model = nets.CascadedASPPNet(n_fft)   # Librosa issue With numba
        logger.info(f"Check GPU & CUDA Avialabelity. [ status:{torch.cuda.is_available()}].")
        if torch.cuda.is_available() and device == "cuda":
            self.device = torch.device("cuda")
            self.model.to(device)
        else:
            self.device = torch.device("cpu")

        self.model.load_state_dict(torch.load(model_weights, map_location=self.device))
        self.offset = self.model.offset
        self.window_size = window_size

    def _execute(self, X_mag_pad, roi_size, n_window):
        self.model.eval()
        with torch.no_grad():
            preds = []
            for i in tqdm(range(n_window)):
                start = i * roi_size
                X_mag_window = X_mag_pad[None, :, :, start : start + self.window_size]
                X_mag_window = torch.from_numpy(X_mag_window).to(self.device)

                pred = self.model.predict(X_mag_window)

                pred = pred.detach().cpu().numpy()
                preds.append(pred[0])

            pred = np.concatenate(preds, axis=2)

        return pred

    def preprocess(self, X_spec):
        X_mag = np.abs(X_spec)
        X_phase = np.angle(X_spec)

        return X_mag, X_phase

    def inference(self, X_spec):
        X_mag, X_phase = self.preprocess(X_spec)

        coef = X_mag.max()
        X_mag_pre = X_mag / coef

        n_frame = X_mag_pre.shape[2]
        pad_l, pad_r, roi_size = dataset.make_padding(
            n_frame, self.window_size, self.offset
        )
        n_window = int(np.ceil(n_frame / roi_size))

        X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")

        pred = self._execute(X_mag_pad, roi_size, n_window)
        pred = pred[:, :, :n_frame]

        return pred * coef, X_mag, np.exp(1.0j * X_phase)

    def inference_tta(self, X_spec):
        X_mag, X_phase = self.preprocess(X_spec)

        coef = X_mag.max()
        X_mag_pre = X_mag / coef

        n_frame = X_mag_pre.shape[2]
        pad_l, pad_r, roi_size = dataset.make_padding(
            n_frame, self.window_size, self.offset
        )
        n_window = int(np.ceil(n_frame / roi_size))

        X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")

        pred = self._execute(X_mag_pad, roi_size, n_window)
        pred = pred[:, :, :n_frame]

        pad_l += roi_size // 2
        pad_r += roi_size // 2
        n_window += 1

        X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")

        pred_tta = self._execute(X_mag_pad, roi_size, n_window)
        pred_tta = pred_tta[:, :, roi_size // 2 :]
        pred_tta = pred_tta[:, :, :n_frame]

        return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.0j * X_phase)

    def split(self, X):

        logger.info("trying to fix this strange error: ['RuntimeError: CUDA error: out of memory'] ...")
        gc.collect()
        torch.cuda.empty_cache()
        logger.info("done.")

        if X.ndim == 1:
            X = np.asarray([X, X])

        logger.info("stft of wave source...")
        X = spec_utils.wave_to_spectrogram(X, self.hop_length, self.n_fft)
        logger.info("done.")
        
        # print(torch.multiprocessing.get_start_method())
        
        if self.tta:
            pred, X_mag, X_phase = self.inference_tta(X)
        else:
            pred, X_mag, X_phase = self.inference(X)

        if self.postprocess:
            logger.info("post processing...")
            pred_inv = np.clip(X_mag - pred, 0, np.inf)
            pred = spec_utils.mask_silence(pred, pred_inv)
            logger.info("done.")

        logger.info("inverse stft of vocals...")
        y_spec = np.clip(X_mag - pred, 0, np.inf) * X_phase 
        # y_spec = y_spec.astype(np.float32)
        vocals = spec_utils.spectrogram_to_wave(y_spec, hop_length=self.hop_length)
        logger.info("done.")

        logger.info("inverse stft of instruments...")
        v_spec = pred * X_phase  
        # v_spec = v_spec.astype(np.float32)
        instruments = spec_utils.spectrogram_to_wave(v_spec, hop_length=self.hop_length)
        logger.info("done.")

        # sr = 44100
        # sf.write("/media/Vocals.wav", vocals.T, sr)
        # sf.write("/content/Instruments.wav", instruments.T, sr)

        if self.output_image:
            # Need Worrk To Save Images
            vocals_img = spec_utils.spectrogram_to_image(y_spec)
            instruments_img = spec_utils.spectrogram_to_image(v_spec)

        return vocals.T, instruments.T
