U
    9%e                     @   s\  d Z ddlZddlZddlmZmZmZ ddlZddl	Z	ddl
Z	ddl	mZ ddlmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZmZmZmZ ddlmZ ee Z!dZ"dZ#dZ$dddgZ%dZ&dZ'dZ(dZ)dZ*dddgZ+dBee,e,f e-e,ee	j. e,ej/dddZ0G dd dej1Z2G dd  d ej1Z3G d!d" d"ej1Z4G d#d$ d$ej1Z5G d%d& d&ej1Z6G d'd( d(ej1Z7G d)d* d*ej1Z8G d+d, d,e8Z9G d-d. d.ej1Z:G d/d0 d0ej1Z;G d1d2 d2ej1Z<G d3d4 d4ej1Z=G d5d6 d6eZ>d7Z?d8Z@ed9e?G d:d; d;e>ZAed<e?G d=d> d>e>ZBed?e?G d@dA dAe>ZCdS )Cz PyTorch SEW model.    N)OptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)is_deepspeed_zero3_enabled)BaseModelOutputCausalLMOutputSequenceClassifierOutput)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )	SEWConfigr   zasapp/sew-tiny-100k-ft-ls100hi$  i   z_'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'gzG?z(anton-l/sew-mid-100k-ft-keyword-spottingz'_unknown_'g
ףp=
#@zasapp/sew-tiny-100kzasapp/sew-small-100kzasapp/sew-mid-100k)shape	mask_probmask_lengthattention_mask	min_masksreturnc                    s  | \}dk rt dkr6t d d dtjd   fdd}|dk	rt|d	  nfd
dt|D }tj	|ft
d}g }	|}
|
dkr|S |D ]v}||}tjjt|d  |dd}t|dkrd }n|d }t|tj|
| tjd| g}|	| qt|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d kr҈d |	|	d k< t||	dd	 |S )af  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr2 }| d  |k rTt| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )intmax)input_lengthnum_masked_spanepsilonr   r   r   sequence_length c/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/sew/modeling_sew.pycompute_num_masked_spani   s    
z6_compute_mask_indices.<locals>.compute_num_masked_spanNc                    s   g | ]} qS r"   r"   .0_)r!   r"   r#   
<listcomp>|   s     z)_compute_mask_indices.<locals>.<listcomp>dtyper   F)replace)
ValueErrornprandomranditemsumdetachtolistrangezerosboolchoicearangelenZconcatenateonesZint32appendarrayZbroadcast_toreshaper   Zput_along_axis)r   r   r   r   r   
batch_sizer$   input_lengthsZspec_aug_maskZspec_aug_mask_idxsZmax_num_masked_spanr   r   Zspec_aug_mask_idxZdummy_mask_idxoffsetsr"   r   r#   _compute_mask_indicesC   s`      

  rB   c                       s&   e Zd Zd fdd	Zdd Z  ZS )SEWNoLayerNormConvLayerr   c                    sj   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   kernel_sizestridebias)super__init__conv_dimin_conv_dimout_conv_dimr   Conv1dconv_kernelconv_stride	conv_biasconvr   feat_extract_activation
activationselfconfiglayer_id	__class__r"   r#   rI      s    
z SEWNoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S N)rQ   rS   rU   hidden_statesr"   r"   r#   forward   s    

zSEWNoLayerNormConvLayer.forward)r   __name__
__module____qualname__rI   r]   __classcell__r"   r"   rX   r#   rC      s   rC   c                       s&   e Zd Zd fdd	Zdd Z  ZS )SEWLayerNormConvLayerr   c                    s|   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   rD   T)Zelementwise_affine)rH   rI   rJ   rK   rL   r   rM   rN   rO   rP   rQ   	LayerNorm
layer_normr   rR   rS   rT   rX   r"   r#   rI      s    
zSEWLayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )Nr%   )rQ   	transposere   rS   r[   r"   r"   r#   r]      s    


zSEWLayerNormConvLayer.forward)r   r^   r"   r"   rX   r#   rc      s   rc   c                       s&   e Zd Zd fdd	Zdd Z  ZS )SEWGroupNormConvLayerr   c                    s   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   rD   T)Z
num_groupsZnum_channelsZaffine)rH   rI   rJ   rK   rL   r   rM   rN   rO   rP   rQ   r   rR   rS   	GroupNormre   rT   rX   r"   r#   rI      s    
zSEWGroupNormConvLayer.__init__c                 C   s"   |  |}| |}| |}|S rZ   )rQ   re   rS   r[   r"   r"   r#   r]      s    


zSEWGroupNormConvLayer.forward)r   r^   r"   r"   rX   r#   rh      s   rh   c                       s$   e Zd Z fddZdd Z  ZS )SEWPositionalConvEmbeddingc              	      s   t    tj|j|j|j|jd |j|jd| _t	 rdd l
}|jj| jjdd tjj| jddd| _W 5 Q R X |j| | jj |j| | jj ntjj| jddd| _t|j| _t|j | _d S )N   )rE   paddinggroupsrF   r   Zmodifier_rankweight)namedim)rH   rI   r   rM   hidden_sizenum_conv_pos_embeddingsZnum_conv_pos_embedding_groupssqueeze_factorrQ   r	   	deepspeedzeroGatheredParametersro   utilsZweight_normZregister_external_parameterweight_vweight_gSEWSamePadLayerrl   r   rR   rS   )rU   rV   ru   rX   r"   r#   rI     s$    
	 z#SEWPositionalConvEmbedding.__init__c                 C   s"   |  |}| |}| |}|S rZ   )rQ   rl   rS   r[   r"   r"   r#   r]     s    


z"SEWPositionalConvEmbedding.forwardr^   r"   r"   rX   r#   rj     s   rj   c                       s$   e Zd Z fddZdd Z  ZS )r{   c                    s$   t    |d dkrdnd| _d S )Nrk   r   r   )rH   rI   num_pad_remove)rU   rs   rX   r"   r#   rI   (  s    
zSEWSamePadLayer.__init__c                 C   s,   | j dkr(|d d d d d | j  f }|S )Nr   )r|   r[   r"   r"   r#   r]   ,  s    
zSEWSamePadLayer.forwardr^   r"   r"   rX   r#   r{   '  s   r{   c                       s$   e Zd Z fddZdd Z  ZS )SEWUpsamplingc                    s:   t    t|j|j|j | _t|j | _	|j| _d S rZ   )
rH   rI   r   Linearrr   rt   
projectionr   rR   rS   rU   rV   rX   r"   r#   rI   3  s    
zSEWUpsampling.__init__c                 C   sd   |  |}| |}| jdkr`| \}}}|| j }|| j }|||| j|}||||}|S )Nr   )r   rS   rt   sizer>   )rU   r\   bszsrc_lenZsrc_embed_dimtgt_lenZtgt_embed_dimr"   r"   r#   r]   9  s    




zSEWUpsampling.forwardr^   r"   r"   rX   r#   r}   2  s   r}   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )SEWFeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr@t ddg fddt jd D  }n6 jdkrd fddt jD }ntd	 j d
t|| _	d| _
d| _d S )Ngroupr   rW   c                    s   g | ]}t  |d  dqS )r   r   )rC   r'   irV   r"   r#   r)   P  s    z.SEWFeatureEncoder.__init__.<locals>.<listcomp>r   layerc                    s   g | ]}t  |d qS )r   )rc   r   r   r"   r#   r)   T  s     z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)rH   rI   Zfeat_extract_normrh   r5   Znum_feat_extract_layersr-   r   
ModuleListconv_layersgradient_checkpointing_requires_grad)rU   rV   r   rX   r   r#   rI   L  s    



zSEWFeatureEncoder.__init__c                 C   s   |   D ]
}d|_qd| _d S )NF)
parametersrequires_gradr   rU   paramr"   r"   r#   _freeze_parameters]  s    z$SEWFeatureEncoder._freeze_parametersc                 C   sj   |d d d f }| j r"| jr"d|_| jD ]<}| j r\| jr\| jr\dd }tjj|||}q(||}q(|S )NTc                    s    fdd}|S )Nc                     s    |  S rZ   r"   inputsmoduler"   r#   custom_forwardm  s    zPSEWFeatureEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr"   r   r   r"   r   r#   create_custom_forwardl  s    z8SEWFeatureEncoder.forward.<locals>.create_custom_forward)r   trainingr   r   r   torchrx   
checkpoint)rU   input_valuesr\   Z
conv_layerr   r"   r"   r#   r]   b  s    

zSEWFeatureEncoder.forward)r_   r`   ra   __doc__rI   r   r]   rb   r"   r"   rX   r#   r   I  s   r   c                       s   e Zd Z fddZ  ZS )SEWFeatureExtractorc                    s8   t  | td| jj d| jjd j dt d S )NzThe class `zD` has been depreciated and will be removed in Transformers v5. Use `r   z
` instead.)rH   rI   warningswarnrY   r_   	__bases__FutureWarningr   rX   r"   r#   rI   }  s
    zSEWFeatureExtractor.__init__)r_   r`   ra   rI   rb   r"   r"   rX   r#   r   |  s   r   c                       s   e Zd ZdZdeeeeed fddZej	eedd	d
Z
dej	eej	 eeej	  eej	 eej	 eeej	eej	 eeej	  f dddZ  ZS )SEWAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FT)	embed_dim	num_headsdropout
is_decoderrG   c                    s   t    || _|| _|| _|| | _| j| | jkrNtd| j d| d| jd | _|| _t	j
|||d| _t	j
|||d| _t	j
|||d| _t	j
|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      )rG   )rH   rI   r   r   r   head_dimr-   scalingr   r   r~   k_projv_projq_projout_proj)rU   r   r   r   r   rG   rX   r"   r#   rI     s    

zSEWAttention.__init__)tensorseq_lenr   c                 C   s    | ||| j| jdd S )Nr   rk   )viewr   r   rg   
contiguous)rU   r   r   r   r"   r"   r#   _shape  s    zSEWAttention._shapeN)r\   key_value_statespast_key_valuer   layer_head_maskoutput_attentionsr   c                 C   sx  |dk	}|  \}}	}
| || j }|r\|dk	r\|d jd |jd kr\|d }|d }n|r| | |d|}| | |d|}n|dk	r| | |d|}| | |d|}tj|d |gdd}tj|d |gdd}n(| | |d|}| | |d|}| j	r ||f}|| j
 d| jf}| ||	|j| }|j| }|j| }| d}t||dd}|  || j
 |	|fkrtd|| j
 |	|f d|   |dk	r |  |d|	|fkrtd	|d|	|f d|   ||| j
|	|| }||| j
 |	|}tjj|dd}|dk	r|  | j
fkrhtd
| j
f d|   |dddd||| j
|	| }||| j
 |	|}|r||| j
|	|}||| j
 |	|}nd}tjj|| j| jd}t||}|  || j
 |	| jfkr4td|| j
 |	| jf d|   ||| j
|	| j}|dd}|||	| j}| |}|||fS )z#Input shape: Batch x Time x ChannelNr   rk   r   r%   rq   z$Attention weights should be of size z	, but is z!Attention mask should be of size z/Head mask for a single layer should be of size )pr   z `attn_output` should be of size )r   r   r   r   r   r   r   r   catr   r   r   r   r>   Zbmmrg   r-   r   
functionalsoftmaxr   r   r   r   )rU   r\   r   r   r   r   r   Zis_cross_attentionr   r   r(   Zquery_statesZ
key_statesZvalue_statesZ
proj_shaper   attn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr"   r"   r#   r]     s~    





" 
zSEWAttention.forward)r   FT)NNNNF)r_   r`   ra   r   r   floatr7   rI   r   Tensorr   r   r   r]   rb   r"   r"   rX   r#   r     s4           r   c                       s$   e Zd Z fddZdd Z  ZS )SEWFeedForwardc                    sp   t    t|j| _t|j|j| _	t
|jtrDt|j | _n|j| _t|j|j| _t|j| _d S rZ   )rH   rI   r   DropoutZactivation_dropoutintermediate_dropoutr~   rr   Zintermediate_sizeintermediate_dense
isinstanceZ
hidden_actstrr   intermediate_act_fnoutput_densehidden_dropoutoutput_dropoutr   rX   r"   r#   rI   $  s    
zSEWFeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S rZ   )r   r   r   r   r   r[   r"   r"   r#   r]   1  s    




zSEWFeedForward.forwardr^   r"   r"   rX   r#   r   #  s   r   c                       s&   e Zd Z fddZdddZ  ZS )SEWEncoderLayerc                    sf   t    t|j|j|jdd| _t|j	| _
tj|j|jd| _t|| _tj|j|jd| _d S )NF)r   r   r   r   Zeps)rH   rI   r   rr   Znum_attention_headsZattention_dropout	attentionr   r   r   r   rd   layer_norm_epsre   r   feed_forwardfinal_layer_normr   rX   r"   r#   rI   =  s    

zSEWEncoderLayer.__init__NFc                 C   sf   |}| j |||d\}}}| |}|| }| |}|| | }| |}|f}|rb||f7 }|S )Nr   r   )r   r   re   r   r   )rU   r\   r   r   Zattn_residualr   r(   outputsr"   r"   r#   r]   J  s      



zSEWEncoderLayer.forward)NFr^   r"   r"   rX   r#   r   <  s   r   c                       s&   e Zd Z fddZdddZ  ZS )	
SEWEncoderc                    s   t     | _t | _t j j| _tj	 j
 jd| _t j| _t fddt jD | _t | _d| _d S )Nr   c                    s   g | ]}t  qS r"   )r   r&   r   r"   r#   r)   f  s     z'SEWEncoder.__init__.<locals>.<listcomp>F)rH   rI   rV   rj   pos_conv_embedr   Z	AvgPool1drt   poolrd   rr   r   re   r   r   r   r   r5   num_hidden_layerslayersr}   upsampler   r   rX   r   r#   rI   _  s    

 
zSEWEncoder.__init__NFTc              	      s  |rdnd } rdnd }|d k	rd|| < |  d}|| jj }	|jd | jj }
tjd|
|	jddd	|	jd d}||	ddk   }d|d d d d d d f j
|jd }|t|jj }|	|jd d|jd |jd }|jd }|dd	}| |}| |}t|d|d}|d
d |f |d
d |f  }|dd	}| |}| |}t }| jD ]}|r||f }tg }| jr|| jjk rdnd}|r|r
| jr| jr fdd}tjj||||}n||| d}|d }|rd} rx||d f }qx|r<||f }| |}|jd |k rvtj|ddd||jd  f}|st dd |||fD S t!|||dS )Nr"   r   r%   r   r   device      ?r*   rk   .TFc                    s    fdd}|S )Nc                     s    | f S rZ   r"   r   )r   r   r"   r#   r     s    zISEWEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr"   r   r   r   r#   r     s    z1SEWEncoder.forward.<locals>.create_custom_forwardr   )NNc                 s   s   | ]}|d k	r|V  qd S rZ   r"   )r'   vr"   r"   r#   	<genexpr>  s      z%SEWEncoder.forward.<locals>.<genexpr>Zlast_hidden_stater\   
attentions)"longr2   rV   rt   r   r   r9   r   r   expandtor+   Zfinfominrg   r   r   r   re   r   r	   r   r0   r   Z	layerdropr   rx   r   r   r   r   padtupler
   )rU   r\   r   r   output_hidden_statesreturn_dictZall_hidden_statesZall_self_attentionsr@   output_lengthsZmax_encoder_lengthZattention_idsZn_input_timestepsZposition_embeddingsZpooled_hidden_statesZ
min_lengthZdeepspeed_zero3_is_enabledr   Zdropout_probabilityZskip_the_layerr   Zlayer_outputsr"   r   r#   r]   j  s    
  &   


 




  

 zSEWEncoder.forward)NFFTr^   r"   r"   rX   r#   r   ^  s       r   c                   @   s\   e Zd ZdZeZdZdZdZdd Z	ddd	Z
eejef d
ddZeejdddZdS )SEWPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    sewr   Tc              	   C   s  t |trRtjj|jjddtd|jj	d |jj
   d tj|jjd nt |tjrv|jjjd| jjd nt |tjtjfr|jj  |jjd nt |tjrPt r@ddl}t|drt|d	r|jj|j|jgdd
 tj|jj W 5 Q R X n.|jj|jdd
 tj|jj W 5 Q R X ntj|jj t |tjtjfr||jdk	r||jj  dS )zInitialize the weightsr   rk   r   )meanZstdr   r   Nry   rz   rn   )r   rj   r   initZnormal_rQ   ro   mathsqrtrE   Zin_channelsZ	constant_rG   r~   datarV   Zinitializer_rangerd   ri   Zzero_Zfill_rM   r	   ru   hasattrrv   rw   ry   rz   Zkaiming_normal_)rU   r   ru   r"   r"   r#   _init_weights  s.    
  z SEWPreTrainedModel._init_weightsFc                 C   s   t |ttfr||_d S rZ   )r   r   r   r   )rU   r   valuer"   r"   r#   _set_gradient_checkpointing  s    z.SEWPreTrainedModel._set_gradient_checkpointing)r@   c                 C   s4   dd }t | jj| jjD ]\}}||||}q|S )zH
        Computes the output length of the convolutional layers
        c                 S   s   t j| | |ddd S )Nfloor)Zrounding_moder   )r   div)r   rE   rF   r"   r"   r#   _conv_out_length   s    zMSEWPreTrainedModel._get_feat_extract_output_lengths.<locals>._conv_out_length)ziprV   rN   rO   )rU   r@   r   rE   rF   r"   r"   r#    _get_feat_extract_output_lengths  s    z3SEWPreTrainedModel._get_feat_extract_output_lengths)feature_vector_lengthr   c                 C   s~   |  |dtj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dgd
dg }|S )Nr%   r   )r+   r   r   r   )r   r2   r   r   r   r   r6   r+   r   r9   flipZcumsumr7   )rU   r   r   r   r?   r"   r"   r#   "_get_feature_vector_attention_mask
  s    
  "z5SEWPreTrainedModel._get_feature_vector_attention_maskN)F)r_   r`   ra   r   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r   r   r   
LongTensorr   r   r   r"   r"   r"   r#   r     s    
r   a  
    SEW was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech
    Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger,
    Yoav Artzi.

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving etc.).

    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`SEWConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
            soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
            conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
        attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
            1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z]The bare SEW Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zed fddZdejeej eej dddZ	e
eeeeeded	deej eej eej ee ee ee eeef d
ddZ  ZS )SEWModelr   c                    s   t  | || _t|| _tj|jd |jd| _	|jd |j
k| _| jrbt|jd |j
| _t|j| _|jdks|jdkrtt|j
 | _t|| _|   d S )Nr%   r   r   )rH   rI   rV   r   feature_extractorr   rd   rJ   r   re   rr   project_featuresr~   feature_projectionr   Zfeat_proj_dropoutfeature_dropoutmask_time_probmask_feature_prob	Parameterr   FloatTensorZuniform_masked_spec_embedr   encoder	post_initr   rX   r"   r#   rI   J  s    

zSEWModel.__init__N)r\   mask_time_indicesr   c                 C   s  t | jdds|S | \}}}|dk	r<| j|j||< nZ| jjdkr| jrt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://arxiv.org/abs/1904.08779).
        Zapply_spec_augmentTNr   )r   r   r   r   )r   r+   )r   r   r   r%   )getattrrV   r   r
  r   r+   r  r   rB   Zmask_time_lengthZmask_time_min_masksr   r   r   r7   r  Zmask_feature_lengthZmask_feature_min_masksr   )rU   r\   r  r   r?   r!   rr   Zmask_feature_indicesr"   r"   r#   _mask_hidden_states^  s4    zSEWModel._mask_hidden_statesaudio)r   output_typer   modalityexpected_output)r   r   r  r   r   r   r   c           
      C   s   |d k	r|n| j j}|d k	r |n| j j}|d k	r4|n| j j}| |}|dd}| |}| jrl| |}| 	|}|d k	r| 
|jd |}| j||d}| j|||||d}	|	d }|s|f|	dd   S t||	j|	jdS )Nr   rk   )r  r   r   r   r   r   r   )rV   r   r   use_return_dictr  rg   re   r  r  r  r   r   r  r  r
   r\   r   )
rU   r   r   r  r   r   r   Zextract_featuresr\   Zencoder_outputsr"   r"   r#   r]     s8    



zSEWModel.forward)NN)NNNNN)r_   r`   ra   r   rI   r   r	  r   r   r  r   SEW_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr
   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r7   r   r   r]   rb   r"   r"   rX   r#   r  E  s<     .
     
r  zaSEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).c                       s   e Zd Zdee d fddZdd Zdd Zd	d
 Zdd Z	e
eeeeeeeddeej eej ee ee ee eej eeef dddZ  ZS )	SEWForCTCN)target_langc                    s~   t  | t|| _t|j| _|| _|j	d krFt
d| j dt|dr\|jr\|jn|j}t||j	| _|   d S )NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `SEWForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.add_adapter)rH   rI   r  r   r   r   Zfinal_dropoutr   r  
vocab_sizer-   rY   r   r  output_hidden_sizerr   r~   lm_headr  )rU   rV   r  r  rX   r"   r#   rI     s    

zSEWForCTC.__init__c                 C   sr   | j }|dk	r2t| jdddkr2td| dn<|dkrXt| jdddk	rXtd n|dk	rn| j|dd dS )a'  
        This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
        passing `target_lang=...` to `from_pretrained(...)`.

        This method is **not** supposed to be called by the user and is prone to be changed in the future.
        NZadapter_attn_dimzCannot pass `target_lang`: z- if `config.adapter_attn_dim` is not defined.z)By default `target_lang` is set to 'eng'.T)Z
force_load)r  r  rV   r-   loggerinfoZload_adapter)rU   r  r"   r"   r#   tie_weights  s    zSEWForCTC.tie_weightsc                 C   s   t dt |   dS )
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5.Please use the equivalent `freeze_feature_encoder` method instead.Nr   r   r   freeze_feature_encoderrU   r"   r"   r#   freeze_feature_extractor  s
    z"SEWForCTC.freeze_feature_extractorc                 C   s   | j j  dS r#  Nr   r  r   r'  r"   r"   r#   r&    s    z SEWForCTC.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS z
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        FNr   r   r   r   r"   r"   r#   freeze_base_model
  s    zSEWForCTC.freeze_base_model)r   r  r   r  expected_lossr   r   r   r   r   labelsr   c              
   C   sf  |dk	r|n| j j}| j|||||d}|d }| |}| |}	d}
|dk	r"| | j jkrttd| j j |dk	r|ntj	|tj
d}| |dtj
}|dk}|d}||}tjj|	dtjddd}tjjjd	d
, tjj||||| j j| j j| j jd}
W 5 Q R X |sR|	f|td  }|
dk	rN|
f| S |S t|
|	|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nr  r   z$Label values must be <= vocab_size: r*   r%   )rq   r+   r   F)enabled)blankZ	reductionZzero_infinitylosslogitsr\   r   )rV   r  r   r   r  r   r  r-   r   Z	ones_liker   r   r2   r   Zmasked_selectr   r   Zlog_softmaxZfloat32rg   backendsZcudnnflagsZctc_lossZpad_token_idZctc_loss_reductionZctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   r\   r   )rU   r   r   r   r   r   r0  r   r\   r5  r4  r@   Zlabels_maskZtarget_lengthsZflattened_targetsZ	log_probsoutputr"   r"   r#   r]     sR    





   zSEWForCTC.forward)N)NNNNN)r_   r`   ra   r   r   rI   r"  r(  r&  r-  r   r  r   r  r   r  _CTC_EXPECTED_OUTPUT_CTC_EXPECTED_LOSSr   r   r7   r   r   r]   rb   r"   r"   rX   r#   r    s6   
     
r  z
    SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
    Keyword Spotting.
    c                       s   e Zd Z fddZdd Zdd Zdd Zeee	e
eed	eed
deej eej ee ee ee eej eeef dddZ  ZS )SEWForSequenceClassificationc                    s   t  | t|dr$|jr$tdt|| _|jd }|jrTt	
t|| | _t	|j|j| _t	|j|j| _|   d S )Nr  zZSequence classification does not support the use of SEW adapters (config.add_adapter=True)r   )rH   rI   r   r  r-   r  r   r   use_weighted_layer_sumr   r  r   r;   layer_weightsr~   rr   Zclassifier_proj_size	projector
num_labels
classifierr  )rU   rV   Z
num_layersrX   r"   r#   rI   k  s    

z%SEWForSequenceClassification.__init__c                 C   s   t dt |   dS )z
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        r$  Nr%  r'  r"   r"   r#   r(  |  s
    z5SEWForSequenceClassification.freeze_feature_extractorc                 C   s   | j j  dS r)  r*  r'  r"   r"   r#   r&    s    z3SEWForSequenceClassification.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS r+  r,  r   r"   r"   r#   r-    s    z.SEWForSequenceClassification.freeze_base_modelr  )r   r  r   r  r  r.  Nr/  c                 C   sf  |dk	r|n| j j}| j jr dn|}| j|||||d}| j jr|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}|dkr|jdd}
n<| |jd |}d|| < |jdd|jdddd }
| |
}d}|dk	r"t }||d| j j|d}|sR|f|td  }|dk	rN|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTr  r   r   r%   r   r   r3  )rV   r  r=  r   r8  r   stackr   r   r   r>  r   r2   r?  r   r   r   rA  r   r@  r   r\   r   )rU   r   r   r   r   r   r0  r   r\   Znorm_weightsZpooled_outputZpadding_maskr5  r4  Zloss_fctr9  r"   r"   r#   r]     sF    

 

z$SEWForSequenceClassification.forward)NNNNN)r_   r`   ra   rI   r(  r&  r-  r   r  r   _SEQ_CLASS_CHECKPOINTr   r  _SEQ_CLASS_EXPECTED_OUTPUT_SEQ_CLASS_EXPECTED_LOSSr   r   r   r7   r   r   r]   rb   r"   r"   rX   r#   r<  b  s6   	     
r<  )Nr   )Dr   r   r   typingr   r   r   numpyr.   r   Ztorch.utils.checkpointr   Ztorch.nnr   Zactivationsr   Zintegrations.deepspeedr	   Zmodeling_outputsr
   r   r   Zmodeling_utilsr   rx   r   r   r   r   Zconfiguration_sewr   Z
get_loggerr_   r   r8  r  r  r  r:  r;  rC  rD  rE  Z!SEW_PRETRAINED_MODEL_ARCHIVE_LISTr   r   r   ZndarrayrB   ModulerC   rc   rh   rj   r{   r}   r   r   r   r   r   r   r   ZSEW_START_DOCSTRINGr  r  r  r<  r"   r"   r"   r#   <module>   s   

  
x"3 "nK| 