U
    ,-et                    @   sN  d Z ddlZddlmZmZmZ ddlZddlZddl	Zddlm
Z
 ddlmZ ddlmZ ddlmZ dd	lmZmZmZmZmZ dd
lmZ ddlmZmZmZmZ ddlmZ ddl m!Z!m"Z" e#e$Z%dZ&dZ'dgZ(ej)e*e*dddZ+d>ej,ej-ej.e*dddZ/d?ej)ej-ee* dddZ0d@ee*e*f e1e*eej2 e*ej3dddZ4ej)e*ej)ddd Z5ej3d!d"d#Z6G d$d% d%e
j7Z8G d&d' d'e
j9Z:G d(d) d)e
j9Z;G d*d+ d+e
j9Z<G d,d- d-eZ=d.Z>d/Z?d0Z@G d1d2 d2e=ZAG d3d4 d4e=ZBed5e>G d6d7 d7e=ZCed8e>G d9d: d:e=ZDed;e@G d<d= d=e=ZEdS )Az PyTorch Whisper model.    N)OptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)WhisperTimeStampLogitsProcessor)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentionsSeq2SeqLMOutputSeq2SeqModelOutputSequenceClassifierOutput)PreTrainedModel)add_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )WhisperConfig)TASK_IDSTO_LANGUAGE_CODEr   zopenai/whisper-tinyzopenai/whisper-base)	input_idspad_token_iddecoder_start_token_idc                 C   sh   |  | j}| ddddf  |ddddf< ||dddf< |dkrTtd||dk| |S )z1
    Shift input ids one token to the right.
    Nr   r   z1self.model.config.pad_token_id has to be defined.i)Z	new_zerosshapeclone
ValueErrormasked_fill_)r   r   r   Zshifted_input_ids r    m/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/whisper/modeling_whisper.pyshift_tokens_right;   s    (r"   )input_ids_shapedtypedevicepast_key_values_lengthc                 C   s   | \}}t j||ft |j|d}t j|d|d}|||d |ddk d ||}|dkrt j	t j
||||d|gdd}|ddddddf |d||| S )zB
    Make causal mask used for bi-directional self-attention.
    )r%   r   r   r   )r$   r%   dimN)torchfullfinfominarangesizer   viewtocatzerosexpand)r#   r$   r%   r&   bsztgt_lenmaskZ	mask_condr    r    r!   _make_causal_maskL   s    "
 r7   )r6   r$   r5   c                 C   sj   |   \}}|dk	r|n|}| ddddddf |d|||}d| }||tjt|jS )z_
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    Nr         ?)r.   r3   r0   Zmasked_fillr)   boolr+   r,   )r6   r$   r5   r4   src_lenZexpanded_maskZinverted_maskr    r    r!   _expand_mask^   s
    *r;   )r   	mask_probmask_lengthattention_mask	min_masksreturnc                    s  | \}dk rt dkr6t d d dtjd   fdd}|dk	rt|d	  nfd
dt|D }tj	|ft
d}g }	|}
|
dkr|S |D ]v}||}tjjt|d  |dd}t|dkrd }n|d }t|tj|
| tjd| g}|	| qt|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d kr҈d |	|	d k< t||	dd	 |S )af  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr2 }| d  |k rTt| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )intmax)input_lengthnum_masked_spanepsilonr=   r<   r?   sequence_lengthr    r!   compute_num_masked_span   s    
z6_compute_mask_indices.<locals>.compute_num_masked_spanNr   c                    s   g | ]} qS r    r    .0_)rH   r    r!   
<listcomp>   s     z)_compute_mask_indices.<locals>.<listcomp>r$   r   F)replace)r   nprandomranditemsumdetachtolistranger2   r9   choicer-   lenZconcatenateonesZint32appendarrayZbroadcast_toreshaperC   Zput_along_axis)r   r<   r=   r>   r?   
batch_sizerI   input_lengthsZspec_aug_maskZspec_aug_mask_idxsZmax_num_masked_spanrD   rE   Zspec_aug_mask_idxZdummy_mask_idxoffsetsr    rF   r!   _compute_mask_indicesm   s`      

  ra   )inputsfilter_widthr@   c                 C   sr   |dks|d dkrt d|d }| jd |kr6| S tjj| ||ddfdd} | d|d d d|f }|S )	z
    Applies a median filter of width `filter_width` along the last dimension of the input.

    The `inputs` tensor is assumed to be 3- or 4-dimensional.
    r      r   z&`filter_width` should be an odd numberr   Zreflect)mode.)r   r   r   
functionalpadZunfoldsort)rb   rc   Z	pad_widthresultr    r    r!   _median_filter   s    rj   )matrixc                 C   s6  | j \}}tj|d |d ftjdtj }tj|d |d ftjd }d|d< td|d D ]}td|d D ]}||d |d f }||d |f }|||d f }	||k r||	k r|d }
}n&||k r||	k r|d }
}n
|	d }
}| |d |d f |
 |||f< ||||f< qrq`|j d d }|j d d }d|dddf< d|dddf< g }g }|dkst|dkr||d  ||d  |||f dkr|d8 }|d8 }nN|||f dkr|d8 }n2|||f dkr|d8 }ntd| d| d	q`t|ddd
 }t|ddd
 }||fS )z
    Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
    token-level timestamps.
    r   rN   r   )r   r   rd   Nz9Internal error in dynamic time warping. Unexpected trace[z, z]. Please file a bug report.r   )	r   rP   rZ   float32infrW   r[   RuntimeErrorr\   )rk   Zoutput_lengthrD   ZcosttracejiZc0c1c2cttext_indicestime_indicesr    r    r!   _dynamic_time_warping   sL    
"
 


rx   c                       s6   e Zd Zdeeee d fddZd	ddZ  ZS )
WhisperPositionalEmbeddingN)num_positionsembedding_dimpadding_idxc                    s   t  || d S N)super__init__)selfrz   r{   r|   	__class__r    r!   r   0  s    z#WhisperPositionalEmbedding.__init__r   c                 C   s   | j |||jd   S Nr   )weightr   )r   r   r&   r    r    r!   forward3  s    z"WhisperPositionalEmbedding.forward)N)r   )__name__
__module____qualname__rB   r   r   r   __classcell__r    r    r   r!   ry   /  s   ry   c                       s   e Zd ZdZdeeeeed fddZej	eedd	d
Z
dej	eej	 eeej	  eej	 eej	 eeej	eej	 eeej	  f dddZ  ZS )WhisperAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FT)	embed_dim	num_headsdropout
is_decoderbiasc                    s   t    || _|| _|| _|| | _| j| | jkrNtd| j d| d| jd | _|| _t	j
||dd| _t	j
|||d| _t	j
|||d| _t	j
|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      Fr   )r~   r   r   r   r   head_dimr   scalingr   r   Lineark_projv_projq_projout_proj)r   r   r   r   r   r   r   r    r!   r   :  s    

zWhisperAttention.__init__)tensorseq_lenr4   c                 C   s    | ||| j| jdd S )Nr   rd   )r/   r   r   	transpose
contiguous)r   r   r   r4   r    r    r!   _shapeV  s    zWhisperAttention._shapeN)hidden_stateskey_value_statespast_key_valuer>   layer_head_maskoutput_attentionsr@   c                 C   sx  |dk	}|  \}}	}
| || j }|r\|dk	r\|d jd |jd kr\|d }|d }n|r| | |d|}| | |d|}n|dk	r| | |d|}| | |d|}tj|d |gdd}tj|d |gdd}n(| | |d|}| | |d|}| j	r ||f}|| j
 d| jf}| ||	|j| }|j| }|j| }| d}t||dd}|  || j
 |	|fkrtd|| j
 |	|f d|   |dk	r |  |d|	|fkrtd	|d|	|f d|   ||| j
|	|| }||| j
 |	|}tjj|dd}|dk	r|  | j
fkrhtd
| j
f d|   |dddd||| j
|	| }||| j
 |	|}|r||| j
|	|}||| j
 |	|}nd}tjj|| j| jd}t||}|  || j
 |	| jfkr4td|| j
 |	| jf d|   ||| j
|	| j}|dd}|||	| j}| |}|||fS )z#Input shape: Batch x Time x ChannelNr   rd   r   r   r'   z$Attention weights should be of size z	, but is z!Attention mask should be of size z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size )r.   r   r   r   r   r   r   r)   r1   r   r   r   r/   r]   Zbmmr   r   r   rf   softmaxr   r   r   r   )r   r   r   r   r>   r   r   Zis_cross_attentionr4   r5   rL   Zquery_statesZ
key_statesZvalue_statesZ
proj_shaper:   attn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr    r    r!   r   Z  s~    





" 
zWhisperAttention.forward)r   FT)NNNNF)r   r   r   __doc__rB   floatr9   r   r)   Tensorr   r   r   r   r   r    r    r   r!   r   7  s4           r   c                       sB   e Zd Zed fddZdejejejeejdddZ  Z	S )	WhisperEncoderLayerconfigc                    s   t    |j| _t| j|j|jd| _t	| j| _
|j| _t|j | _|j| _t| j|j| _t|j| j| _t	| j| _d S )N)r   r   r   )r~   r   d_modelr   r   Zencoder_attention_headsattention_dropout	self_attnr   	LayerNormself_attn_layer_normr   r   activation_functionactivation_fnactivation_dropoutr   Zencoder_ffn_dimfc1fc2final_layer_normr   r   r   r    r!   r     s    
zWhisperEncoderLayer.__init__F)r   r>   r   r   r@   c           
      C   s  |}|  |}| j||||d\}}}tjj|| j| jd}|| }|}| |}| | |}tjj|| j	| jd}| 
|}tjj|| j| jd}|| }|jtjkrt| st| rt|jjd }tj|| |d}|f}	|r|	|f7 }	|	S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r>   r   r   r   i  )r,   rC   )r   r   r   rf   r   r   r   r   r   r   r   r$   r)   Zfloat16isinfanyisnanr+   rC   clamp)
r   r   r>   r   r   residualr   rL   Zclamp_valueoutputsr    r    r!   r     s8    



zWhisperEncoderLayer.forward)F)
r   r   r   r   r   r)   r   r9   r   r   r    r    r   r!   r     s    r   c                       sx   e Zd Zed fddZd
ejeej eej eej eej eej eeej  ee	 ee	 ejd
dd	Z
  ZS )WhisperDecoderLayerr   c                    s   t    |j| _t| j|j|jdd| _|j| _t	|j
 | _|j| _t| j| _t| j|j|jdd| _t| j| _t| j|j| _t|j| j| _t| j| _d S )NT)r   r   r   r   )r   r   )r~   r   r   r   r   Zdecoder_attention_headsr   r   r   r   r   r   r   r   r   r   encoder_attnencoder_attn_layer_normr   Zdecoder_ffn_dimr   r   r   r   r   r    r!   r     s,    
zWhisperDecoderLayer.__init__NFT)
r   r>   encoder_hidden_statesencoder_attention_maskr   cross_attn_layer_head_maskr   r   	use_cacher@   c
                 C   s^  |}
|  |}|dk	r"|dd nd}| j|||||d\}}}tjj|| j| jd}|
| }d}d}|dk	r|}
| |}|dk	r|dd nd}| j||||||d\}}}tjj|| j| jd}|
| }|| }|}
| |}| 	| 
|}tjj|| j| jd}| |}tjj|| j| jd}|
| }|f}|rJ|||f7 }|	rZ||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        Nrd   )r   r   r>   r   r   r   )r   r   r>   r   r   r   )r   r   r   rf   r   r   r   r   r   r   r   r   r   )r   r   r>   r   r   r   r   r   r   r   r   Zself_attn_past_key_valueZself_attn_weightsZpresent_key_valueZcross_attn_present_key_valueZcross_attn_weightsZcross_attn_past_key_valuer   r    r    r!   r   4  sT    




zWhisperDecoderLayer.forward)NNNNNNFT)r   r   r   r   r   r)   r   r   r   r9   r   r   r    r    r   r!   r     s*           r   c                   @   sF   e Zd ZeZdZdZdZddgZdd Z	dd	d
Z
ejdddZdS )WhisperPreTrainedModelmodelinput_featuresTr   r   c                 C   s   | j j}t|tjtjfrD|jjjd|d |j	d k	r~|j	j
  n:t|tjr~|jjjd|d |jd k	r~|jj|j 
  d S )Nr   )meanstd)r   Zinit_std
isinstancer   r   Conv1dr   dataZnormal_r   Zzero_	Embeddingr|   )r   moduler   r    r    r!   _init_weights  s    

z$WhisperPreTrainedModel._init_weightsFc                 C   s   t |ttfr||_d S r}   )r   WhisperDecoderWhisperEncodergradient_checkpointing)r   r   valuer    r    r!   _set_gradient_checkpointing  s    z2WhisperPreTrainedModel._set_gradient_checkpointing)r_   c                 C   s   |d d d }|S )zH
        Computes the output length of the convolutional layers
        r   rd   r    )r   r_   r    r    r!    _get_feat_extract_output_lengths  s    z7WhisperPreTrainedModel._get_feat_extract_output_lengthsN)F)r   r   r   r   config_classbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr   r   r)   
LongTensorr   r    r    r    r!   r     s   
r   aL  
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`WhisperConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aN  
    Args:
        input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
            Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
            loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
            the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
            [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
            tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
        attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in
            `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should read
            [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
            paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
            input (see `past_key_values`). This is useful if you want more control over how to convert
            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
aY  
    Args:
        input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
            Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
            loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
            the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
            [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
            tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
            hidden-states at the output of the last layer of the encoder.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                       sX   e Zd ZdZed fddZdd Zejddd	Z	ejd
ddZ
dddZ  ZS )r   z
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`WhisperEncoderLayer`].

    Args:
        config: WhisperConfig
    r   c                    s   t     j| _ j| _ j} j| _ j| _ j	| _	 j
rJt|nd| _tj| j|ddd| _tj||dddd| _t| j	|| _t fddt jD | _t j| _d	| _|   d S )
Nr8   r   r   )kernel_sizepaddingrd   )r   Zstrider   c                    s   g | ]}t  qS r    )r   rJ   r   r    r!   rM   G  s     z+WhisperEncoder.__init__.<locals>.<listcomp>F)r~   r   r   Zencoder_layerdrop	layerdropr   Znum_mel_binsr   r|   max_source_positionsscale_embeddingmathsqrtembed_scaler   r   conv1conv2r   embed_positions
ModuleListrW   Zencoder_layerslayersr   
layer_normr   	post_init)r   r   r   r   r   r!   r   7  s     zWhisperEncoder.__init__c                 C   s   |   D ]
}d|_qd| _d S )NF)
parametersZrequires_gradZ_requires_grad)r   paramr    r    r!   _freeze_parametersN  s    z!WhisperEncoder._freeze_parametersr@   c                 C   s   | j S r}   r   r   r    r    r!   get_input_embeddingsS  s    z#WhisperEncoder.get_input_embeddingsr   c                 C   s
   || _ d S r}   r   r   r   r    r    r!   set_input_embeddingsV  s    z#WhisperEncoder.set_input_embeddingsNc                    s   dk	r n| j j |dk	r |n| j j}|dk	r4|n| j j}tj| |}tj| |}|	ddd}| j
j}|| }	tjj|	| j| jd}	|rdnd}
 rdnd}|dk	r| d t| jkstdt| j d| d  d	t| jD ]\}}|r|
|	f }
d
}| jr6tg }|| jk r6d}|rBd}nr| jr| jr fdd}tjj|||	d|dk	r|| nd}n"||	d|dk	r|| nd d}|d }	 r||d f }q| |	}	|r|
|	f }
|s tdd |	|
|fD S t|	|
|dS )aK  
        Args:
            input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
            attention_mask (`torch.Tensor`)`, *optional*):
                Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
                but it is not used. By default the silence in the input log mel spectrogram are ignored.
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr   rd   r   r   r    z&The head_mask should be specified for  layers, but it is for .FT)NNc                    s    fdd}|S )Nc                     s    | f S r}   r    rb   )r   r   r    r!   custom_forward  s    zMWhisperEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr    r   r   r   r   r!   create_custom_forward  s    z5WhisperEncoder.forward.<locals>.create_custom_forward)r   r   c                 s   s   | ]}|d k	r|V  qd S r}   r    rK   vr    r    r!   	<genexpr>  s      z)WhisperEncoder.forward.<locals>.<genexpr>last_hidden_stater   
attentions)r   r   output_hidden_statesuse_return_dictr   rf   Zgelur   r   permuter   r   r   r   r.   rY   r   AssertionError	enumerater)   rR   r   r   utils
checkpointr   tupler
   )r   r   r>   	head_maskr   r  return_dictinputs_embedsZ	embed_posr   Zencoder_statesZall_attentionsidxZencoder_layerZto_dropdropout_probabilitylayer_outputsr   r    r   r!   r   Y  sn    "




  zWhisperEncoder.forward)NNNNN)r   r   r   r   r   r   r   r   Moduler   r   r   r   r    r    r   r!   r   .  s        r   c                       sH   e Zd ZdZed fddZdd Zdd Zd	d
 ZdddZ	  Z
S )r   z
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`]

    Args:
        config: WhisperConfig
    r   c                    s   t     j| _ j| _ j| _ j| _ j| _ j	rFt
 jnd| _t j j| j| _t| j j| _t fddt jD | _t j| _d| _|   d S )Nr8   c                    s   g | ]}t  qS r    )r   rJ   r   r    r!   rM     s     z+WhisperDecoder.__init__.<locals>.<listcomp>F)r~   r   r   Zdecoder_layerdropr   r   r|   max_target_positionsr   r   r   r   r   r   r   r   
vocab_sizeembed_tokensry   r   r   rW   decoder_layersr   r   r   r   r   r   r   r   r!   r     s     zWhisperDecoder.__init__c                 C   s   | j S r}   r  r   r    r    r!   r     s    z#WhisperDecoder.get_input_embeddingsc                 C   s
   || _ d S r}   r  r   r    r    r!   r     s    z#WhisperDecoder.set_input_embeddingsc                 C   sX   d }|d dkr$t ||j|j|d}|d k	rTt||j|d d}|d krL|n|| }|S )Nr   r   )r%   r&   )r5   )r7   r$   r%   r;   )r   r>   input_shaper  r&   Zcombined_attention_maskZexpanded_attn_maskr    r    r!   _prepare_decoder_attention_mask  s    z.WhisperDecoder._prepare_decoder_attention_maskNc                    s   dk	r n| j j |
dk	r |
n| j j}
dk	r4n| j j|dk	rH|n| j j}|dk	rj|dk	rjtdnD|dk	r| }|d|d }n"|dk	r| dd }ntd|dk	r|d d jd nd}|dkr| 	|}| 
||||}|dk	r| j||d}n| j||d}|| }tjj|| j| jd}| jrX| jrXrXtd	 d
|
rbdnd} rpdnd} r|dk	rdnd}rdnd}t||gddgD ]V\}}|dk	r| d t| jkstd| dt| j d| d  dqt| jD ]D\}}|
r(||f7 }| jrJtg }|| jk rJq|dk	r\|| nd}| jr| jrć fdd}tjj|||||d|dk	r|| nd|dk	r|| ndd}n<|||||dk	r|| nd|dk	r|| nd| d}|d }r&|| rdnd f7 } r||d f7 }|dk	r||d f7 }q| |}|
rp||f7 }rz|nd}|stdd |||||fD S t|||||dS )aA  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
                on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
                shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
                `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
                control over how to convert `input_ids` indices into associated vectors than the model's internal
                embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsr   rd   )r&   r   z^`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...Fr    r  cross_attn_head_maskThe `z` should be specified for r   r   c                    s    fdd}|S )Nc                     s    | f S r}   r    r   )r   r   r   r    r!   r     s    zMWhisperDecoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr    r   r   r   r   r!   r     s    z5WhisperDecoder.forward.<locals>.create_custom_forward)r>   r   r   r   r   r   r   r   r   c                 s   s   | ]}|d k	r|V  qd S r}   r    r   r    r    r!   r     s   z)WhisperDecoder.forward.<locals>.<genexpr>)r  past_key_valuesr   r  cross_attentions) r   r   r  r   r  r   r.   r/   r   r  r  r   r   rf   r   r   r   loggerZwarning_onceziprY   r   r  r  r)   rR   r   r  r	  r   r
  r   )r   r   r>   r   r  r  r  r  r   r   r  r  r  r&   Z	positionsr   Zall_hidden_statesZall_self_attnsZall_cross_attentionsZnext_decoder_cacheZ	attn_maskZ	mask_namer  Zdecoder_layerr  r   r   r  Z
next_cacher    r  r!   r     s    G

   

$




zWhisperDecoder.forward)NNNNNNNNNNN)r   r   r   r   r   r   r   r   r  r   r   r    r    r   r!   r     s"              r   zUThe bare Whisper Model outputting raw hidden-states without any specific head on top.c                       s  e Zd Zed fddZdd Zdd Zdd	 Zd
d Zdd Z	de
jee
j dddZeeeeeddee
j ee
j ee
j ee
j ee
j ee
j ee
j eeee
j   eeee
j   eee
j  ee ee ee ee eee
j ef dddZ  ZS )WhisperModelr   c                    s,   t  | t|| _t|| _|   d S r}   )r~   r   r   encoderr   decoderr   r   r   r    r!   r     s    

zWhisperModel.__init__c                 C   s   | j jS r}   r"  r  r   r    r    r!   r     s    z!WhisperModel.get_input_embeddingsc                 C   s   || j _d S r}   r#  r   r    r    r!   r     s    z!WhisperModel.set_input_embeddingsc                 C   s   | j S r}   )r!  r   r    r    r!   get_encoder  s    zWhisperModel.get_encoderc                 C   s   | j S r}   )r"  r   r    r    r!   get_decoder  s    zWhisperModel.get_decoderc                 C   s   | j   dS z
        Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
        not be updated during training.
        Nr!  r   r   r    r    r!   freeze_encoder  s    zWhisperModel.freeze_encoderN)r   r>   c                 C   s   t | jdds|S | \}}}| jjdkr| jrt||f| jj| jj|| jjd}tj	||j
tjd}|dddf d|d}d||< | jjdkr| jrt||f| jj| jj| jjd}tj	||j
tjd}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://arxiv.org/abs/1904.08779).
        Zapply_spec_augmentTr   )r<   r=   r>   r?   )r%   r$   Nr   )r<   r=   r?   )getattrr   r.   Zmask_time_probr   ra   Zmask_time_lengthZmask_time_min_masksr)   r   r%   r9   r3   Zmask_feature_probZmask_feature_lengthZmask_feature_min_masks)r   r   r>   r^   hidden_sizerH   Zmask_time_indicesZmask_feature_indicesr    r    r!   _mask_input_features  s0    z!WhisperModel._mask_input_featuresoutput_typer   )r   r>   decoder_input_idsdecoder_attention_maskr  decoder_head_maskr  encoder_outputsr  decoder_inputs_embedsr   r   r  r  r@   c                 C   s  |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}|dk	rH|n| j j}|dkr|| j||d}| j|||||d}nH|rt|tst|d t	|dkr|d ndt	|dkr|d ndd}| j
|||d |||	|
||||d}|s|| S t|j|j|j|j|j|j|j|jd	S )
a{  
        Returns:

        Example:
         ```python
         >>> import torch
         >>> from transformers import AutoFeatureExtractor, WhisperModel
         >>> from datasets import load_dataset

         >>> model = WhisperModel.from_pretrained("openai/whisper-base")
         >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
         >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
         >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
         >>> input_features = inputs.input_features
         >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
         >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
         >>> list(last_hidden_state.shape)
         [1, 2, 512]
         ```N)r>   r  r   r  r  r   r   rd   r   )r   r>   r   r  r  r  r  r   r   r  r  )r  r  decoder_hidden_statesdecoder_attentionsr  encoder_last_hidden_stater   encoder_attentions)r   r   r  r   r  r+  r!  r   r
   rY   r"  r   r  r  r   r  r  )r   r   r>   r.  r/  r  r0  r  r1  r  r2  r   r   r  r  Zdecoder_outputsr    r    r!   r     sZ    &zWhisperModel.forward)N)NNNNNNNNNNNNNN)r   r   r   r   r   r   r   r$  r%  r(  r)   FloatTensorr   r   r+  r   WHISPER_INPUTS_DOCSTRINGr   r   _CONFIG_FOR_DOCr   r   r9   r   r   r   r    r    r   r!   r     sX   
 +
              r   z^The Whisper Model with a language modeling head. Can be used for automatic speech recognition.c                       s\  e Zd ZdZdgZed fddZdd Zdd	 Zd
d Z	dd Z
ejdddZdd Zeeeeedd#eej eej eej eej eej eej eej eeeej   eeeej   eeej  eej ee ee ee ee eeej ef dddZd$eej eej d fddZd%ddZedd Z d&d!d"Z!  Z"S )'WhisperForConditionalGenerationr   zproj_out.weightr   c                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFr   )
r~   r   r   r   r   r   r   r  proj_outr   r   r   r    r!   r   w  s    
z(WhisperForConditionalGeneration.__init__c                 C   s
   | j  S r}   )r   r$  r   r    r    r!   r$    s    z+WhisperForConditionalGeneration.get_encoderc                 C   s
   | j  S r}   )r   r%  r   r    r    r!   r%    s    z+WhisperForConditionalGeneration.get_decoderc                 C   s   | j S r}   r<  r   r    r    r!   get_output_embeddings  s    z5WhisperForConditionalGeneration.get_output_embeddingsc                 C   s
   || _ d S r}   r=  )r   Znew_embeddingsr    r    r!   set_output_embeddings  s    z5WhisperForConditionalGeneration.set_output_embeddingsr   c                 C   s
   | j  S r}   )r   r   r   r    r    r!   r     s    z4WhisperForConditionalGeneration.get_input_embeddingsc                 C   s   | j j  dS r&  )r   r!  r   r   r    r    r!   r(    s    z.WhisperForConditionalGeneration.freeze_encoderr,  N)r   r>   r.  r/  r  r0  r  r1  r  r2  labelsr   r   r  r  r@   c                 C   s  |dk	r|n| j j}|dk	r@|dkr@|
dkr@t|| j j| j j}| j|||||||||	|
||||d}| |d }d}|dk	rt }||j	}||
d| j j|d}|s|f|dd  }|dk	r|f| S |S t|||j|j|j|j|j|j|jd	S )a5  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
            or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
            only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
        >>> from datasets import load_dataset

        >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
        >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

        >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
        >>> input_features = inputs.input_features

        >>> generated_ids = model.generate(inputs=input_features)

        >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> transcription
        ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
        ```N)r>   r.  r1  r/  r  r0  r  r  r2  r   r   r  r  r   r   r   )	losslogitsr  r4  r5  r  r6  r   r7  )r   r  r"   r   r   r   r<  r   r0   r%   r/   r  r]   r   r  r4  r5  r  r6  r   r7  )r   r   r>   r.  r/  r  r0  r  r1  r  r2  r@  r   r   r  r  r   Z	lm_logitsrA  loss_fctoutputr    r    r!   r     sX    1  z'WhisperForConditionalGeneration.forwardF)rb   
prompt_idsc                    s\  |dkr| j }|dk	r0t|ds(td||_nd|_|	dk	r^t|dsPtd|	 }	|	|_|dk	r~t|dsxtd||_d}t| jd	r| jjdk	r| jj}n.t| j d	r| j jdk	r| j j}n|	d	d}|dk	s|	dk	s|dkrl|dk	rlg }t|d
r|j|j
 kr |j}n|jt krDdt|j  d}n`|jt krdd|j d}n@t|jdk}td|j d|rtt n
tt  d|d|j
| f n
|d t|dr|jtkr|d|j|j f ntd|j dt dn t|dr.|d|jd f t|drl|jsl|rX|d d d nd}|||jf |dk	r|||_|dk	r|	ddk	rtd| }|^}}|| jj d d d }|d|i |	dddk	rR|d  t|7  < |d | jjkrRtdt| d|d t|  d|d  d | jj d!| jj d|d	dpd|j}||jfd"d# |D }d$d# t|D }||_|jrt|g}|r
d%|d&< d%|d'< t|ddd(krtd) t|d*std+|	d,dk	r
|d,|_t j ||||||f|}|rXt|d*rXt|d,d}| j!||j"|d-|d.< |S )/a  

        Generates sequences of token ids for models with a language modeling head.

        <Tip warning={true}>

        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.

        For an overview of generation strategies and code examples, check out the [following
        guide](./generation_strategies).

        </Tip>

        Parameters:
            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            logits_processor (`LogitsProcessorList`, *optional*):
                Custom logits processors that complement the default logits processors built from arguments and
                generation config. If a logit processor is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                Custom stopping criteria that complement the default stopping criteria built from arguments and a
                generation config. If a stopping criteria is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            return_timestamps (`bool`, *optional*):
                Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
            task (`str`, *optional*):
                Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
                will be updated accordingly.
            language (`str`, *optional*):
                Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
                find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
            is_multilingual (`bool`, *optional*):
                Whether or not the model is multilingual.
            prompt_ids (`torch.Tensor`, *optional*):
                Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
                provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
                transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
                correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
            return_token_timestamps (`bool`, *optional*):
                Whether to return token-level timestamps with the text. This can be used with or without the
                `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
                words.
            kwargs (`Dict[str, Any]`, *optional*):
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.

        Return:
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.

                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
                [`~utils.ModelOutput`] types are:

                    - [`~generation.GreedySearchDecoderOnlyOutput`],
                    - [`~generation.SampleDecoderOnlyOutput`],
                    - [`~generation.BeamSearchDecoderOnlyOutput`],
                    - [`~generation.BeamSampleDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~utils.ModelOutput`] types are:

                    - [`~generation.GreedySearchEncoderDecoderOutput`],
                    - [`~generation.SampleEncoderDecoderOutput`],
                    - [`~generation.BeamSearchEncoderDecoderOutput`],
                    - [`~generation.BeamSampleEncoderDecoderOutput`]
        Nno_timestamps_token_idab  You are trying to return timestamps, but the generation config is not properly set.Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`.For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363F
lang_to_ida<  The generation config is outdated and is thus not compatible with the `language` argumentto `generate`. Either set the language using the `forced_decoder_ids` in the model config, or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224
task_to_ida4  The generation config is outdated and is thus not compatible with the `task` argumentto `generate`. Either set the task using the `forced_decoder_ids` in the model config, or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224forced_decoder_idslanguagez<|z|>rd   zUnsupported language: z. Language should be one of: r   r   )r   Ntaskr  z3`task is not supported. The task should be one of `rA   Z
transcriber   r   r   zfWhen specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten.Zmax_new_tokensz)The length of the sliced `prompt_ids` is z, and the `max_new_tokens` zP. Thus, the combined length of the sliced `prompt_ids` and `max_new_tokens` is: z@. This exceeds the `max_target_positions` of the Whisper model: z. You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, so that their combined length is less that c                 S   s   g | ]\}}|qS r    r    )rK   Z_ranktokenr    r    r!   rM     s     z<WhisperForConditionalGeneration.generate.<locals>.<listcomp>c                 S   s   g | ]\}}|d  |fqS )r   r    )rK   ZrankrL  r    r    r!   rM     s     Tr   Zreturn_dict_in_generate	translatez@Token-level timestamps may not be reliable for task 'translate'.alignment_headszModel generation config has no `alignment_heads`, token-level timestamps not available. See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config.
num_frames)rO  Ztoken_timestamps)#generation_confighasattrr   return_timestampslowerrJ  rK  r   rI  getrG  keysr   valuesrY   listr[   r   rH  rF  rV   r  updatepopr   r  r	   r)  r  warningrO  r~   generate_extract_token_timestampsrN  )r   rb   rP  Zlogits_processorZstopping_criteriaZprefix_allowed_tokens_fnZsynced_gpusrR  rK  rJ  Zis_multilingualrE  Zreturn_token_timestampskwargsrI  Zlanguage_tokenZis_language_coder  r   Ztext_prompt_idsZnon_prompt_forced_decoder_idsr   rO  r   r    r!   r[    s    i






$,


>


  
z(WhisperForConditionalGeneration.generatec                 K   s,   |d k	r|d d dd f }||||d dS )Nr   )r1  r  r.  r   r/  r    )r   r.  r  r   r1  r>   r]  r    r    r!   prepare_inputs_for_generation  s    
z=WhisperForConditionalGeneration.prepare_inputs_for_generationc                    s.   d}| D ] }|t  fdd|D f7 }q|S )Nr    c                 3   s"   | ]}| d  |jV  qdS )r   N)Zindex_selectr0   r%   )rK   Z
past_statebeam_idxr    r!   r     s     zAWhisperForConditionalGeneration._reorder_cache.<locals>.<genexpr>)r
  )r  r`  Zreordered_pastZ
layer_pastr    r_  r!   _reorder_cache  s    z.WhisperForConditionalGeneration._reorder_cache{Gz?c                    sB  g  t | jjD ]( tjfdd|jD dd qt fdd|D }|ddddg}|d	k	r|d
d	|d f }tj	|dddd\}}|| | }t
|| jj}|jdd}tj|jtjd}	t |	jd D ]`}
t||
     \}}tjt|dddt}|| | }t||	|
dd	f< q|	S )a  
        Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
        map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
        cross-attentions will be cropped before applying DTW.

        Returns:
            tensor containing the timestamps in seconds for each predicted token
        c                    s   g | ]}|  qS r    r    )rK   x)rq   r    r!   rM   &  s     zMWhisperForConditionalGeneration._extract_token_timestamps.<locals>.<listcomp>rd   r'   c                    s$   g | ]\}} | d d |f qS r}   r    )rK   lh)r  r    r!   rM   *  s     r   r   r   N.r   TF)r(   ZkeepdimZunbiasedrN   )r   r   )Zconstant_values)rW   r   r  r[   r)   r1   r  stackr  Zstd_meanrj   Zmedian_filter_widthr   Z
zeros_like	sequencesrl   r   rx   doublecpunumpyrP   rg   diffZastyper9   r   )r   Zgenerate_outputsrN  Ztime_precisionrO  weightsr   r   rk   Z
timestampsZ	batch_idxrv   rw   ZjumpsZ
jump_timesr    )r  rq   r!   r\    s$    &z9WhisperForConditionalGeneration._extract_token_timestamps)NNNNNNNNNNNNNNN)NNNNNFNNNNNN)NNNN)rb  N)#r   r   r   r   Z_tied_weights_keysr   r   r$  r%  r>  r?  r   r  r   r(  r   r9  r   r   r:  r   r)   r8  r   r   r   r9   r   r   r[  r^  staticmethodra  r\  r   r    r    r   r!   r;  o  s   
               b            
  	    

r;  z
    Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
    like SUPERB Keyword Spotting.
    c                       s   e Zd Z fddZdd ZejdddZejdd	d
Ze	e
eeeddeej eej eeeej   eej ee ee ee eeej ef dddZ  ZS )WhisperForAudioClassificationc                    sl   t  | t|| _|jd }|jr<tt	|| | _
t|j|j| _t|j|j| _|   d S r   )r~   r   r   r!  Znum_hidden_layersuse_weighted_layer_sumr   	Parameterr)   rZ   layer_weightsr   r*  Zclassifier_proj_size	projector
num_labels
classifierr   )r   r   Z
num_layersr   r    r!   r   K  s    

z&WhisperForAudioClassification.__init__c                 C   s   | j   dS )z
        Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
        not be updated during training. Only the projection layers and classification head will be updated.
        Nr'  r   r    r    r!   r(  X  s    z,WhisperForAudioClassification.freeze_encoderr   c                 C   s
   | j  S r}   )r!  r   r   r    r    r!   r   _  s    z2WhisperForAudioClassification.get_input_embeddingsr   c                 C   s   | j | d S r}   )r!  r   r   r    r    r!   r   b  s    z2WhisperForAudioClassification.set_input_embeddingsr,  N)r   r  r1  r@  r   r  r  r@   c                 C   sB  |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}|dkrX| j|||||d}| j jrtj|dd}tj	j
| jdd}	||	ddd jdd}n|d }| |}|jdd}
| |
}d}|dk	rt }||j}||d| j j|d}|s.|f|dd  }|dk	r*|f| S |S t|||j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
        >>> from datasets import load_dataset

        >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
        >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")

        >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
        >>> sample = next(iter(ds))

        >>> inputs = feature_extractor(
        ...     sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
        ... )
        >>> input_features = inputs.input_features

        >>> with torch.no_grad():
        ...     logits = model(input_features).logits

        >>> predicted_class_ids = torch.argmax(logits).item()
        >>> predicted_label = model.config.id2label[predicted_class_ids]
        >>> predicted_label
        'Afrikaans'
        ```Nr3  r   r'   r   r   )rA  rB  r   r  )r   r   r  r  r!  ro  r)   rf  r   rf   r   rq  r/   rT   rr  r   rt  r   r0   r%   rs  r   r   r  )r   r   r  r1  r@  r   r  r  r   Znorm_weightsZpooled_outputrB  rA  rC  rD  r    r    r!   r   e  sD    /

z%WhisperForAudioClassification.forward)NNNNNNN)r   r   r   r   r(  r   r  r   r   r    WHISPER_ENCODER_INPUTS_DOCSTRINGr   r   r:  r   r)   r   r   r   r8  r9   r   r   r   r    r    r   r!   rn  C  s.   
       rn  )r   )N)Nr   )Fr   r   typingr   r   r   rj  rP   r)   Ztorch.utils.checkpointr   Ztorch.nnr   Zactivationsr   Zgeneration.logits_processr	   Zmodeling_outputsr
   r   r   r   r   Zmodeling_utilsr   r  r   r   r   r   Zconfiguration_whisperr   Ztokenization_whisperr   r   Z
get_loggerr   r  r:  Z_CHECKPOINT_FOR_DOCZ%WHISPER_PRETRAINED_MODEL_ARCHIVE_LISTr   rB   r"   Sizer$   r%   r7   r;   r   r   Zndarrayra   rj   rx   r   ry   r  r   r   r   r   ZWHISPER_START_DOCSTRINGr9  ru  r   r   r   r;  rn  r    r    r    r!   <module>   s   
      
w6 EuU    %   S