U
    ,:%eՊ                     @   s  d dl mZmZ d dlmZmZmZ d dlZd dlm	Z	 dddgZ
G dd	 d	ejjZG d
d dejjZG dd deZG dd dejjeZG dd dejjZG dd dejjZG dd dejjZeeeeeeeeeeeeeeeeeeeeeedddZeedddZdS )    )ABCabstractmethod)ListOptionalTupleN)EmformerRNNTemformer_rnnt_baseemformer_rnnt_modelc                       sJ   e Zd ZdZedd fddZejejeejejf dddZ	  Z
S )	_TimeReductionzCoalesces frames along time dimension into a
    fewer number of frames with higher feature dimensionality.

    Args:
        stride (int): number of frames to merge for each output frame.
    N)stridereturnc                    s   t    || _d S N)super__init__r   )selfr   	__class__ U/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/torchaudio/models/rnnt.pyr      s    
z_TimeReduction.__init__inputlengthsr   c           	      C   sr   |j \}}}||| j  }|ddd|ddf }|j| jdd}|| j }||||| j }| }||fS )a  Forward pass.

        B: batch size;
        T: maximum input sequence length in batch;
        D: feature dimension of each input sequence frame.

        Args:
            input (torch.Tensor): input sequences, with shape `(B, T, D)`.
            lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``input``.

        Returns:
            (torch.Tensor, torch.Tensor):
                torch.Tensor
                    output sequences, with shape
                    `(B, T  // stride, D * stride)`
                torch.Tensor
                    output lengths, with shape `(B,)` and i-th element representing
                    number of valid frames for i-th batch element in output sequences.
        Ntrunc)Zrounding_mode)shaper   divZreshape
contiguous)	r   r   r   BTDZ
num_framesZT_maxoutputr   r   r   forward   s    
z_TimeReduction.forward)__name__
__module____qualname____doc__intr   torchTensorr   r!   __classcell__r   r   r   r   r      s   r   c                       s^   e Zd ZdZdeeeedd fddZej	e
eej	  eej	eej	 f dd	d
Z  ZS )_CustomLSTMa  Custom long-short-term memory (LSTM) block that applies layer normalization
    to internal nodes.

    Args:
        input_dim (int): input dimension.
        hidden_dim (int): hidden dimension.
        layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
        layer_norm_epsilon (float, optional):  value of epsilon to use in
            layer normalization layers (Default: 1e-5)
    Fh㈵>N)	input_dim
hidden_dim
layer_normlayer_norm_epsilonr   c                    s   t    tjj|d| | d| _tjj|d| dd| _|rjtjj||d| _tjjd| |d| _	ntj
 | _tj
 | _	|| _d S )N   ZbiasF)Zeps)r   r   r'   nnLinearx2gp2g	LayerNormc_normg_normZIdentityr-   )r   r,   r-   r.   r/   r   r   r   r   C   s    
z_CustomLSTM.__init__)r   stater   c                 C   s  |dkrD| d}tj|| j|j|jd}tj|| j|j|jd}n|\}}| |}g }|dD ]}|| | }| 	|}|
dd\}	}
}}|	 }	|
 }
| }| }|
| |	|  }| |}||  }|| qdtj|dd}||g}||fS )a  Forward pass.

        B: batch size;
        T: maximum sequence length in batch;
        D: feature dimension of each input sequence element.

        Args:
            input (torch.Tensor): with shape `(T, B, D)`.
            state (List[torch.Tensor] or None): list of tensors
                representing internal state generated in preceding invocation
                of ``forward``.

        Returns:
            (torch.Tensor, List[torch.Tensor]):
                torch.Tensor
                    output, with shape `(T, B, hidden_dim)`.
                List[torch.Tensor]
                    list of tensors representing internal state generated
                    in current invocation of ``forward``.
        N   )devicedtyper   r0   )dim)sizer'   Zzerosr-   r;   r<   r4   Zunbindr5   r8   chunkZsigmoidtanhr7   appendstack)r   r   r9   r   hcZgated_inputoutputsZgatesZ
input_gateZforget_gateZ	cell_gateZoutput_gater    r   r   r   r!   V   s,    



z_CustomLSTM.forward)Fr+   r"   r#   r$   r%   r&   boolfloatr   r'   r(   r   r   r   r!   r)   r   r   r   r   r*   7   s      r*   c                
   @   st   e Zd Zeejejeejejf dddZeejejee	e	ej   eejeje	e	ej  f dddZ
dS )_Transcriberr   c                 C   s   d S r   r   )r   r   r   r   r   r   r!      s    z_Transcriber.forwardr   r   statesr   c                 C   s   d S r   r   )r   r   r   rK   r   r   r   infer   s    z_Transcriber.inferN)r"   r#   r$   r   r'   r(   r   r!   r   r   rL   r   r   r   r   rI      s   $rI   c                       s   e Zd ZdZddddddeeeeeeeeeeeeeeedd	 fd
dZe	j
e	j
ee	j
e	j
f dddZe	jje	j
e	j
eeee	j
   ee	j
e	j
eee	j
  f dddZ  ZS )_EmformerEncodera  Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).

    Args:
        input_dim (int): feature dimension of each input sequence element.
        output_dim (int): feature dimension of each output sequence element.
        segment_length (int): length of input segment expressed as number of frames.
        right_context_length (int): length of right context expressed as number of frames.
        time_reduction_input_dim (int): dimension to scale each element in input sequences to
            prior to applying time reduction block.
        time_reduction_stride (int): factor by which to reduce length of input sequence.
        transformer_num_heads (int): number of attention heads in each Emformer layer.
        transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
        transformer_num_layers (int): number of Emformer layers to instantiate.
        transformer_left_context_length (int): length of left context.
        transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
        transformer_activation (str, optional): activation function to use in each Emformer layer's
            feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
        transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
        transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
            strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
        transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
            relur   	depthwiseF)transformer_dropouttransformer_activationtransformer_max_memory_size&transformer_weight_init_scale_strategytransformer_tanh_on_memN)r,   
output_dimsegment_lengthright_context_lengthtime_reduction_input_dimtime_reduction_stridetransformer_num_headstransformer_ffn_dimtransformer_num_layerstransformer_left_context_lengthrQ   rR   rS   rT   rU   r   c                   s|   t    tjj||dd| _t|| _|| }t||||	|| |||
|| |||d| _	tj||| _
tj|| _d S )NFr1   )dropout
activationZleft_context_lengthrX   Zmax_memory_sizeZweight_init_scale_strategyZtanh_on_mem)r   r   r'   r2   r3   input_linearr   time_reductionr   transformeroutput_linearr6   r.   )r   r,   rV   rW   rX   rY   rZ   r[   r\   r]   r^   rQ   rR   rS   rT   rU   Ztransformer_input_dimr   r   r   r      s0    

z_EmformerEncoder.__init__r   c           
      C   sF   |  |}| ||\}}| ||\}}| |}| |}	|	|fS )a  Forward pass for training.

        B: batch size;
        T: maximum input sequence length in batch;
        D: feature dimension of each input sequence frame (input_dim).

        Args:
            input (torch.Tensor): input frame sequences right-padded with right context, with
                shape `(B, T + right context length, D)`.
            lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``input``.

        Returns:
            (torch.Tensor, torch.Tensor):
                torch.Tensor
                    output frame sequences, with
                    shape `(B, T // time_reduction_stride, output_dim)`.
                torch.Tensor
                    output input lengths, with shape `(B,)` and i-th element representing
                    number of valid elements for i-th batch element in output frame sequences.
        )ra   rb   rc   rd   r.   )
r   r   r   input_linear_outtime_reduction_outtime_reduction_lengthstransformer_outtransformer_lengthsoutput_linear_outlayer_norm_outr   r   r   r!      s    


z_EmformerEncoder.forwardrJ   c                 C   sN   |  |}| ||\}}| j|||\}}}	| |}
| |
}|||	fS )aR  Forward pass for inference.

        B: batch size;
        T: maximum input sequence segment length in batch;
        D: feature dimension of each input sequence frame (input_dim).

        Args:
            input (torch.Tensor): input frame sequence segments right-padded with right context, with
                shape `(B, T + right context length, D)`.
            lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``input``.
            state (List[List[torch.Tensor]] or None): list of lists of tensors
                representing internal state generated in preceding invocation
                of ``infer``.

        Returns:
            (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
                torch.Tensor
                    output frame sequences, with
                    shape `(B, T // time_reduction_stride, output_dim)`.
                torch.Tensor
                    output input lengths, with shape `(B,)` and i-th element representing
                    number of valid elements for i-th batch element in output.
                List[List[torch.Tensor]]
                    output states; list of lists of tensors
                    representing internal state generated in current invocation
                    of ``infer``.
        )ra   rb   rc   rL   rd   r.   )r   r   r   rK   re   rf   rg   rh   ri   Ztransformer_statesrj   rk   r   r   r   rL      s    #


z_EmformerEncoder.infer)r"   r#   r$   r%   r&   rH   strrG   r   r'   r(   r   r!   jitexportr   r   rL   r)   r   r   r   r   rM      s>   $,"rM   c                       sx   e Zd ZdZdeeeeeeeedd	 fddZdej	ej	e
eeej	   eej	ej	eeej	  f d	d
dZ  ZS )
_Predictora  Recurrent neural network transducer (RNN-T) prediction network.

    Args:
        num_symbols (int): size of target token lexicon.
        output_dim (int): feature dimension of each output sequence element.
        symbol_embedding_dim (int): dimension of each target token embedding.
        num_lstm_layers (int): number of LSTM layers to instantiate.
        lstm_hidden_dim (int): output dimension of each LSTM layer.
        lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
            for LSTM layers. (Default: ``False``)
        lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
            LSTM layer normalization layers. (Default: 1e-5)
        lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)

    Fr+   rN   N)	num_symbolsrV   symbol_embedding_dimnum_lstm_layerslstm_hidden_dimlstm_layer_normlstm_layer_norm_epsilonlstm_dropoutr   c	           	         s   t    tj|| _tj| _tj fddt	|D | _
tjj|d| _tj || _tj|| _|| _d S )Nc                    s(   g | ] }t |d krn  dqS )r   )r.   r/   )r*   ).0idxrs   rt   ru   rq   r   r   
<listcomp>H  s   z'_Predictor.__init__.<locals>.<listcomp>)p)r   r   r'   r2   Z	Embedding	embeddingr6   input_layer_normZ
ModuleListrangelstm_layersZDropoutr_   r3   linearoutput_layer_normrv   )	r   rp   rV   rq   rr   rs   rt   ru   rv   r   ry   r   r   9  s    
z_Predictor.__init__)r   r   r9   r   c                 C   s   | dd}| |}| |}|}g }t| jD ]:\}	}
|
||dkrJdn||	 \}}| |}|| q2| |}| |}| ddd||fS )a#  Forward pass.

        B: batch size;
        U: maximum sequence length in batch;
        D: feature dimension of each input sequence element.

        Args:
            input (torch.Tensor): target sequences, with shape `(B, U)` and each element
                mapping to a target symbol, i.e. in range `[0, num_symbols)`.
            lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``input``.
            state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
                representing internal state generated in preceding invocation
                of ``forward``. (Default: ``None``)

        Returns:
            (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
                torch.Tensor
                    output encoding sequences, with shape `(B, U, output_dim)`
                torch.Tensor
                    output lengths, with shape `(B,)` and i-th element representing
                    number of valid elements for i-th batch element in output encoding sequences.
                List[List[torch.Tensor]]
                    output states; list of lists of tensors
                    representing internal state generated in current invocation of ``forward``.
        r:   r   N   )	Zpermuter|   r}   	enumerater   r_   rA   r   r   )r   r   r   r9   Zinput_tbZembedding_outZinput_layer_norm_outZlstm_outZ	state_outZ	layer_idxZlstmZlstm_state_outZ
linear_outZoutput_layer_norm_outr   r   r   r!   X  s     




z_Predictor.forward)Fr+   rN   )NrF   r   r   r   r   ro   (  s,      # ro   c                       s\   e Zd ZdZd
eeedd fddZejejejeje	ejejejf ddd	Z
  ZS )_Joinera@  Recurrent neural network transducer (RNN-T) joint network.

    Args:
        input_dim (int): source and target input dimension.
        output_dim (int): output dimension.
        activation (str, optional): activation function to use in the joiner.
            Must be one of ("relu", "tanh"). (Default: "relu")

    rO   N)r,   rV   r`   r   c                    s\   t    tjj||dd| _|dkr4tj | _n$|dkrJtj | _nt	d| d S )NTr1   rO   r@   zUnsupported activation )
r   r   r'   r2   r3   r   ZReLUr`   ZTanh
ValueError)r   r,   rV   r`   r   r   r   r     s    
z_Joiner.__init__source_encodingssource_lengthstarget_encodingstarget_lengthsr   c                 C   s:   | d | d  }| |}| |}|||fS )a  Forward pass for training.

        B: batch size;
        T: maximum source sequence length in batch;
        U: maximum target sequence length in batch;
        D: dimension of each source and target sequence encoding.

        Args:
            source_encodings (torch.Tensor): source encoding sequences, with
                shape `(B, T, D)`.
            source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                valid sequence length of i-th batch element in ``source_encodings``.
            target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
            target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                valid sequence length of i-th batch element in ``target_encodings``.

        Returns:
            (torch.Tensor, torch.Tensor, torch.Tensor):
                torch.Tensor
                    joint network output, with shape `(B, T, U, output_dim)`.
                torch.Tensor
                    output source lengths, with shape `(B,)` and i-th element representing
                    number of valid elements along dim 1 for i-th batch element in joint network output.
                torch.Tensor
                    output target lengths, with shape `(B,)` and i-th element representing
                    number of valid elements along dim 2 for i-th batch element in joint network output.
        r   r:   )Z	unsqueezer   r`   r   )r   r   r   r   r   Zjoint_encodingsZactivation_outr    r   r   r   r!     s    "

z_Joiner.forward)rO   )r"   r#   r$   r%   r&   rl   r   r'   r(   r   r!   r)   r   r   r   r   r     s   
r   c                       sd  e Zd ZdZeeedd fddZdej	ej	ej	ej	e
eeej	   eej	ej	ej	eeej	  f dddZejjej	ej	e
eeej	   eej	ej	eeej	  f d	d
dZejjej	ej	eej	ej	f dddZejjej	ej	e
eeej	   eej	ej	eeej	  f dddZejjej	ej	ej	ej	eej	ej	ej	f dddZ  ZS )r   a  torchaudio.models.RNNT()

    Recurrent neural network transducer (RNN-T) model.

    Note:
        To build the model, please use one of the factory functions.

    See Also:
        :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.

    Args:
        transcriber (torch.nn.Module): transcription network.
        predictor (torch.nn.Module): prediction network.
        joiner (torch.nn.Module): joint network.
    N)transcriber	predictorjoinerr   c                    s    t    || _|| _|| _d S r   )r   r   r   r   r   )r   r   r   r   r   r   r   r     s    
zRNNT.__init__)sourcesr   targetsr   predictor_stater   c           	      C   sL   | j ||d\}}| j|||d\}}}| j||||d\}}}||||fS )a  Forward pass for training.

        B: batch size;
        T: maximum source sequence length in batch;
        U: maximum target sequence length in batch;
        D: feature dimension of each source sequence element.

        Args:
            sources (torch.Tensor): source frame sequences right-padded with right context, with
                shape `(B, T, D)`.
            source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``sources``.
            targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
                mapping to a target symbol.
            target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``targets``.
            predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
                representing prediction network internal state generated in preceding invocation
                of ``forward``. (Default: ``None``)

        Returns:
            (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
                torch.Tensor
                    joint network output, with shape
                    `(B, max output source length, max output target length, output_dim (number of target symbols))`.
                torch.Tensor
                    output source lengths, with shape `(B,)` and i-th element representing
                    number of valid elements along dim 1 for i-th batch element in joint network output.
                torch.Tensor
                    output target lengths, with shape `(B,)` and i-th element representing
                    number of valid elements along dim 2 for i-th batch element in joint network output.
                List[List[torch.Tensor]]
                    output states; list of lists of tensors
                    representing prediction network internal state generated in current invocation
                    of ``forward``.
        )r   r   r   r   r9   r   r   r   r   )r   r   r   )	r   r   r   r   r   r   r   r   r    r   r   r   r!     s(    ,
zRNNT.forward)r   r   r9   r   c                 C   s   | j |||S )a  Applies transcription network to sources in streaming mode.

        B: batch size;
        T: maximum source sequence segment length in batch;
        D: feature dimension of each source sequence frame.

        Args:
            sources (torch.Tensor): source frame sequence segments right-padded with right context, with
                shape `(B, T + right context length, D)`.
            source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``sources``.
            state (List[List[torch.Tensor]] or None): list of lists of tensors
                representing transcription network internal state generated in preceding invocation
                of ``transcribe_streaming``.

        Returns:
            (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
                torch.Tensor
                    output frame sequences, with
                    shape `(B, T // time_reduction_stride, output_dim)`.
                torch.Tensor
                    output lengths, with shape `(B,)` and i-th element representing
                    number of valid elements for i-th batch element in output.
                List[List[torch.Tensor]]
                    output states; list of lists of tensors
                    representing transcription network internal state generated in current invocation
                    of ``transcribe_streaming``.
        )r   rL   )r   r   r   r9   r   r   r   transcribe_streaming  s    #zRNNT.transcribe_streaming)r   r   r   c                 C   s   |  ||S )a  Applies transcription network to sources in non-streaming mode.

        B: batch size;
        T: maximum source sequence length in batch;
        D: feature dimension of each source sequence frame.

        Args:
            sources (torch.Tensor): source frame sequences right-padded with right context, with
                shape `(B, T + right context length, D)`.
            source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``sources``.

        Returns:
            (torch.Tensor, torch.Tensor):
                torch.Tensor
                    output frame sequences, with
                    shape `(B, T // time_reduction_stride, output_dim)`.
                torch.Tensor
                    output lengths, with shape `(B,)` and i-th element representing
                    number of valid elements for i-th batch element in output frame sequences.
        )r   )r   r   r   r   r   r   
transcribeD  s    zRNNT.transcribe)r   r   r9   r   c                 C   s   | j |||dS )a  Applies prediction network to targets.

        B: batch size;
        U: maximum target sequence length in batch;
        D: feature dimension of each target sequence frame.

        Args:
            targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
                mapping to a target symbol, i.e. in range `[0, num_symbols)`.
            target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``targets``.
            state (List[List[torch.Tensor]] or None): list of lists of tensors
                representing internal state generated in preceding invocation
                of ``predict``.

        Returns:
            (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
                torch.Tensor
                    output frame sequences, with shape `(B, U, output_dim)`.
                torch.Tensor
                    output lengths, with shape `(B,)` and i-th element representing
                    number of valid elements for i-th batch element in output.
                List[List[torch.Tensor]]
                    output states; list of lists of tensors
                    representing internal state generated in current invocation of ``predict``.
        r   )r   )r   r   r   r9   r   r   r   predicta  s    !zRNNT.predictr   c                 C   s"   | j ||||d\}}}|||fS )a  Applies joint network to source and target encodings.

        B: batch size;
        T: maximum source sequence length in batch;
        U: maximum target sequence length in batch;
        D: dimension of each source and target sequence encoding.

        Args:
            source_encodings (torch.Tensor): source encoding sequences, with
                shape `(B, T, D)`.
            source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                valid sequence length of i-th batch element in ``source_encodings``.
            target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
            target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                valid sequence length of i-th batch element in ``target_encodings``.

        Returns:
            (torch.Tensor, torch.Tensor, torch.Tensor):
                torch.Tensor
                    joint network output, with shape `(B, T, U, output_dim)`.
                torch.Tensor
                    output source lengths, with shape `(B,)` and i-th element representing
                    number of valid elements along dim 1 for i-th batch element in joint network output.
                torch.Tensor
                    output target lengths, with shape `(B,)` and i-th element representing
                    number of valid elements along dim 2 for i-th batch element in joint network output.
        r   )r   )r   r   r   r   r   r    r   r   r   join  s    #z	RNNT.join)N)r"   r#   r$   r%   rI   ro   r   r   r'   r(   r   r   r   r!   rm   rn   r   r   r   r   r)   r   r   r   r   r     sF    C$")r,   encoding_dimrp   rW   rX   rY   rZ   r[   r\   r]   rQ   rR   r^   rS   rT   rU   rq   rr   rt   ru   rv   r   c                 C   sT   t | ||||||||	|
|||||d}t||||||||d}t||}t|||S )a 
  Builds Emformer-based :class:`~torchaudio.models.RNNT`.

    Note:
        For non-streaming inference, the expectation is for `transcribe` to be called on input
        sequences right-concatenated with `right_context_length` frames.

        For streaming inference, the expectation is for `transcribe_streaming` to be called
        on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
        frames.

    Args:
        input_dim (int): dimension of input sequence frames passed to transcription network.
        encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
            passed to joint network.
        num_symbols (int): cardinality of set of target tokens.
        segment_length (int): length of input segment expressed as number of frames.
        right_context_length (int): length of right context expressed as number of frames.
        time_reduction_input_dim (int): dimension to scale each element in input sequences to
            prior to applying time reduction block.
        time_reduction_stride (int): factor by which to reduce length of input sequence.
        transformer_num_heads (int): number of attention heads in each Emformer layer.
        transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
        transformer_num_layers (int): number of Emformer layers to instantiate.
        transformer_left_context_length (int): length of left context considered by Emformer.
        transformer_dropout (float): Emformer dropout probability.
        transformer_activation (str): activation function to use in each Emformer layer's
            feedforward network. Must be one of ("relu", "gelu", "silu").
        transformer_max_memory_size (int): maximum number of memory elements to use.
        transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
            strategy. Must be one of ("depthwise", "constant", ``None``).
        transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
        symbol_embedding_dim (int): dimension of each target token embedding.
        num_lstm_layers (int): number of LSTM layers to instantiate.
        lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
        lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
        lstm_dropout (float): LSTM dropout probability.

    Returns:
        RNNT:
            Emformer RNN-T model.
    )r,   rV   rW   rX   rY   rZ   r[   r\   r]   rQ   rR   r^   rS   rT   rU   )rq   rr   rs   rt   ru   rv   )rM   ro   r   r   )r,   r   rp   rW   rX   rY   rZ   r[   r\   r]   rQ   rR   r^   rS   rT   rU   rq   rr   rt   ru   rv   encoderr   r   r   r   r   r
     s:    A

)rp   r   c                 C   s2   t dd| dddddddd	d
ddddddddddS )zBuilds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.

    Args:
        num_symbols (int): The size of target token lexicon.

    Returns:
        RNNT:
            Emformer RNN-T model.
    P   i      r0         i      g?Zgelu   r   rP   Ti      gMbP?g333333?)r,   r   rp   rW   rX   rY   rZ   r[   r\   r]   rQ   rR   r^   rS   rT   rU   rq   rr   rt   ru   rv   )r
   )rp   r   r   r   r	     s.    
)abcr   r   typingr   r   r   r'   Ztorchaudio.modelsr   __all__r2   Moduler   r*   rI   rM   ro   r   r   r&   rH   rl   rG   r
   r	   r   r   r   r   <module>   sJ   
,R `= n`