U
    ,-e                    @   s  d Z ddlZddlZddlmZ ddlmZmZmZ ddl	Z
ddlZddlZddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZmZmZ ddlmZ ddlmZmZmZmZm Z m!Z! ddl"m#Z# e $e%Z&dZ'dZ(dZ)dddgZ*dZ+dZ,ddgZ-eG dd deZ.dMee/e/f e0e/eej1 e/e
j2dddZ3G dd dej4Z5G dd  d ej4Z6G d!d" d"ej4Z7G d#d$ d$ej4Z8G d%d& d&ej4Z9G d'd( d(ej4Z:G d)d* d*e:Z;G d+d, d,ej4Z<G d-d. d.ej4Z=G d/d0 d0ej4Z>G d1d2 d2ej4Z?G d3d4 d4ej4Z@G d5d6 d6ej4ZAG d7d8 d8ej4ZBG d9d: d:ej4ZCG d;d< d<ej4ZDG d=d> d>eZEd?ZFd@ZGedAeFG dBdC dCeEZHedDeFG dEdF dFeEZIedGeFG dHdI dIeEZJedJeFG dKdL dLeEZKdS )Nz PyTorch UniSpeech model.    N)	dataclass)OptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)is_deepspeed_zero3_enabled)BaseModelOutputCausalLMOutputSequenceClassifierOutputWav2Vec2BaseModelOutput)PreTrainedModel)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )UniSpeechConfig   r   z/patrickvonplaten/unispeech-large-1500h-cv-timiti$  i   zW'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'gQ+1@z"microsoft/unispeech-large-1500h-cvz0microsoft/unispeech-large-multi-lingual-1500h-cvc                   @   sz   e Zd ZU dZdZeej ed< dZ	ejed< dZ
ejed< dZejed< dZeeej  ed< dZeeej  ed< dS )	UniSpeechForPreTrainingOutputaL  
    Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.

    Args:
        loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
            paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
        projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
            Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
            projected quantized states.
        projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
            Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
            target vectors for contrastive loss.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlossprojected_statesprojected_quantized_statescodevector_perplexityhidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r   r   r    r'   r'   q/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/unispeech/modeling_unispeech.pyr   B   s   
r   )shape	mask_probmask_lengthattention_mask	min_masksreturnc                    s  | \}dk rt dkr6t d d dtjd   fdd}|dk	rt|d	  nfd
dt|D }tj	|ft
d}g }	|}
|
dkr|S |D ]v}||}tjjt|d  |dd}t|dkrd }n|d }t|tj|
| tjd| g}|	| qt|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d kr҈d |	|	d k< t||	dd	 |S )af  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr2 }| d  |k rTt| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )intmax)input_lengthnum_masked_spanepsilonr+   r*   r-   sequence_lengthr'   r(   compute_num_masked_span   s    
z6_compute_mask_indices.<locals>.compute_num_masked_spanNc                    s   g | ]} qS r'   r'   .0_)r6   r'   r(   
<listcomp>   s     z)_compute_mask_indices.<locals>.<listcomp>dtyper   F)replace)
ValueErrornprandomranditemsumdetachtolistrangezerosboolchoicearangelenZconcatenateonesZint32appendarrayZbroadcast_toreshaper1   Zput_along_axis)r)   r*   r+   r,   r-   
batch_sizer7   input_lengthsZspec_aug_maskZspec_aug_mask_idxsZmax_num_masked_spanr2   r3   Zspec_aug_mask_idxZdummy_mask_idxoffsetsr'   r4   r(   _compute_mask_indicesg   s`      

  rU   c                       s&   e Zd Zd fdd	Zdd Z  ZS )UniSpeechNoLayerNormConvLayerr   c                    sj   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   kernel_sizestridebias)super__init__conv_dimin_conv_dimout_conv_dimr   Conv1dconv_kernelconv_stride	conv_biasconvr	   feat_extract_activation
activationselfconfiglayer_id	__class__r'   r(   r\      s    
z&UniSpeechNoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S N)rd   rf   rh   r   r'   r'   r(   forward   s    

z%UniSpeechNoLayerNormConvLayer.forward)r   r    r!   r"   r\   ro   __classcell__r'   r'   rk   r(   rV      s   rV   c                       s&   e Zd Zd fdd	Zdd Z  ZS )UniSpeechLayerNormConvLayerr   c                    s|   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   rW   T)Zelementwise_affine)r[   r\   r]   r^   r_   r   r`   ra   rb   rc   rd   	LayerNorm
layer_normr	   re   rf   rg   rk   r'   r(   r\      s    
z$UniSpeechLayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )Nr8   )rd   	transposert   rf   rn   r'   r'   r(   ro     s    


z#UniSpeechLayerNormConvLayer.forward)r   rp   r'   r'   rk   r(   rr      s   rr   c                       s&   e Zd Zd fdd	Zdd Z  ZS )UniSpeechGroupNormConvLayerr   c                    s   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   rW   T)
num_groupsZnum_channelsZaffine)r[   r\   r]   r^   r_   r   r`   ra   rb   rc   rd   r	   re   rf   	GroupNormrt   rg   rk   r'   r(   r\     s    
z$UniSpeechGroupNormConvLayer.__init__c                 C   s"   |  |}| |}| |}|S rm   )rd   rt   rf   rn   r'   r'   r(   ro   "  s    


z#UniSpeechGroupNormConvLayer.forward)r   rp   r'   r'   rk   r(   rw     s   rw   c                       s$   e Zd Z fddZdd Z  ZS ) UniSpeechPositionalConvEmbeddingc              	      s   t    tj|j|j|j|jd |jd| _tjj	}t
tjjdrNtjjj	}t rdd l}|jj| jjdd || jddd| _W 5 Q R X |j| | jj |j| | jj n|| jddd| _t|j| _t|j | _d S )Nr   )rX   paddinggroupsweight_normr   )Zmodifier_rankweight)namedim)r[   r\   r   r`   hidden_sizenum_conv_pos_embeddingsZnum_conv_pos_embedding_groupsrd   utilsr}   hasattrZparametrizationsr
   	deepspeedzeroZGatheredParametersr~   Zregister_external_parameterZweight_vZweight_gUniSpeechSamePadLayerr{   r	   re   rf   )rh   ri   r}   r   rk   r'   r(   r\   +  s(    

z)UniSpeechPositionalConvEmbedding.__init__c                 C   s:   | dd}| |}| |}| |}| dd}|S Nr   r   )rv   rd   r{   rf   rn   r'   r'   r(   ro   F  s    


z(UniSpeechPositionalConvEmbedding.forwardrp   r'   r'   rk   r(   rz   *  s   rz   c                       s$   e Zd Z fddZdd Z  ZS )r   c                    s$   t    |d dkrdnd| _d S )Nr   r   r   )r[   r\   num_pad_remove)rh   r   rk   r'   r(   r\   S  s    
zUniSpeechSamePadLayer.__init__c                 C   s,   | j dkr(|d d d d d | j  f }|S )Nr   )r   rn   r'   r'   r(   ro   W  s    
zUniSpeechSamePadLayer.forwardrp   r'   r'   rk   r(   r   R  s   r   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )UniSpeechFeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr@t ddg fddt jd D  }n6 jdkrd fddt jD }ntd	 j d
t|| _	d| _
d| _d S )Ngroupr   rj   c                    s   g | ]}t  |d  dqS )r   r   )rV   r:   iri   r'   r(   r<   e  s   z4UniSpeechFeatureEncoder.__init__.<locals>.<listcomp>r   layerc                    s   g | ]}t  |d qS )r   )rr   r   r   r'   r(   r<   j  s    z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)r[   r\   Zfeat_extract_normrw   rH   Znum_feat_extract_layersr@   r   
ModuleListconv_layersgradient_checkpointing_requires_grad)rh   ri   r   rk   r   r(   r\   a  s    




z UniSpeechFeatureEncoder.__init__c                 C   s   |   D ]
}d|_qd| _d S )NF)
parametersrequires_gradr   rh   paramr'   r'   r(   _freeze_parametersu  s    z*UniSpeechFeatureEncoder._freeze_parametersc                 C   sj   |d d d f }| j r"| jr"d|_| jD ]<}| j r\| jr\| jr\dd }tjj|||}q(||}q(|S )NTc                    s    fdd}|S )Nc                     s    |  S rm   r'   inputsmoduler'   r(   custom_forward  s    zVUniSpeechFeatureEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr'   r   r   r'   r   r(   create_custom_forward  s    z>UniSpeechFeatureEncoder.forward.<locals>.create_custom_forward)r   trainingr   r   r   r$   r   
checkpoint)rh   input_valuesr   Z
conv_layerr   r'   r'   r(   ro   z  s    

zUniSpeechFeatureEncoder.forward)r    r!   r"   r#   r\   r   ro   rq   r'   r'   rk   r(   r   ^  s   r   c                       s   e Zd Z fddZ  ZS )UniSpeechFeatureExtractorc                    s8   t  | td| jj d| jjd j dt d S )NzThe class `zD` has been depreciated and will be removed in Transformers v5. Use `r   z
` instead.)r[   r\   warningswarnrl   r    	__bases__FutureWarningrh   ri   rk   r'   r(   r\     s
    z"UniSpeechFeatureExtractor.__init__)r    r!   r"   r\   rq   r'   r'   rk   r(   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )UniSpeechFeatureProjectionc                    sJ   t    tj|jd |jd| _t|jd |j| _	t
|j| _d S )Nr8   Zeps)r[   r\   r   rs   r]   layer_norm_epsrt   Linearr   
projectionDropoutZfeat_proj_dropoutdropoutr   rk   r'   r(   r\     s    
z#UniSpeechFeatureProjection.__init__c                 C   s&   |  |}| |}| |}||fS rm   )rt   r   r   )rh   r   Znorm_hidden_statesr'   r'   r(   ro     s    


z"UniSpeechFeatureProjection.forwardrp   r'   r'   rk   r(   r     s   r   c                       s   e Zd ZdZdeeeeed fddZej	eedd	d
Z
dej	eej	 eeej	  eej	 eej	 eeej	eej	 eeej	  f dddZ  ZS )UniSpeechAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FT)	embed_dim	num_headsr   
is_decoderrZ   c                    s   t    || _|| _|| _|| | _| j| | jkrNtd| j d| d| jd | _|| _t	j
|||d| _t	j
|||d| _t	j
|||d| _t	j
|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      )rZ   )r[   r\   r   r   r   head_dimr@   scalingr   r   r   k_projv_projq_projout_proj)rh   r   r   r   r   rZ   rk   r'   r(   r\     s    

zUniSpeechAttention.__init__)tensorseq_lenbszc                 C   s    | ||| j| jdd S r   )viewr   r   rv   
contiguous)rh   r   r   r   r'   r'   r(   _shape  s    zUniSpeechAttention._shapeN)r   key_value_statespast_key_valuer,   layer_head_maskoutput_attentionsr.   c                 C   sx  |dk	}|  \}}	}
| || j }|r\|dk	r\|d jd |jd kr\|d }|d }n|r| | |d|}| | |d|}n|dk	r| | |d|}| | |d|}tj|d |gdd}tj|d |gdd}n(| | |d|}| | |d|}| j	r ||f}|| j
 d| jf}| ||	|j| }|j| }|j| }| d}t||dd}|  || j
 |	|fkrtd|| j
 |	|f d|   |dk	r |  |d|	|fkrtd	|d|	|f d|   ||| j
|	|| }||| j
 |	|}tjj|dd}|dk	r|  | j
fkrhtd
| j
f d|   |dddd||| j
|	| }||| j
 |	|}|r||| j
|	|}||| j
 |	|}nd}tjj|| j| jd}t||}|  || j
 |	| jfkr4td|| j
 |	| jf d|   ||| j
|	| j}|dd}|||	| j}| |}|||fS )z#Input shape: Batch x Time x ChannelNr   r   r   r8   r   z$Attention weights should be of size z	, but is z!Attention mask should be of size z/Head mask for a single layer should be of size )pr   z `attn_output` should be of size )sizer   r   r)   r   r   r   r$   catr   r   r   r   rQ   Zbmmrv   r@   r   
functionalsoftmaxr   r   r   r   )rh   r   r   r   r,   r   r   Zis_cross_attentionr   Ztgt_lenr;   Zquery_statesZ
key_statesZvalue_statesZ
proj_shapeZsrc_lenattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr'   r'   r(   ro     s~    





" 
zUniSpeechAttention.forward)r   FT)NNNNF)r    r!   r"   r#   r0   floatrJ   r\   r$   Tensorr   r   r   ro   rq   r'   r'   rk   r(   r     s4           r   c                       s$   e Zd Z fddZdd Z  ZS )UniSpeechFeedForwardc                    sp   t    t|j| _t|j|j| _	t
|jtrDt|j | _n|j| _t|j|j| _t|j| _d S rm   )r[   r\   r   r   Zactivation_dropoutintermediate_dropoutr   r   Zintermediate_sizeintermediate_dense
isinstanceZ
hidden_actstrr	   intermediate_act_fnoutput_densehidden_dropoutoutput_dropoutr   rk   r'   r(   r\   L  s    
zUniSpeechFeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S rm   )r   r   r   r   r   rn   r'   r'   r(   ro   Y  s    




zUniSpeechFeedForward.forwardrp   r'   r'   rk   r(   r   K  s   r   c                       s&   e Zd Z fddZdddZ  ZS )UniSpeechEncoderLayerc                    sf   t    t|j|j|jdd| _t|j	| _
tj|j|jd| _t|| _tj|j|jd| _d S )NFr   r   r   r   r   )r[   r\   r   r   num_attention_headsattention_dropout	attentionr   r   r   r   rs   r   rt   r   feed_forwardfinal_layer_normr   rk   r'   r(   r\   e  s    

zUniSpeechEncoderLayer.__init__NFc                 C   sf   |}| j |||d\}}}| |}|| }| |}|| | }| |}|f}|rb||f7 }|S Nr,   r   )r   r   rt   r   r   rh   r   r,   r   Zattn_residualr   r;   outputsr'   r'   r(   ro   r  s      



zUniSpeechEncoderLayer.forward)NFrp   r'   r'   rk   r(   r   d  s   r   c                       s,   e Zd Z fddZejdddZ  ZS )UniSpeechAttnAdapterLayerc                    sZ   t    |j| _|j| _t| j| _t	| j| j| _
t | _t	| j| j| _dS )z
        Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
        up training throughput.
        N)r[   r\   adapter_attn_dimZ	input_dimr   Z
hidden_dimr   rs   normr   linear_1ZReLUact_fnlinear_2r   rk   r'   r(   r\     s    

z"UniSpeechAttnAdapterLayer.__init__)r   c                 C   s,   |  |}| |}| |}| |}|S rm   )r   r   r   r   rn   r'   r'   r(   ro     s
    



z!UniSpeechAttnAdapterLayer.forward)r    r!   r"   r\   r$   r%   ro   rq   r'   r'   rk   r(   r     s   r   c                       s8   e Zd Z fddZdejeej edddZ  Z	S )	$UniSpeechEncoderLayerStableLayerNormc                    s   t    t|j|j|jdd| _t|j	| _
tj|j|jd| _t|| _tj|j|jd| _t|dd d k	r~t|| _nd | _d S )NFr   r   r   )r[   r\   r   r   r   r   r   r   r   r   r   rs   r   rt   r   r   r   getattrr   adapter_layerr   rk   r'   r(   r\     s    

z-UniSpeechEncoderLayerStableLayerNorm.__init__NF)r   r,   r   c                 C   sz   |}|  |}| j|||d\}}}| |}|| }|| | | }| jd k	rb|| | }|f}|rv||f7 }|S r   )rt   r   r   r   r   r   r   r'   r'   r(   ro     s     
  


z,UniSpeechEncoderLayerStableLayerNorm.forward)NF)
r    r!   r"   r\   r$   r   r   rJ   ro   rq   r'   r'   rk   r(   r     s     r   c                       s<   e Zd Z fddZd	ejeej eeedddZ	  Z
S )
UniSpeechEncoderc                    sf   t     | _t | _tj j jd| _	t
 j| _t fddt jD | _d| _d S )Nr   c                    s   g | ]}t  qS r'   )r   r9   r   r'   r(   r<     s     z-UniSpeechEncoder.__init__.<locals>.<listcomp>Fr[   r\   ri   rz   pos_conv_embedr   rs   r   r   rt   r   r   r   r   rH   num_hidden_layerslayersr   r   rk   r   r(   r\     s    

 zUniSpeechEncoder.__init__NFT)r   r,   r   output_hidden_statesreturn_dictc                    s  |rdnd } rdnd }|d k	r| ddd|jd }d|| < d|d d d d d d f j|jd }|t|jj }||jd d|jd |jd }| 	|}	||	 }| 
|}| |}t }
| jD ]}|r||f }tg }| jr|| jjk rdnd	}|r|
r`| jrJ| jrJ fd
d}tjj||||}n||| d}|d }|rjd} r||d f }q|r||f }|stdd |||fD S t|||dS )Nr'   r8   r   r   r         ?r=   TFc                    s    fdd}|S )Nc                     s    | f S rm   r'   r   r   r   r'   r(   r     s    zOUniSpeechEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr'   r   r   r   r(   r     s    z7UniSpeechEncoder.forward.<locals>.create_custom_forwardr   NNc                 s   s   | ]}|d k	r|V  qd S rm   r'   r:   vr'   r'   r(   	<genexpr>  s      z+UniSpeechEncoder.forward.<locals>.<genexpr>last_hidden_stater   r   )	unsqueezerepeatr)   tor>   r$   finfominexpandr   rt   r   r
   r   rC   r   ri   	layerdropr   r   r   tupler   rh   r   r,   r   r   r   Zall_hidden_statesZall_self_attentionsZexpand_attention_maskZposition_embeddingsZdeepspeed_zero3_is_enabledr   Zdropout_probabilityZskip_the_layerr   Zlayer_outputsr'   r   r(   ro     sd    
&   





  
zUniSpeechEncoder.forward)NFFT)r    r!   r"   r\   r$   r   r   r   rJ   ro   rq   r'   r'   rk   r(   r     s       r   c                       s&   e Zd Z fddZdddZ  ZS )	UniSpeechEncoderStableLayerNormc                    sf   t     | _t | _tj j jd| _	t
 j| _t fddt jD | _d| _d S )Nr   c                    s   g | ]}t  qS r'   )r   r9   r   r'   r(   r<   /  s     z<UniSpeechEncoderStableLayerNorm.__init__.<locals>.<listcomp>Fr   r   rk   r   r(   r\   (  s    

z(UniSpeechEncoderStableLayerNorm.__init__NFTc                    s  |rdnd } rdnd }|d k	r| ddd|jd }d|| < d|d d d d d d f j|jd }|t|jj }||jd d|jd |jd }| 	|}	||	 }| 
|}t }
| jD ]}|r||f }tg }| jr|| jjk rdnd	}|r|
rR| jr<| jr< fd
d}tjj||||}n||| d}|d }|r\d} r||d f }q| |}|r||f }|stdd |||fD S t|||dS )Nr'   r8   r   r   r   r   r=   TFc                    s    fdd}|S )Nc                     s    | f S rm   r'   r   r   r'   r(   r   ^  s    z^UniSpeechEncoderStableLayerNorm.forward.<locals>.create_custom_forward.<locals>.custom_forwardr'   r   r   r   r(   r   ]  s    zFUniSpeechEncoderStableLayerNorm.forward.<locals>.create_custom_forwardr   r   c                 s   s   | ]}|d k	r|V  qd S rm   r'   r   r'   r'   r(   r   z  s      z:UniSpeechEncoderStableLayerNorm.forward.<locals>.<genexpr>r   )r   r   r)   r   r>   r$   r   r   r  r   r   r
   r   rC   r   ri   r  r   r   r   rt   r  r   r  r'   r   r(   ro   3  sd    
&   




  

z'UniSpeechEncoderStableLayerNorm.forward)NFFTrp   r'   r'   rk   r(   r  '  s       r  c                       s4   e Zd ZdZ fddZedd Zdd Z  ZS )UniSpeechGumbelVectorQuantizerz
    Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
    GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
    c                    s   t    |j| _|j| _|j| j dkrDtd|j d| j dt	t
d| j| j |j| j | _t|jd | j| j | _d| _d S )Nr   z`config.codevector_dim z5 must be divisible by `config.num_codevector_groups` z for concatenationr   r8   r   )r[   r\   Znum_codevector_groupsrx   Znum_codevectors_per_groupnum_varscodevector_dimr@   r   	Parameterr$   r%   codevectorsr   r]   weight_projtemperaturer   rk   r'   r(   r\     s    
z'UniSpeechGumbelVectorQuantizer.__init__c                 C   s8   | j dd}ttj|t|d  dd  }|S )Nr   r   gHz>r8   )meanr$   exprE   log)ZprobsZmarginal_probs
perplexityr'   r'   r(   _compute_perplexity  s    (z2UniSpeechGumbelVectorQuantizer._compute_perplexityc                 C   s  |j \}}}| |}||| | j d}| jr~tjj| | j	dd
|}tj||| | jd dd}| |}nH|jdd}|j|j  d|ddd}||| | jd}| |}||| d}|d| j }	|	|| | j| jd}
|
d||d}
|
|fS )Nr8   T)tauhardr   r   r   ru   )r)   r  r   rx   r   r   r   Zgumbel_softmaxr   r  type_asr$   r   r  ZargmaxZ	new_zerosZscatter_r   r
  r  rE   )rh   r   rR   r6   r   Zcodevector_probsZcodevector_soft_distr  Zcodevector_idxZcodevectors_per_groupr
  r'   r'   r(   ro     s:    
    
 
z&UniSpeechGumbelVectorQuantizer.forward)	r    r!   r"   r#   r\   staticmethodr  ro   rq   r'   r'   rk   r(   r    s
   
r  c                   @   s\   e Zd ZdZeZdZdZdZdd Z	e
ejef ddd	Zeejd
ddZdddZdS )UniSpeechPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    	unispeechr   Tc              	   C   s  t |tr>|jjjjddd |jjj  tj	
|j njt |trtj	j|jjddtd|jjd |jj   d tj	|jjd nt |trtd|jj }tj	j
|jj| |d tj	j
|jj| |d nt |tjr|jjjd| jjd |jdk	r|jj  nt |tjtjfrN|jj  |jjd nZt |tjrtj	|j |jdk	rt|j|j|jd   }tj	j
|j| |d dS )	zInitialize the weightsr   r   )r  Zstdr   r   )abNr   )r   r  r  r~   dataZnormal_rZ   Zzero_r   inituniform_r
  rz   rd   mathsqrtrX   Zin_channelsZ	constant_r   r   Zin_featuresr   ri   Zinitializer_rangers   ry   fill_r`   Zkaiming_normal_r|   )rh   r   kr'   r'   r(   _init_weights  s6    

 
z&UniSpeechPreTrainedModel._init_weights)rS   c                 C   s4   dd }t | jj| jjD ]\}}||||}q|S )zH
        Computes the output length of the convolutional layers
        c                 S   s   t j| | |ddd S )Nfloor)Zrounding_moder   )r$   div)r2   rX   rY   r'   r'   r(   _conv_out_length  s    zSUniSpeechPreTrainedModel._get_feat_extract_output_lengths.<locals>._conv_out_length)zipri   ra   rb   )rh   rS   r$  rX   rY   r'   r'   r(    _get_feat_extract_output_lengths  s    z9UniSpeechPreTrainedModel._get_feat_extract_output_lengths)feature_vector_lengthr,   c                 C   s   |j ddd d df }| |tj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dg d
dg }|S )Nr8   r   r   )r>   devicer   )r(  )Zcumsumr&  r   r$   longr)   rI   r>   r(  rL   fliprJ   )rh   r'  r,   Znon_padded_lengthsZoutput_lengthsrR   r'   r'   r(   "_get_feature_vector_attention_mask  s    
  "z;UniSpeechPreTrainedModel._get_feature_vector_attention_maskFc                 C   s   t |tttfr||_d S rm   )r   r   r  r   r   )rh   r   valuer'   r'   r(   _set_gradient_checkpointing  s    z4UniSpeechPreTrainedModel._set_gradient_checkpointingN)F)r    r!   r"   r#   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr!  r   r$   
LongTensorr0   r&  r+  r-  r'   r'   r'   r(   r    s   !r  a  
    UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled
    Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei,
    Michael Zeng, Xuedong Huang.

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving etc.).

    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`UniSpeechConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aI  
    Args:
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
            soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
            conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
        attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
            1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            <Tip warning={true}>

            `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
            True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
            **not** be passed to avoid degraded performance when doing batched inference. For such models
            `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
            models also yield slightly different results depending on whether `input_values` is padded or not.

            </Tip>

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zcThe bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zed fddZdejeej eej dddZ	e
eeeeeded	deej eej eej ee ee ee eeef d
ddZ  ZS )UniSpeechModelr   c                    sz   t  | || _t|| _t|| _|jdks:|jdkrRt	
t|j | _|jrdt|| _n
t|| _|   d S )Nr   )r[   r\   ri   r   feature_extractorr   feature_projectionmask_time_probmask_feature_probr   r	  r$   r%   r   r  masked_spec_embedZdo_stable_layer_normr  encoderr   	post_initr   rk   r'   r(   r\   T  s    


zUniSpeechModel.__init__N)r   mask_time_indicesr,   c                 C   s  t | jdds|S | \}}}|dk	r<| j|j||< nZ| jjdkr| jrt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://arxiv.org/abs/1904.08779).
        Zapply_spec_augmentTNr   )r*   r+   r,   r-   )r(  r>   )r*   r+   r-   r8   )r   ri   r   r5  r   r>   r3  r   rU   Zmask_time_lengthZmask_time_min_masksr$   r   r(  rJ   r4  Zmask_feature_lengthZmask_feature_min_masksr  )rh   r   r8  r,   rR   r6   r   Zmask_feature_indicesr'   r'   r(   _mask_hidden_statesf  s4    z"UniSpeechModel._mask_hidden_statesaudio)r   output_typer.  modalityexpected_output)r   r,   r8  r   r   r   r.   c           
      C   s   |d k	r|n| j j}|d k	r |n| j j}|d k	r4|n| j j}| |}|dd}|d k	rl| |jd |}| |\}}| j	|||d}| j
|||||d}	|	d }|s||f|	dd   S t|||	j|	jdS )Nr   r   )r8  r,   r,   r   r   r   r   )r   extract_featuresr   r   )ri   r   r   use_return_dictr1  rv   r+  r)   r2  r9  r6  r   r   r   )
rh   r   r,   r8  r   r   r   r?  r   Zencoder_outputsr'   r'   r(   ro     s<    
  zUniSpeechModel.forward)NN)NNNNN)r    r!   r"   r   r\   r$   r%   r   r/  r9  r   UNISPEECH_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   rJ   r   r   ro   rq   r'   r'   rk   r(   r0  O  s<     .
     
r0  zPUniSpeech Model with a vector-quantization module and ctc loss for pre-training.c                       s   e Zd Zed fddZedddZdd Zd	d
 Ze	de
je
je
jedddZeeeeeddee
j ee
j ee ee ee eeef dddZ  ZS )UniSpeechForPreTrainingr   c                    s~   t  | t|| _t|j| _t|| _	t
|j|j| _t
|j|j| _t
|j|j| _t|j| _|   d S rm   )r[   r\   r0  r  r   r   Zfeat_quantizer_dropoutdropout_featuresr  	quantizerr   r  Zproj_codevector_dim	project_qr   project_hidZnum_ctc_classesctc_projfinal_dropoutr   r7  r   rk   r'   r(   r\     s    

z UniSpeechForPreTraining.__init__)r  c                 C   s   || j _dS )zb
        Set the Gumbel softmax temperature to a given value. Only necessary for training
        N)rG  r  )rh   r  r'   r'   r(   set_gumbel_temperature  s    z.UniSpeechForPreTraining.set_gumbel_temperaturec                 C   s   t dt |   dS z
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5.Please use the equivalent `freeze_feature_encoder` method instead.Nr   r   r   freeze_feature_encoderrh   r'   r'   r(   freeze_feature_extractor  s
    z0UniSpeechForPreTraining.freeze_feature_extractorc                 C   s   | j j  dS 
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        Nr  r1  r   rQ  r'   r'   r(   rP    s    z.UniSpeechForPreTraining.freeze_feature_encoderr   )target_featuresnegative_featurespredicted_featuresr  c                 C   s@   t j| |gdd} t j| |  dd}|| }|| }|S )z
        Compute logits for contrastive loss based using cosine similarity as the distance measure between
        `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
        r   r   r8   )r$   r   Zcosine_similarityr   r  )rV  rW  rX  r  logitsr'   r'   r(   compute_contrastive_logits  s
    
z2UniSpeechForPreTraining.compute_contrastive_logits)r;  r.  N)r   r,   r   r   r   r.   c                 C   sB  |dk	r|n| j j}| j|||||d}|d }| |d }| |\}	}
| |	}	| |	}	t|	d|	d
| j j}|dd}t| |j}|dd}|d}||d|	| d }| |}| |}d}|s*|dk	r|||	|
f|dd  S ||	|
f|dd  S t|||	|
|j|jdS )	a  
        mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
            masked extracted features in *config.proj_codevector_dim* space.
        sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
            Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
            Required input for pre-training.

        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining

        >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
        >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
        >>> # TODO: Add full pretraining example
        ```Nr>  r   r   r8   r   r   )r   r   r   r   r   r   )ri   r@  r  rF  rG  rH  rI  r$   emptyr   r  Zreplace_probrv   Z	bernoullirJ   r   r(  r   Zmasked_fillr   rJ  r   r   r   )rh   r   r,   r   r   r   r   Ztransformer_featuresr?  Zquantized_featuresr   Zprob_replace_matrixZsampled_replace_matrixrY  r   r'   r'   r(   ro     sL    






zUniSpeechForPreTraining.forward)r   )NNNN)r    r!   r"   r   r\   r0   rL  rR  rP  r  r$   r%   rZ  r   rA  r   r   rC  r   r   rJ   r   r   ro   rq   r'   r'   rk   r(   rE    s4    
    
rE  zgUniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).c                       s   e Zd Zdee d fddZdd Zdd Zd	d
 Zdd Z	e
eeeeeeeddeej eej ee ee ee eej eeef dddZ  ZS )UniSpeechForCTCN)target_langc                    s~   t  | t|| _t|j| _|| _|j	d krFt
d| j dt|dr\|jr\|jn|j}t||j	| _|   d S )NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `UniSpeechForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.add_adapter)r[   r\   r0  r  r   r   rK  r   r]  
vocab_sizer@   rl   r   r^  output_hidden_sizer   r   lm_headr7  )rh   ri   r]  r`  rk   r'   r(   r\   d  s    

zUniSpeechForCTC.__init__c                 C   sr   | j }|dk	r2t| jdddkr2td| dn<|dkrXt| jdddk	rXtd n|dk	rn| j|dd dS )a'  
        This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
        passing `target_lang=...` to `from_pretrained(...)`.

        This method is **not** supposed to be called by the user and is prone to be changed in the future.
        Nr   zCannot pass `target_lang`: z- if `config.adapter_attn_dim` is not defined.z)By default `target_lang` is set to 'eng'.T)Z
force_load)r]  r   ri   r@   loggerinfoZload_adapter)rh   r]  r'   r'   r(   tie_weights{  s    zUniSpeechForCTC.tie_weightsc                 C   s   t dt |   dS )rT  rN  NrO  rQ  r'   r'   r(   rR    s
    z(UniSpeechForCTC.freeze_feature_extractorc                 C   s   | j j  dS rS  rU  rQ  r'   r'   r(   rP    s    z&UniSpeechForCTC.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS z
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        FNr  r   r   r   r'   r'   r(   freeze_base_model  s    z!UniSpeechForCTC.freeze_base_model)r   r;  r.  r=  Zexpected_lossr   r,   r   r   r   labelsr.   c              
   C   sf  |dk	r|n| j j}| j|||||d}|d }| |}| |}	d}
|dk	r"| | j jkrttd| j j |dk	r|ntj	|tj
d}| |dtj
}|dk}|d}||}tjj|	dtjddd}tjjjd	d
, tjj||||| j j| j j| j jd}
W 5 Q R X |sR|	f|td  }|
dk	rN|
f| S |S t|
|	|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nr>  r   z$Label values must be <= vocab_size: r=   r8   )r   r>   r   F)enabled)blankZ	reductionZzero_infinityr   rY  r   r   )ri   r@  r  r   ra  r1   r_  r@   r$   Z	ones_liker)  r&  rE   r   Zmasked_selectr   r   Zlog_softmaxZfloat32rv   backendsZcudnnflagsZctc_lossZpad_token_idZctc_loss_reductionZctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   r   r   )rh   r   r,   r   r   r   ri  r   r   rY  r   rS   Zlabels_maskZtarget_lengthsZflattened_targetsZ	log_probsoutputr'   r'   r(   ro     sR    





   zUniSpeechForCTC.forward)N)NNNNN)r    r!   r"   r   r   r\   rd  rR  rP  rg  r   rA  r   rB  r   rC  _CTC_EXPECTED_OUTPUT_CTC_EXPECTED_LOSSr$   r   rJ   r   r   ro   rq   r'   r'   rk   r(   r\  ^  s6   
     
r\  z
    UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
    SUPERB Keyword Spotting.
    c                       s   e Zd Z fddZdd Zdd Zdd Zeee	e
eed	d
deej eej ee ee ee eej eeef dddZ  ZS )"UniSpeechForSequenceClassificationc                    s   t  | t|dr$|jr$tdt|| _|jd }|jrTt	
t|| | _t	|j|j| _t	|j|j| _|   d S )Nr^  z`Sequence classification does not support the use of UniSpeech adapters (config.add_adapter=True)r   )r[   r\   r   r^  r@   r0  r  r   use_weighted_layer_sumr   r	  r$   rN   layer_weightsr   r   Zclassifier_proj_size	projector
num_labels
classifierr7  )rh   ri   Z
num_layersrk   r'   r(   r\     s    

z+UniSpeechForSequenceClassification.__init__c                 C   s   t dt |   dS rM  rO  rQ  r'   r'   r(   rR    s
    z;UniSpeechForSequenceClassification.freeze_feature_extractorc                 C   s   | j j  dS rS  rU  rQ  r'   r'   r(   rP  "  s    z9UniSpeechForSequenceClassification.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS re  rf  r   r'   r'   r(   rg  *  s    z4UniSpeechForSequenceClassification.freeze_base_modelr:  )r   r;  r.  r<  Nrh  c                 C   sf  |dk	r|n| j j}| j jr dn|}| j|||||d}| j jr|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}|dkr|jdd}
n<| |jd |}d|| < |jdd|jdddd }
| |
}d}|dk	r"t }||d| j j|d}|sR|f|td  }|dk	rN|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTr>  r   r   r8   r   r   rl  )ri   r@  rt  r  ro  r$   stackr   r   r   ru  r   rE   rv  r  r+  r)   rx  r   rw  r   r   r   )rh   r   r,   r   r   r   ri  r   r   Znorm_weightsZpooled_outputZpadding_maskrY  r   Zloss_fctrp  r'   r'   r(   ro   2  sF    

 

z*UniSpeechForSequenceClassification.forward)NNNNN)r    r!   r"   r\   rR  rP  rg  r   rA  r   rB  r   rC  r   r$   r   rJ   r   r   ro   rq   r'   r'   rk   r(   rs    s2   
     
rs  )Nr   )Lr#   r  r   dataclassesr   typingr   r   r   numpyrA   r$   Ztorch.utils.checkpointr   Ztorch.nnr   Zactivationsr	   Zintegrations.deepspeedr
   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   r   r   r   r   r   r   Zconfiguration_unispeechr   Z
get_loggerr    rb  ro  rC  rB  rD  rq  rr  Z'UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LISTr   r0   r   r/  ZndarrayrU   ModulerV   rr   rw   rz   r   r   r   r   r   r   r   r   r   r   r  r  r  ZUNISPEECH_START_DOCSTRINGrA  r0  rE  r\  rs  r'   r'   r'   r(   <module>   s    

(  
x(6 #.X[FO%y   