U
    ,:%eZ³  ã                   @   s.  d dl Z d dlmZmZmZmZ d dlZd dlmZmZ d dl	m
Z dgZdeeeeejjdœd	d
„Zd eeeeeeeeee f  eeeejjdœ	dd„Zeedœdd„ZG dd„ dejƒZG dd„ dejƒZG dd„ dejƒZG dd„ dejƒZG dd„ dejƒZG dd„ dejƒZG dd„ dejƒZdS )!é    N)ÚListÚOptionalÚTupleÚUnion)ÚnnÚTensor)Ú
functionalÚ	Tacotron2TÚlinear)Úin_dimÚout_dimÚbiasÚw_init_gainÚreturnc                 C   s4   t jj| ||d}t jjj|jt jj |¡d |S )a  Linear layer with xavier uniform initialization.

    Args:
        in_dim (int): Size of each input sample.
        out_dim (int): Size of each output sample.
        bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``)
        w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
            for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)

    Returns:
        (torch.nn.Linear): The corresponding linear layer.
    ©r   ©Zgain)Útorchr   ÚLinearÚinitÚxavier_uniform_ÚweightÚcalculate_gain)r   r   r   r   r
   © r   úZ/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/torchaudio/models/tacotron2.pyÚ_get_linear_layer)   s    r   é   )	Úin_channelsÚout_channelsÚkernel_sizeÚstrideÚpaddingÚdilationr   r   r   c           	   	   C   sl   |dkr0|d dkrt dƒ‚t||d  d ƒ}tjj| ||||||d}tjjj|jtjj |¡d |S )al  1D convolution with xavier uniform initialization.

    Args:
        in_channels (int): Number of channels in the input image.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (int, optional): Number of channels in the input image. (Default: ``1``)
        stride (int, optional): Number of channels in the input image. (Default: ``1``)
        padding (str, int or tuple, optional): Padding added to both sides of the input.
            (Default: dilation * (kernel_size - 1) / 2)
        dilation (int, optional): Number of channels in the input image. (Default: ``1``)
        w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
            for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)

    Returns:
        (torch.nn.Conv1d): The corresponding Conv1D layer.
    Né   r   zkernel_size must be odd)r   r   r    r!   r   r   )	Ú
ValueErrorÚintr   r   ÚConv1dr   r   r   r   )	r   r   r   r   r    r!   r   r   Zconv1dr   r   r   Ú_get_conv1d_layer;   s    ù
r&   )Úlengthsr   c                 C   sF   t  | ¡ ¡ }t jd|| j| jd}||  d¡k  ¡ }t  |d¡}|S )al  Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask
    is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths.

    Args:
        lengths (Tensor): The length of each element in the batch, with shape (n_batch, ).

    Returns:
        mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``).
    r   )ÚdeviceÚdtyper   )	r   ÚmaxÚitemZaranger(   r)   Ú	unsqueezeÚbyteÚle)r'   Úmax_lenZidsÚmaskr   r   r   Ú_get_mask_from_lengthsi   s
    
r1   c                       s:   e Zd ZdZeeedœ‡ fdd„Zeedœdd„Z‡  ZS )Ú_LocationLayera  Location layer used in the Attention model.

    Args:
        attention_n_filter (int): Number of filters for attention model.
        attention_kernel_size (int): Kernel size for attention model.
        attention_hidden_dim (int): Dimension of attention hidden representation.
    )Úattention_n_filterÚattention_kernel_sizeÚattention_hidden_dimc              	      sH   t ƒ  ¡  t|d d ƒ}td|||dddd| _t||ddd| _d S )Nr   r"   F)r   r    r   r   r!   Útanh©r   r   )ÚsuperÚ__init__r$   r&   Úlocation_convr   Úlocation_dense)Úselfr3   r4   r5   r    ©Ú	__class__r   r   r9   ƒ   s"    
ù	   ÿz_LocationLayer.__init__)Úattention_weights_catr   c                 C   s$   |   |¡}| dd¡}|  |¡}|S )a  Location layer used in the Attention model.

        Args:
            attention_weights_cat (Tensor): Cumulative and previous attention weights
                with shape (n_batch, 2, max of ``text_lengths``).

        Returns:
            processed_attention (Tensor): Cumulative and previous attention weights
                with shape (n_batch, ``attention_hidden_dim``).
        r   r"   )r:   Ú	transposer;   )r<   r?   Zprocessed_attentionr   r   r   Úforward˜   s    

z_LocationLayer.forward©	Ú__name__Ú
__module__Ú__qualname__Ú__doc__r$   r9   r   rA   Ú__classcell__r   r   r=   r   r2   z   s   
ür2   c                       sd   e Zd ZdZeeeeeddœ‡ fdd„Zeeeedœdd„Zeeeeeeeef d	œd
d„Z	‡  Z
S )Ú
_Attentionaº  Locally sensitive attention model.

    Args:
        attention_rnn_dim (int): Number of hidden units for RNN.
        encoder_embedding_dim (int): Number of embedding dimensions in the Encoder.
        attention_hidden_dim (int): Dimension of attention hidden representation.
        attention_location_n_filter (int): Number of filters for Attention model.
        attention_location_kernel_size (int): Kernel size for Attention model.
    N)Úattention_rnn_dimÚencoder_embedding_dimr5   Úattention_location_n_filterÚattention_location_kernel_sizer   c                    s\   t ƒ  ¡  t||ddd| _t||ddd| _t|ddd| _t|||ƒ| _tdƒ | _	d S )NFr6   r7   r   r   Úinf)
r8   r9   r   Úquery_layerÚmemory_layerÚvr2   Úlocation_layerÚfloatÚscore_mask_value)r<   rI   rJ   r5   rK   rL   r=   r   r   r9   ¶   s    
   ÿýz_Attention.__init__)ÚqueryÚprocessed_memoryr?   r   c                 C   s@   |   | d¡¡}|  |¡}|  t || | ¡¡}| d¡}|S )a=  Get the alignment vector.

        Args:
            query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step).
            processed_memory (Tensor): Processed Encoder outputs
                with shape (n_batch, max of ``text_lengths``, attention_hidden_dim).
            attention_weights_cat (Tensor): Cumulative and previous attention weights
                with shape (n_batch, 2, max of ``text_lengths``).

        Returns:
            alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``).
        r   r"   )rN   r,   rQ   rP   r   r6   Úsqueeze)r<   rT   rU   r?   Zprocessed_queryZprocessed_attention_weightsZenergiesÚ	alignmentr   r   r   Ú_get_alignment_energiesË   s
    

z"_Attention._get_alignment_energies)Úattention_hidden_stateÚmemoryrU   r?   r0   r   c           	      C   sN   |   |||¡}| || j¡}tj|dd}t | d¡|¡}| d¡}||fS )a¹  Pass the input through the Attention model.

        Args:
            attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``).
            memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            processed_memory (Tensor): Processed Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
            attention_weights_cat (Tensor): Previous and cumulative attention weights
                with shape (n_batch, current_num_frames * 2, max of ``text_lengths``).
            mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).

        Returns:
            attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
            attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
        r   ©Údim)	rX   Zmasked_fillrS   ÚFZsoftmaxr   Zbmmr,   rV   )	r<   rY   rZ   rU   r?   r0   rW   Úattention_weightsÚattention_contextr   r   r   rA   à   s    
z_Attention.forward)rC   rD   rE   rF   r$   r9   r   rX   r   rA   rG   r   r   r=   r   rH   «   s    ù
ùrH   c                       s>   e Zd ZdZeee ddœ‡ fdd„Zeedœdd„Z‡  Z	S )	Ú_PrenetzÒPrenet Module. It is consists of ``len(output_size)`` linear layers.

    Args:
        in_dim (int): The size of each input sample.
        output_sizes (list): The output dimension of each linear layers.
    N)r   Ú	out_sizesr   c                    s<   t ƒ  ¡  |g|d d…  }t dd„ t||ƒD ƒ¡| _d S )Néÿÿÿÿc                 S   s   g | ]\}}t ||d d‘qS )Fr   )r   )Ú.0Zin_sizeZout_sizer   r   r   Ú
<listcomp>  s     z$_Prenet.__init__.<locals>.<listcomp>)r8   r9   r   Ú
ModuleListÚzipÚlayers)r<   r   ra   Zin_sizesr=   r   r   r9   
  s
    
ÿz_Prenet.__init__©Úxr   c                 C   s*   | j D ]}tjt ||ƒ¡ddd}q|S )zÙPass the input through Prenet.

        Args:
            x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim).

        Return:
            x (Tensor): Tensor with shape (n_batch, sizes[-1])
        ç      à?T)ÚpÚtraining)rg   r]   ÚdropoutÚrelu)r<   ri   r
   r   r   r   rA     s    

z_Prenet.forward)
rC   rD   rE   rF   r$   r   r9   r   rA   rG   r   r   r=   r   r`     s   r`   c                       s<   e Zd ZdZeeeedœ‡ fdd„Zeedœdd„Z‡  ZS )Ú_Postneta  Postnet Module.

    Args:
        n_mels (int): Number of mel bins.
        postnet_embedding_dim (int): Postnet embedding dimension.
        postnet_kernel_size (int): Postnet kernel size.
        postnet_n_convolution (int): Number of postnet convolutions.
    )Ún_melsÚpostnet_embedding_dimÚpostnet_kernel_sizeÚpostnet_n_convolutionc           
         s¶   t ƒ  ¡  t ¡ | _t|ƒD ]ˆ}|dkr,|n|}||d kr@|n|}||d krTdnd}||d krh|n|}	| j t t|||dt	|d d ƒd|dt 
|	¡¡¡ qt| jƒ| _d S )Nr   r   r
   r6   r"   ©r   r   r    r!   r   )r8   r9   r   re   ÚconvolutionsÚrangeÚappendÚ
Sequentialr&   r$   ÚBatchNorm1dÚlenÚn_convs)
r<   rp   rq   rr   rs   Úir   r   Z	init_gainZnum_featuresr=   r   r   r9   *  s,    

ù	öÿz_Postnet.__init__rh   c                 C   sZ   t | jƒD ]J\}}|| jd k r>tjt ||ƒ¡d| jd}q
tj||ƒd| jd}q
|S )a  Pass the input through Postnet.

        Args:
            x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).

        Return:
            x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
        r   rj   )rl   )Ú	enumerateru   r{   r]   rm   r   r6   rl   )r<   ri   r|   Úconvr   r   r   rA   J  s
    
z_Postnet.forwardrB   r   r   r=   r   ro      s   û ro   c                       s>   e Zd ZdZeeeddœ‡ fdd„Zeeedœdd„Z‡  ZS )	Ú_Encodera§  Encoder Module.

    Args:
        encoder_embedding_dim (int): Number of embedding dimensions in the encoder.
        encoder_n_convolution (int): Number of convolution layers in the encoder.
        encoder_kernel_size (int): The kernel size in the encoder.

    Examples
        >>> encoder = _Encoder(3, 512, 5)
        >>> input = torch.rand(10, 20, 30)
        >>> output = encoder(input)  # shape: (10, 30, 512)
    N)rJ   Úencoder_n_convolutionÚencoder_kernel_sizer   c                    sŠ   t ƒ  ¡  t ¡ | _t|ƒD ]@}t t|||dt|d d ƒdddt 	|¡¡}| j 
|¡ qtj|t|d ƒdddd| _| j ¡  d S )Nr   r"   rn   rt   T)Úbatch_firstÚbidirectional)r8   r9   r   re   ru   rv   rx   r&   r$   ry   rw   ZLSTMÚlstmZflatten_parameters)r<   rJ   r€   r   Ú_Z
conv_layerr=   r   r   r9   k  s0    

ù	ö
ûz_Encoder.__init__)ri   Úinput_lengthsr   c                 C   sv   | j D ]}t t ||ƒ¡d| j¡}q| dd¡}| ¡ }tjj	j
||dd}|  |¡\}}tjj	j|dd\}}|S )a_  Pass the input through the Encoder.

        Args:
            x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq).
            input_lengths (Tensor): The length of each input sequence with shape (n_batch, ).

        Return:
            x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim).
        rj   r   r"   T)r‚   )ru   r]   rm   rn   rl   r@   Úcpur   ÚutilsZrnnZpack_padded_sequencer„   Zpad_packed_sequence)r<   ri   r†   r~   Úoutputsr…   r   r   r   rA   Œ  s    
z_Encoder.forwardrB   r   r   r=   r   r   ]  s   û!r   c                       s2  e Zd ZdZeeeeeeeeeeeeeeddœ‡ fdd„Zeedœdd„Z	ee
eeeeeeeef dœd	d
„Zeedœdd„Zeeee
eeef dœdd„Zeeeeeeeeeeee
eeeeeeeeef	 dœdd„Zeeee
eeef dœdd„Zeedœdd„Zejjeee
eeeef dœdd„ƒZ‡  ZS )Ú_Decodera,  Decoder with Attention model.

    Args:
        n_mels (int): number of mel bins
        n_frames_per_step (int): number of frames processed per step, only 1 is supported
        encoder_embedding_dim (int): the number of embedding dimensions in the encoder.
        decoder_rnn_dim (int): number of units in decoder LSTM
        decoder_max_step (int): maximum number of output mel spectrograms
        decoder_dropout (float): dropout probability for decoder LSTM
        decoder_early_stopping (bool): stop decoding when all samples are finished
        attention_rnn_dim (int): number of units in attention LSTM
        attention_hidden_dim (int): dimension of attention hidden representation
        attention_location_n_filter (int): number of filters for attention model
        attention_location_kernel_size (int): kernel size for attention model
        attention_dropout (float): dropout probability for attention LSTM
        prenet_dim (int): number of ReLU units in prenet layers
        gate_threshold (float): probability threshold for stop token
    N)rp   Ún_frames_per_steprJ   Údecoder_rnn_dimÚdecoder_max_stepÚdecoder_dropoutÚdecoder_early_stoppingrI   r5   rK   rL   Úattention_dropoutÚ
prenet_dimÚgate_thresholdr   c                    sÆ   t ƒ  ¡  || _|| _|| _|| _|| _|| _|| _|| _	|| _
|| _|| _t|| ||gƒ| _t || |¡| _t|||	|
|ƒ| _t || |d¡| _t|| || ƒ| _t|| dddd| _d S )NTr   Úsigmoidr7   )r8   r9   rp   r‹   rJ   rI   rŒ   r‘   r   r’   r   rŽ   r   r`   Úprenetr   ZLSTMCellÚattention_rnnrH   Úattention_layerÚdecoder_rnnr   Úlinear_projectionÚ
gate_layer)r<   rp   r‹   rJ   rŒ   r   rŽ   r   rI   r5   rK   rL   r   r‘   r’   r=   r   r   r9   ¹  s:    
û   ÿz_Decoder.__init__)rZ   r   c                 C   s4   |  d¡}|j}|j}tj|| j| j ||d}|S )am  Gets all zeros frames to use as the first decoder input.

        Args:
            memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).

        Returns:
            decoder_input (Tensor): all zeros frames with shape
                (n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``).
        r   ©r)   r(   ©Úsizer)   r(   r   Úzerosrp   r‹   ©r<   rZ   Ún_batchr)   r(   Údecoder_inputr   r   r   Ú_get_initial_frameì  s
    
z_Decoder._get_initial_framec                 C   sÈ   |  d¡}|  d¡}|j}|j}tj|| j||d}tj|| j||d}tj|| j||d}tj|| j||d}	tj||||d}
tj||||d}tj|| j||d}| j 	|¡}||||	|
|||fS )a  Initializes attention rnn states, decoder rnn states, attention
        weights, attention cumulative weights, attention context, stores memory
        and stores processed memory.

        Args:
            memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).

        Returns:
            attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
            attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
            attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
            processed_memory (Tensor): Processed encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
        r   r   rš   )
rœ   r)   r(   r   r   rI   rŒ   rJ   r–   rO   )r<   rZ   rŸ   Zmax_timer)   r(   Úattention_hiddenÚattention_cellÚdecoder_hiddenÚdecoder_cellr^   Úattention_weights_cumr_   rU   r   r   r   Ú_initialize_decoder_statesý  s*    

øz#_Decoder._initialize_decoder_states)Údecoder_inputsr   c                 C   s@   |  dd¡}| | d¡t| d¡| j ƒd¡}|  dd¡}|S )ak  Prepares decoder inputs.

        Args:
            decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs,
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)

        Returns:
            inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``).
        r   r"   r   rb   )r@   Úviewrœ   r$   r‹   )r<   r¨   r   r   r   Ú_parse_decoder_inputs.  s    ýz_Decoder._parse_decoder_inputs)Úmel_specgramÚgate_outputsÚ
alignmentsr   c                 C   sb   |  dd¡ ¡ }|  dd¡ ¡ }|  dd¡ ¡ }|jd d| jf}|j|Ž }|  dd¡}|||fS )aq  Prepares decoder outputs for output

        Args:
            mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
            gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
            alignments (Tensor): sequence of attention weights from the decoder
                with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)

        Returns:
            mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
            gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``)
            alignments (Tensor): sequence of attention weights from the decoder
                with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``)
        r   r   rb   r"   )r@   Ú
contiguousÚshaperp   r©   )r<   r«   r¬   r­   r¯   r   r   r   Ú_parse_decoder_outputsC  s    
z_Decoder._parse_decoder_outputs)r    r¢   r£   r¤   r¥   r^   r¦   r_   rZ   rU   r0   r   c              	   C   sä   t  ||fd¡}|  |||f¡\}}t || j| j¡}t j| d¡| d¡fdd}|  ||	|
||¡\}}||7 }t  ||fd¡}|  	|||f¡\}}t || j
| j¡}t j||fdd}|  |¡}|  |¡}|||||||||f	S )a&	  Decoder step using stored states, attention and memory

        Args:
            decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``).
            attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
            attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
            attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
            memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            processed_memory (Tensor): Processed Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
            mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).

        Returns:
            decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``).
            gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``).
            attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
            attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
            attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
        rb   r   r[   )r   Úcatr•   r]   rm   r   rl   r,   r–   r—   rŽ   r˜   r™   )r<   r    r¢   r£   r¤   r¥   r^   r¦   r_   rZ   rU   r0   Z
cell_inputr?   Z decoder_hidden_attention_contextZdecoder_outputZgate_predictionr   r   r   Údecodec  s8    )    ÿ

÷z_Decoder.decode)rZ   Úmel_specgram_truthÚmemory_lengthsr   c                 C   s  |   |¡ d¡}|  |¡}tj||fdd}|  |¡}t|ƒ}|  |¡\}}}	}
}}}}g g g   }}}t|ƒ| 	d¡d k ræ|t|ƒ }|  
||||	|
||||||¡\	}}}}}	}
}}}|| d¡g7 }|| d¡g7 }||g7 }qh|  t |¡t |¡t |¡¡\}}}|||fS )aî  Decoder forward pass for training.

        Args:
            memory (Tensor): Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
            memory_lengths (Tensor): Encoder output lengths for attention masking
                (the same as ``text_lengths``) with shape (n_batch, ).

        Returns:
            mel_specgram (Tensor): Predicted mel spectrogram
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
            gate_outputs (Tensor): Predicted stop token for each timestep
                with shape (n_batch,  max of ``mel_specgram_lengths``).
            alignments (Tensor): Sequence of attention weights from the decoder
                with shape (n_batch,  max of ``mel_specgram_lengths``, max of ``text_lengths``).
        r   r[   r   )r¡   r,   rª   r   r±   r”   r1   r§   rz   rœ   r²   rV   r°   Ústack)r<   rZ   r³   r´   r    r¨   r0   r¢   r£   r¤   r¥   r^   r¦   r_   rU   Zmel_outputsr¬   r­   Z
mel_outputÚgate_outputr«   r   r   r   rA   ­  sd    


÷õö  ÿ
z_Decoder.forwardc                 C   s4   |  d¡}|j}|j}tj|| j| j ||d}|S )aU  Gets all zeros frames to use as the first decoder input

        args:
            memory (Tensor): Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).

        returns:
            decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
        r   rš   r›   rž   r   r   r   Ú_get_go_frameù  s
    
z_Decoder._get_go_frame)rZ   r´   r   c                 C   s–  |  d¡|j }}|  |¡}t|ƒ}|  |¡\}}}	}
}}}}tj|gtj|d}tj|gtj|d}g }g }g }t	| j
ƒD ]´}|  |¡}|  ||||	|
||||||¡\	}}}}}	}
}}}| | d¡¡ | | dd¡¡ | |¡ ||   d7  < |t | d¡¡| jkO }| jr,t |¡r, q2|}q|t|ƒ| j
krLt d¡ tj|dd}tj|dd}tj|dd}|  |||¡\}}}||||fS )a’  Decoder inference

        Args:
            memory (Tensor): Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            memory_lengths (Tensor): Encoder output lengths for attention masking
                (the same as ``text_lengths``) with shape (n_batch, ).

        Returns:
            mel_specgram (Tensor): Predicted mel spectrogram
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
            mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
            gate_outputs (Tensor): Predicted stop token for each timestep
                with shape (n_batch,  max of ``mel_specgram_lengths``).
            alignments (Tensor): Sequence of attention weights from the decoder
                with shape (n_batch,  max of ``mel_specgram_lengths``, max of ``text_lengths``).
        r   rš   r   zZReached max decoder steps. The generated spectrogram might not cover the whole transcript.r[   )rœ   r(   r·   r1   r§   r   r   Zint32Úboolrv   r   r”   r²   rw   r,   r@   r“   rV   r’   r   Úallrz   ÚwarningsÚwarnr±   r°   )r<   rZ   r´   Z
batch_sizer(   r    r0   r¢   r£   r¤   r¥   r^   r¦   r_   rU   Úmel_specgram_lengthsÚfinishedZmel_specgramsr¬   r­   r…   r«   r¶   r   r   r   Úinfer
  sx    

÷
õö
ÿz_Decoder.infer)rC   rD   rE   rF   r$   rR   r¸   r9   r   r¡   r   r§   rª   r°   r²   rA   r·   r   ÚjitÚexportr¾   rG   r   r   r=   r   rŠ   ¥  s`   ð3þ1  þ"óK  þLrŠ   c                       s    e Zd ZdZdeeeeeeeeeeeeeeeeeeeeeeddœ‡ fdd„Zeeeee	eeeef dœdd„Z
ejjdeee e	eeef dœdd„ƒZ‡  ZS )r	   aÑ	  Tacotron2 model from *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
    :cite:`shen2018natural` based on the implementation from
    `Nvidia Deep Learning Examples <https://github.com/NVIDIA/DeepLearningExamples/>`_.

    See Also:
        * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.

    Args:
        mask_padding (bool, optional): Use mask padding (Default: ``False``).
        n_mels (int, optional): Number of mel bins (Default: ``80``).
        n_symbol (int, optional): Number of symbols for the input text (Default: ``148``).
        n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``).
        symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``).
        encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``).
        encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``).
        encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``).
        decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``).
        decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``).
        decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``).
        decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``).
        attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``).
        attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``).
        attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``).
        attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``).
        attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``).
        prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``).
        postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``).
        postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``).
        postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``).
        gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``).
    FéP   é”   r   é   é   é   é   éÐ  çš™™™™™¹?Té€   é    é   é   rj   N)Úmask_paddingrp   Ún_symbolr‹   Úsymbol_embedding_dimrJ   r€   r   rŒ   r   rŽ   r   rI   r5   rK   rL   r   r‘   rs   rr   rq   r’   r   c                    s‚   t ƒ  ¡  || _|| _|| _t ||¡| _tjj	 
| jj¡ t|||ƒ| _t||||	|
|||||||||ƒ| _t||||ƒ| _d S )N)r8   r9   rÍ   rp   r‹   r   Z	EmbeddingÚ	embeddingr   r   r   r   r   ÚencoderrŠ   Údecoderro   Úpostnet)r<   rÍ   rp   rÎ   r‹   rÏ   rJ   r€   r   rŒ   r   rŽ   r   rI   r5   rK   rL   r   r‘   rs   rr   rq   r’   r=   r   r   r9   †  s0    
òzTacotron2.__init__)ÚtokensÚtoken_lengthsr«   r¼   r   c                 C   sÀ   |   |¡ dd¡}|  ||¡}| j|||d\}}}|  |¡}	||	 }	| jr´t|ƒ}
|
 | j|
 	d¡|
 	d¡¡}
|
 
ddd¡}
| |
d¡ |	 |
d¡ | |
dd…ddd…f d¡ ||	||fS )a´  Pass the input through the Tacotron2 model. This is in teacher
        forcing mode, which is generally used for training.

        The input ``tokens`` should be padded with zeros to length max of ``token_lengths``.
        The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.

        Args:
            tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`.
            token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
            mel_specgram (Tensor): The target mel spectrogram
                with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
            mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.

        Returns:
            [Tensor, Tensor, Tensor, Tensor]:
                Tensor
                    Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
                Tensor
                    Mel spectrogram after Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
                Tensor
                    The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
                Tensor
                    Sequence of attention weights from the decoder with
                    shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
        r   r"   )r´   r   g        Ng     @@)rÐ   r@   rÑ   rÒ   rÓ   rÍ   r1   Úexpandrp   rœ   ZpermuteZmasked_fill_)r<   rÔ   rÕ   r«   r¼   Úembedded_inputsÚencoder_outputsr¬   r­   Zmel_specgram_postnetr0   r   r   r   rA   ¹  s"    !  ÿ
zTacotron2.forward)rÔ   r'   r   c                 C   s¢   |j \}}|dkr0t |g¡ |¡ |j|j¡}|dk	s<t‚|  |¡ 	dd¡}|  
||¡}| j ||¡\}}}	}
|  |¡}|| }|
 d||¡ 	dd¡}
|||
fS )aª  Using Tacotron2 for inference. The input is a batch of encoded
        sentences (``tokens``) and its corresponding lengths (``lengths``). The
        output is the generated mel spectrograms, its corresponding lengths, and
        the attention weights from the decoder.

        The input `tokens` should be padded with zeros to length max of ``lengths``.

        Args:
            tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
            lengths (Tensor or None, optional):
                The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
                If ``None``, it is assumed that the all the tokens are valid. Default: ``None``

        Returns:
            (Tensor, Tensor, Tensor):
                Tensor
                    The predicted mel spectrogram with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
                Tensor
                    The length of the predicted mel spectrogram with shape `(n_batch, )`.
                Tensor
                    Sequence of attention weights from the decoder with shape
                    `(n_batch, max of mel_specgram_lengths, max of lengths)`.
        Nr   r"   r   )r¯   r   ZtensorrÖ   Útor(   r)   ÚAssertionErrorrÐ   r@   rÑ   rÒ   r¾   rÓ   Zunfold)r<   rÔ   r'   rŸ   Ú
max_lengthr×   rØ   r«   r¼   r…   r­   Zmel_outputs_postnetr   r   r   r¾   ï  s    

zTacotron2.infer)FrÁ   rÂ   r   rÃ   rÃ   rÄ   rÅ   rÆ   rÇ   rÈ   TrÆ   rÉ   rÊ   rË   rÈ   rÌ   rÅ   rÅ   rÃ   rj   )N)rC   rD   rE   rF   r¸   r$   rR   r9   r   r   rA   r   r¿   rÀ   r   r¾   rG   r   r   r=   r   r	   e  sp   "                      éè5ú6)Tr
   )r   r   Nr   Tr
   )rº   Útypingr   r   r   r   r   r   r   Ztorch.nnr   r]   Ú__all__r$   r¸   Ústrr   r   r%   r&   r1   ÚModuler2   rH   r`   ro   r   rŠ   r	   r   r   r   r   Ú<module>   sF   ÿ      ø÷.1W=H   C