from __future__ import annotations

import itertools as it

from abc import abstractmethod
from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Tuple, Union

import torch

from flashlight.lib.text.decoder import (
    CriterionType as _CriterionType,
    LexiconDecoder as _LexiconDecoder,
    LexiconDecoderOptions as _LexiconDecoderOptions,
    LexiconFreeDecoder as _LexiconFreeDecoder,
    LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
    LM as _LM,
    LMState as _LMState,
    SmearingMode as _SmearingMode,
    Trie as _Trie,
    ZeroLM as _ZeroLM,
)
from flashlight.lib.text.dictionary import (
    create_word_dict as _create_word_dict,
    Dictionary as _Dictionary,
    load_words as _load_words,
)
from torchaudio.utils import download_asset

try:
    from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
except Exception:
    try:
        from flashlight.lib.text.decoder import KenLM as _KenLM
    except Exception:
        _KenLM = None

__all__ = [
    "CTCHypothesis",
    "CTCDecoder",
    "CTCDecoderLM",
    "CTCDecoderLMState",
    "ctc_decoder",
    "download_pretrained_files",
]

_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])


def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
    vocab_size = tokens_dict.index_size()
    trie = _Trie(vocab_size, silence)
    start_state = lm.start(False)

    for word, spellings in lexicon.items():
        word_idx = word_dict.get_index(word)
        _, score = lm.score(start_state, word_idx)
        for spelling in spellings:
            spelling_idx = [tokens_dict.get_index(token) for token in spelling]
            trie.insert(spelling_idx, word_idx, score)
    trie.smear(_SmearingMode.MAX)
    return trie


def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
    word_dict = None
    if lm_dict is not None:
        word_dict = _Dictionary(lm_dict)

    if lexicon and word_dict is None:
        word_dict = _create_word_dict(lexicon)
    elif not lexicon and word_dict is None and type(lm) == str:
        d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
        d[unk_word] = [[unk_word]]
        word_dict = _create_word_dict(d)

    return word_dict


class CTCHypothesis(NamedTuple):
    r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`."""
    tokens: torch.LongTensor
    """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""

    words: List[str]
    """List of predicted words.

    Note:
        This attribute is only applicable if a lexicon is provided to the decoder. If
        decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and
        :func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead.
    """

    score: float
    """Score corresponding to hypothesis"""

    timesteps: torch.IntTensor
    """Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence"""


class CTCDecoderLMState(_LMState):
    """Language model state."""

    @property
    def children(self) -> Dict[int, CTCDecoderLMState]:
        """Map of indices to LM states"""
        return super().children

    def child(self, usr_index: int) -> CTCDecoderLMState:
        """Returns child corresponding to usr_index, or creates and returns a new state if input index
        is not found.

        Args:
            usr_index (int): index corresponding to child state

        Returns:
            CTCDecoderLMState: child state corresponding to usr_index
        """
        return super().child(usr_index)

    def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState:
        """Compare two language model states.

        Args:
            state (CTCDecoderLMState): LM state to compare against

        Returns:
            int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
        """
        pass


class CTCDecoderLM(_LM):
    """Language model base class for creating custom language models to use with the decoder."""

    @abstractmethod
    def start(self, start_with_nothing: bool) -> CTCDecoderLMState:
        """Initialize or reset the language model.

        Args:
            start_with_nothing (bool): whether or not to start sentence with sil token.

        Returns:
            CTCDecoderLMState: starting state
        """
        raise NotImplementedError

    @abstractmethod
    def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]:
        """Evaluate the language model based on the current LM state and new word.

        Args:
            state (CTCDecoderLMState): current LM state
            usr_token_idx (int): index of the word

        Returns:
            (CTCDecoderLMState, float)
                CTCDecoderLMState:
                    new LM state
                float:
                    score
        """
        raise NotImplementedError

    @abstractmethod
    def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]:
        """Evaluate end for language model based on current LM state.

        Args:
            state (CTCDecoderLMState): current LM state

        Returns:
            (CTCDecoderLMState, float)
                CTCDecoderLMState:
                    new LM state
                float:
                    score
        """
        raise NotImplementedError


class CTCDecoder:
    """CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.

    .. devices:: CPU

    Note:
        To build the decoder, please use the factory function :func:`ctc_decoder`.
    """

    def __init__(
        self,
        nbest: int,
        lexicon: Optional[Dict],
        word_dict: _Dictionary,
        tokens_dict: _Dictionary,
        lm: CTCDecoderLM,
        decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
        blank_token: str,
        sil_token: str,
        unk_word: str,
    ) -> None:
        """
        Args:
            nbest (int): number of best decodings to return
            lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
            word_dict (_Dictionary): dictionary of words
            tokens_dict (_Dictionary): dictionary of tokens
            lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
            decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions):
                parameters used for beam search decoding
            blank_token (str): token corresopnding to blank
            sil_token (str): token corresponding to silence
            unk_word (str): word corresponding to unknown
        """

        self.nbest = nbest
        self.word_dict = word_dict
        self.tokens_dict = tokens_dict
        self.blank = self.tokens_dict.get_index(blank_token)
        silence = self.tokens_dict.get_index(sil_token)
        transitions = []

        if lexicon:
            trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
            unk_word = word_dict.get_index(unk_word)
            token_lm = False  # use word level LM

            self.decoder = _LexiconDecoder(
                decoder_options,
                trie,
                lm,
                silence,
                self.blank,
                unk_word,
                transitions,
                token_lm,
            )
        else:
            self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
        # https://github.com/pytorch/audio/issues/3218
        # If lm is passed like rvalue reference, the lm object gets garbage collected,
        # and later call to the lm fails.
        # This ensures that lm object is not deleted as long as the decoder is alive.
        # https://github.com/pybind/pybind11/discussions/4013
        self.lm = lm

    def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
        idxs = (g[0] for g in it.groupby(idxs))
        idxs = filter(lambda x: x != self.blank, idxs)
        return torch.LongTensor(list(idxs))

    def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
        """Returns frame numbers corresponding to non-blank tokens."""

        timesteps = []
        for i, idx in enumerate(idxs):
            if idx == self.blank:
                continue
            if i == 0 or idx != idxs[i - 1]:
                timesteps.append(i)
        return torch.IntTensor(timesteps)

    def decode_begin(self):
        """Initialize the internal state of the decoder.

        See :py:meth:`decode_step` for the usage.

        .. note::

           This method is required only when performing online decoding.
           It is not necessary when performing batch decoding with :py:meth:`__call__`.
        """
        self.decoder.decode_begin()

    def decode_end(self):
        """Finalize the internal state of the decoder.

        See :py:meth:`decode_step` for the usage.

        .. note::

           This method is required only when performing online decoding.
           It is not necessary when performing batch decoding with :py:meth:`__call__`.
        """
        self.decoder.decode_end()

    def decode_step(self, emissions: torch.FloatTensor):
        """Perform incremental decoding on top of the curent internal state.

        .. note::

           This method is required only when performing online decoding.
           It is not necessary when performing batch decoding with :py:meth:`__call__`.

        Args:
            emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
                probability distribution over labels; output of acoustic model.

        Example:
            >>> decoder = torchaudio.models.decoder.ctc_decoder(...)
            >>> decoder.decode_begin()
            >>> decoder.decode_step(emission1)
            >>> decoder.decode_step(emission2)
            >>> decoder.decode_end()
            >>> result = decoder.get_final_hypothesis()
        """
        if emissions.dtype != torch.float32:
            raise ValueError("emissions must be float32.")

        if not emissions.is_cpu:
            raise RuntimeError("emissions must be a CPU tensor.")

        if not emissions.is_contiguous():
            raise RuntimeError("emissions must be contiguous.")

        if emissions.ndim != 2:
            raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")

        T, N = emissions.size()
        self.decoder.decode_step(emissions.data_ptr(), T, N)

    def _to_hypo(self, results) -> List[CTCHypothesis]:
        return [
            CTCHypothesis(
                tokens=self._get_tokens(result.tokens),
                words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
                score=result.score,
                timesteps=self._get_timesteps(result.tokens),
            )
            for result in results
        ]

    def get_final_hypothesis(self) -> List[CTCHypothesis]:
        """Get the final hypothesis

        Returns:
            List[CTCHypothesis]:
                List of sorted best hypotheses.

        .. note::

           This method is required only when performing online decoding.
           It is not necessary when performing batch decoding with :py:meth:`__call__`.
        """
        results = self.decoder.get_all_final_hypothesis()
        return self._to_hypo(results[: self.nbest])

    def __call__(
        self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
    ) -> List[List[CTCHypothesis]]:
        """
        Performs batched offline decoding.

        .. note::

           This method performs offline decoding in one go. To perform incremental decoding,
           please refer to :py:meth:`decode_step`.

        Args:
            emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
                probability distribution over labels; output of acoustic model.
            lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
                in time axis of the output Tensor in each batch.

        Returns:
            List[List[CTCHypothesis]]:
                List of sorted best hypotheses for each audio sequence in the batch.
        """

        if emissions.dtype != torch.float32:
            raise ValueError("emissions must be float32.")

        if not emissions.is_cpu:
            raise RuntimeError("emissions must be a CPU tensor.")

        if not emissions.is_contiguous():
            raise RuntimeError("emissions must be contiguous.")

        if emissions.ndim != 3:
            raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")

        if lengths is not None and not lengths.is_cpu:
            raise RuntimeError("lengths must be a CPU tensor.")

        B, T, N = emissions.size()
        if lengths is None:
            lengths = torch.full((B,), T)

        float_bytes = 4
        hypos = []

        for b in range(B):
            emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
            results = self.decoder.decode(emissions_ptr, lengths[b], N)
            hypos.append(self._to_hypo(results[: self.nbest]))
        return hypos

    def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
        """
        Map raw token IDs into corresponding tokens

        Args:
            idxs (LongTensor): raw token IDs generated from decoder

        Returns:
            List: tokens corresponding to the input IDs
        """
        return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]


def ctc_decoder(
    lexicon: Optional[str],
    tokens: Union[str, List[str]],
    lm: Union[str, CTCDecoderLM] = None,
    lm_dict: Optional[str] = None,
    nbest: int = 1,
    beam_size: int = 50,
    beam_size_token: Optional[int] = None,
    beam_threshold: float = 50,
    lm_weight: float = 2,
    word_score: float = 0,
    unk_score: float = float("-inf"),
    sil_score: float = 0,
    log_add: bool = False,
    blank_token: str = "-",
    sil_token: str = "|",
    unk_word: str = "<unk>",
) -> CTCDecoder:
    """Builds an instance of :class:`CTCDecoder`.

    Args:
        lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
            Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
            decoding.
        tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
            format is for tokens mapping to the same index to be on the same line
        lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
            custom language model of type `CTCDecoderLM`, or `None` if not using a language model
        lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
            per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
            in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
            (Default: None)
        nbest (int, optional): number of best decodings to return (Default: 1)
        beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
        beam_size_token (int, optional): max number of tokens to consider at each decode step.
            If `None`, it is set to the total number of tokens (Default: None)
        beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
        lm_weight (float, optional): weight of language model (Default: 2)
        word_score (float, optional): word insertion score (Default: 0)
        unk_score (float, optional): unknown word insertion score (Default: -inf)
        sil_score (float, optional): silence insertion score (Default: 0)
        log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
        blank_token (str, optional): token corresponding to blank (Default: "-")
        sil_token (str, optional): token corresponding to silence (Default: "|")
        unk_word (str, optional): word corresponding to unknown (Default: "<unk>")

    Returns:
        CTCDecoder: decoder

    Example
        >>> decoder = ctc_decoder(
        >>>     lexicon="lexicon.txt",
        >>>     tokens="tokens.txt",
        >>>     lm="kenlm.bin",
        >>> )
        >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
    """
    if lm_dict is not None and type(lm_dict) is not str:
        raise ValueError("lm_dict must be None or str type.")

    tokens_dict = _Dictionary(tokens)

    # decoder options
    if lexicon:
        lexicon = _load_words(lexicon)
        decoder_options = _LexiconDecoderOptions(
            beam_size=beam_size,
            beam_size_token=beam_size_token or tokens_dict.index_size(),
            beam_threshold=beam_threshold,
            lm_weight=lm_weight,
            word_score=word_score,
            unk_score=unk_score,
            sil_score=sil_score,
            log_add=log_add,
            criterion_type=_CriterionType.CTC,
        )
    else:
        decoder_options = _LexiconFreeDecoderOptions(
            beam_size=beam_size,
            beam_size_token=beam_size_token or tokens_dict.index_size(),
            beam_threshold=beam_threshold,
            lm_weight=lm_weight,
            sil_score=sil_score,
            log_add=log_add,
            criterion_type=_CriterionType.CTC,
        )

    # construct word dict and language model
    word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)

    if type(lm) == str:
        if _KenLM is None:
            raise RuntimeError(
                "flashlight-text is installed, but KenLM is not installed. "
                "Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
            )
        lm = _KenLM(lm, word_dict)
    elif lm is None:
        lm = _ZeroLM()

    return CTCDecoder(
        nbest=nbest,
        lexicon=lexicon,
        word_dict=word_dict,
        tokens_dict=tokens_dict,
        lm=lm,
        decoder_options=decoder_options,
        blank_token=blank_token,
        sil_token=sil_token,
        unk_word=unk_word,
    )


def _get_filenames(model: str) -> _PretrainedFiles:
    if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
        raise ValueError(
            f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
        )

    prefix = f"decoder-assets/{model}"
    return _PretrainedFiles(
        lexicon=f"{prefix}/lexicon.txt",
        tokens=f"{prefix}/tokens.txt",
        lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
    )


def download_pretrained_files(model: str) -> _PretrainedFiles:
    """
    Retrieves pretrained data files used for :func:`ctc_decoder`.

    Args:
        model (str): pretrained language model to download.
            Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``.

    Returns:
        Object with the following attributes

            * ``lm``: path corresponding to downloaded language model,
              or ``None`` if the model is not associated with an lm
            * ``lexicon``: path corresponding to downloaded lexicon file
            * ``tokens``: path corresponding to downloaded tokens file
    """

    files = _get_filenames(model)
    lexicon_file = download_asset(files.lexicon)
    tokens_file = download_asset(files.tokens)
    if files.lm is not None:
        lm_file = download_asset(files.lm)
    else:
        lm_file = None

    return _PretrainedFiles(
        lexicon=lexicon_file,
        tokens=tokens_file,
        lm=lm_file,
    )
