U
    ,-eR                    @   sz  d Z ddlZddlZddlmZ ddlmZmZmZm	Z	m
Z
 ddlZddlm  mZ ddlmZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZmZmZ ddlmZmZm Z m!Z!m"Z" ddl#m$Z$m%Z%m&Z& e!'e(Z)dZ*ddgZ+dd Z,dd Z-dd Z.d_ddZ/ej0ej0dddZ1eG dd deZ2eG dd deZ3eG dd  d eZ4G d!d" d"ej5Z6G d#d$ d$ej5Z7G d%d& d&ej5Z8G d'd( d(ej5Z9G d)d* d*ej5Z:G d+d, d,ej5Z;G d-d. d.ej5Z<G d/d0 d0ej5Z=G d1d2 d2ej5Z>G d3d4 d4ej5Z?G d5d6 d6ej5Z@G d7d8 d8ej5ZAd9ZBd:ZCd;ZDd<ZEG d=d> d>ej5ZFG d?d@ d@ej5ZGG dAdB dBej5ZHG dCdD dDej5ZIG dEdF dFej5ZJG dGdH dHej5ZKG dIdJ dJej5ZLG dKdL dLej5ZMG dMdN dNej5ZNG dOdP dPej5ZOG dQdR dReZPG dSdT dTePZQG dUdV dVePZReeBG dWdX dXePZSedYeBG dZd[ d[ePZTed\eBG d]d^ d^ePZUdS )`z PyTorch CLAP model.    N)	dataclass)AnyListOptionalTupleUnion)nn   )ACT2FN))BaseModelOutputWithPastAndCrossAttentionsBaseModelOutputWithPooling,BaseModelOutputWithPoolingAndCrossAttentions)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )ClapAudioConfig
ClapConfigClapTextConfigzlaion/clap-htsat-fusedzlaion/clap-htsat-unfusedc                 C   sJ   | j \}}}| dddddddf dd|d}|||| |}|S )ae  
    Interpolate data in time domain. This is used to compensate the resolution reduction in downsampling of a CNN.

    Args:
        hidden_states (`torch.FloatTensor` of shape (batch_size, time_length, classes_num)):
            Input hidden states
        ratio (`int`):
            The ratio of the length of the output to the length of the input.
    Nr   )shaperepeatreshape)hidden_statesratio
batch_sizetime_lengthZclasses_numZ	upsampled r#   g/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/clap/modeling_clap.pyinterpolate7   s    
(r%   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )aR  
    Returns the resized hidden states. The output shape should be `(batch_size * num_windows, window_size, window_size,
    num_channels)`

    Args:
        hidden_states (`torch.FloatTensor` of shape `(batch_size, height, width, num_channels)`):
            Input hidden states
        window_size (`int`):
            Window size
    r   r   r	            )r   viewpermute
contiguous)r   window_sizer!   heightwidthnum_channelswindowsr#   r#   r$   window_partitionH   s         $r2   c                 C   sb   t | jd || | |  }| ||| || ||d}|dddddd |||d}|S )aQ  
    Args:
        windows (`torch.FloatTensor` of shape `(num_windows * batch_size, window_size, window_size, num_channels)`):
            Input windows
        window_size (`int`):
            Window size
        height (`int`):
            Height of the resized audio
        width (`int`):
            Width of the resized audio
    r   r)   r   r	   r&   r'   r(   )intr   r*   r+   r,   )r1   r-   r.   r/   r!   r   r#   r#   r$   window_reverse]   s    $r4   c                 C   s6   |  | }tj|dd|| | }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   dim)ner3   torchZcumsumZtype_aslong)	input_idspadding_idxpast_key_values_lengthmaskZincremental_indicesr#   r#   r$   "create_position_ids_from_input_idsq   s    r>   )logitsreturnc                 C   s"   t jt| | jd}tj| |S )Ndevice)r8   arangelenrB   r   
functionalZcross_entropy)r?   labelsr#   r#   r$   contrastive_loss   s    rG   c                   @   s^   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS )ClapTextModelOutputa  
    Base class for text model's outputs that also contains a pooling of the last hidden states.

    Args:
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The text embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Ntext_embedslast_hidden_stater   
attentions)__name__
__module____qualname____doc__rI   r   r8   FloatTensor__annotations__rJ   r   r   rK   r#   r#   r#   r$   rH      s
   
rH   c                   @   s^   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS )ClapAudioModelOutputak  
    ClapAudio model output to mimic the output of the original implementation.

    Args:
        audio_embeds (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            The Audio embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Naudio_embedsrJ   r   rK   )rL   rM   rN   rO   rS   r   r8   rP   rQ   rJ   r   r   rK   r#   r#   r#   r$   rR      s
   
rR   c                   @   s   e Zd ZU dZdZeej ed< dZ	ejed< dZ
ejed< dZejed< dZejed< dZeed< dZeed	< ee d
ddZdS )
ClapOutputa  
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Contrastive loss for audio-text similarity.
        logits_per_audio:(`torch.FloatTensor` of shape `(audio_batch_size, text_batch_size)`):
            The scaled dot product scores between `audio_embeds` and `text_embeds`. This represents the audio-text
            similarity scores.
        logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, audio_batch_size)`):
            The scaled dot product scores between `text_embeds` and `audio_embeds`. This represents the text-audio
            similarity scores.
        text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The text embeddings obtained by applying the projection layer to the pooled output of [`ClapTextModel`].
        audio_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The audio embeddings obtained by applying the projection layer to the pooled output of [`ClapAudioModel`].
        text_model_output(`BaseModelOutputWithPooling`):
            The output of the [`ClapTextModel`].
        audio_model_output(`BaseModelOutputWithPooling`):
            The output of the [`ClapAudioModel`].
    Nlosslogits_per_audiologits_per_textrI   rS   text_model_outputaudio_model_outputr@   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d kr | nt  | V  qdS ))rX   rY   N)getattrto_tuple).0kselfr#   r$   	<genexpr>   s   z&ClapOutput.to_tuple.<locals>.<genexpr>)tuplekeysr_   r#   r_   r$   r\      s    zClapOutput.to_tuple)rL   rM   rN   rO   rU   r   r8   rP   rQ   rV   rW   rI   rS   rX   r   rY   r   r   r\   r#   r#   r#   r$   rT      s   
rT   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )ClapDropPathz
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is a slightly
    refactored version of the `SwinDropPath` implementation.
    Nc                    s   t    || _d S N)super__init__	drop_prob)r`   rh   	__class__r#   r$   rg      s    
zClapDropPath.__init__c                 C   sj   | j dks| js|S d| j  }|jd fd|jd   }|tj||j|jd }|  |	|| }|S )N        r   r   )r   dtyperB   )
rh   trainingr   ndimr8   Zrandrm   rB   Zfloor_div)r`   r   Z	keep_probr   Zrandom_tensoroutputr#   r#   r$   forward   s    
zClapDropPath.forward)N)rL   rM   rN   rO   rg   rr   __classcell__r#   r#   ri   r$   rd      s   rd   c                       s.   e Zd ZdZed fddZdd Z  ZS )ClapAudioAFFBlockz
    ATTENTIONAL FEATURE FUSION Block from CLAP, since in CLAP we are always in 2D mode, it is not needed to implement
    the 1D version.
    configc                    s   t    |j}|j}t|| }ttj||ddddt|tj	ddtj||ddddt|| _
ttdtj||ddddt|tj	ddtj||ddddt|| _t | _d S )Nr   r   Zkernel_sizeZstridepaddingT)Zinplace)rf   rg   patch_embeds_hidden_sizeZaff_block_rr3   r   Z
SequentialConv2dBatchNorm2dZReLU	local_attZAdaptiveAvgPool2d
global_attZSigmoidsigmoid)r`   rv   channelsZdownsize_ratioZinter_channelsri   r#   r$   rg   	  s(    


	zClapAudioAFFBlock.__init__c                 C   sF   || }|  || | }| |}d| | d| d|   }|S )Nr&   r   )r|   r}   r~   )r`   r   ZresidualZattention_inputZfused_layer_outputrq   r#   r#   r$   rr   !  s
    
zClapAudioAFFBlock.forwardrL   rM   rN   rO   r   rg   rr   rs   r#   r#   ri   r$   rt     s   rt   c                       s0   e Zd ZdZed fddZdddZ  ZS )	ClapAudioPatchEmbedz
    This module converts the hidden states reshaped as an image to patch embeddings ready to be passed to the
    Transformer block.
    ru   c                    s  t    t|jtr"|j|jfn|j}t|jtr@|j|jfn|j}t|jtr^|j|jfn|j}|| _|| _|d |d  |d |d  f| _| jd | jd  | _	|j
| _|j| _|d |d  d |d |d  d f}| jr|jdkrdnd}tj|j| |j|||d| _|jr*t|jnt | _| jrt|| _tj|j|j|d |d d f|d |d d f|d| _d S )Nr   r   r&   Zchannel_mapr'   rw   r	   )rf   rg   
isinstance	spec_sizer3   
patch_sizepatch_strideimg_size	grid_sizeZnum_patchesZflatten_patch_embedsflattenenable_fusionZfusion_typer   rz   Zpatch_embed_input_channelsry   projZenable_patch_layer_norm	LayerNormIdentitynormrt   fusion_model
mel_conv2d)r`   rv   r   r   r   rx   Zscale_factorri   r#   r$   rg   1  s>    
"(
zClapAudioPatchEmbed.__init__Nc              
   C   s  | j rb|d d ddd d d d f }|j\}}}}|| jd ksR|| jd krtd| d| d| jd  d| jd  d	| |}|d}t|dkr\||dd d d d d f  }	|	j\}}}}|	|| d||}	| 	|	}	|	j\}
}}}|	|||||}	|	
d d	}	|	d}tjj|	d|| fd
d}	| || |	||< |}nf|j\}
}
}}|| jd ks|| jd krtd| d| d| jd  d| jd  d	| |}| jr|ddd}| |}|S )Nr   r   zInput audio size (*z) doesn't match model (z).r)   )r   r&   r	   r   r'   r	   Zconstantr&   )r   r   r   
ValueErrorr   sizerD   r,   r*   r   r+   r   r8   r   rE   padr   	transposer   )r`   r   Zis_longer_idxZglobal_hidden_statesr!   r0   r.   r/   Zoutput_widthZlocal_hidden_states_featuresZlocal_widthr#   r#   r$   rr   [  sN     (

 

 
    (

zClapAudioPatchEmbed.forward)Nr   r#   r#   ri   r$   r   +  s   *r   c                       sT   e Zd Z fddZdd Zd
ejeej eej ee	 e
ej ddd	Z  ZS )ClapAudioSelfAttentionc                    s
  t    || dkr,td| d| d|| _t|| | _| j| j | _t|tj	j
r`|n||f| _ttd| jd  d d| jd  d  || _t| jd }t| jd }tt||gdd}t|d}|d d d d d f |d d d d d f  }	|	ddd }	|	d d d d df  | jd d 7  < |	d d d d df  | jd d 7  < |	d d d d df  d| jd  d 9  < |	d	}
| d
|
 tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _t|j| _ d S )Nr   The hidden size (6) is not a multiple of the number of attention heads ()r&   r   Zij)Zindexingr)   relative_position_indexbias)!rf   rg   r   num_attention_headsr3   attention_head_sizeall_head_sizer   collectionsabcIterabler-   r   	Parameterr8   zerosrelative_position_bias_tablerC   stackr   r   r+   r,   sumregister_bufferLinearZqkv_biasquerykeyvalueDropoutattention_probs_dropout_probdropout)r`   rv   r6   	num_headsr-   Zcoords_hZcoords_wZcoordsZcoords_flattenZrelative_coordsr   ri   r#   r$   rg     s8    
*,((,
zClapAudioSelfAttention.__init__c                 C   s6   |  d d | j| jf }||}|ddddS Nr)   r   r&   r   r	   r   r   r   r*   r+   r`   xZnew_x_shaper#   r#   r$   transpose_for_scores  s    
z+ClapAudioSelfAttention.transpose_for_scoresNFr   attention_mask	head_maskoutput_attentionsr@   c                 C   s  |j \}}}| |}| | |}	| | |}
| |}t||	dd}|t	| j
 }| j| jd }|| jd | jd  | jd | jd  d}|ddd }||d }|d k	r|j d }||| || j||}||dd }|d| j||}tjj|dd}| |}|d k	rB|| }t||
}|dddd }| d d | jf }||}|r||fn|f}|S )Nr)   r   r   r&   r5   r	   )r   r   r   r   r   r8   matmulr   mathsqrtr   r   r   r*   r-   r+   r,   	unsqueezer   r   rE   softmaxr   r   r   )r`   r   r   r   r   r!   r6   r0   mixed_query_layer	key_layervalue_layerquery_layerattention_scoresZrelative_position_biasZ
mask_shapeattention_probscontext_layernew_context_layer_shapeoutputsr#   r#   r$   rr     sH    

  

    


zClapAudioSelfAttention.forward)NNF)rL   rM   rN   rg   r   r8   Tensorr   rP   boolr   rr   rs   r#   r#   ri   r$   r     s   %   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )ClapAudioSelfOutputc                    s*   t    t||| _t|j| _d S re   )rf   rg   r   r   denser   r   r   r`   rv   r6   ri   r#   r$   rg     s    
zClapAudioSelfOutput.__init__r   input_tensorr@   c                 C   s   |  |}| |}|S re   r   r   r`   r   r   r#   r#   r$   rr     s    

zClapAudioSelfOutput.forwardrL   rM   rN   rg   r8   r   rr   rs   r#   r#   ri   r$   r     s   r   c                       sT   e Zd Z fddZdd Zd
ejeej eej ee	 e
ej ddd	Z  ZS )ClapAudioAttentionc                    s2   t    t||||| _t||| _t | _d S re   )rf   rg   r   r`   r   rq   setpruned_heads)r`   rv   r6   r   r-   ri   r#   r$   rg     s    
zClapAudioAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S Nr   r   r5   rD   r   r`   r   r   r   r   r   r   r   rq   r   r   unionr`   Zheadsindexr#   r#   r$   prune_heads  s       zClapAudioAttention.prune_headsNFr   c                 C   s6   |  ||||}| |d |}|f|dd   }|S Nr   r   r`   rq   )r`   r   r   r   r   self_outputsattention_outputr   r#   r#   r$   rr     s    zClapAudioAttention.forward)NNF)rL   rM   rN   rg   r   r8   r   r   rP   r   r   rr   rs   r#   r#   ri   r$   r     s      r   c                       s0   e Zd Z fddZejejdddZ  ZS )ClapAudioIntermediatec                    sH   t    t|t|j| | _t|jt	r<t
|j | _n|j| _d S re   )rf   rg   r   r   r3   	mlp_ratior   r   
hidden_actstrr
   intermediate_act_fnr   ri   r#   r$   rg   )  s
    
zClapAudioIntermediate.__init__r   r@   c                 C   s   |  |}| |}|S re   r   r   r`   r   r#   r#   r$   rr   1  s    

zClapAudioIntermediate.forwardr   r#   r#   ri   r$   r   (  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )ClapAudioOutputc                    s4   t    tt|j| || _t|j| _	d S re   )
rf   rg   r   r   r3   r   r   r   hidden_dropout_probr   r   ri   r#   r$   rg   9  s    
zClapAudioOutput.__init__r   c                 C   s   |  |}| |}|S re   r   r   r#   r#   r$   rr   >  s    

zClapAudioOutput.forwardr   r#   r#   ri   r$   r   8  s   r   c                	       st   e Zd Zd fdd	Zdd Zdd Zdd	 Zdeje	e
e
f eej ee ee e	ejejf dddZ  ZS )ClapAudioLayerr   c                    s   t    |j| _|| _|j| _|| _tj||jd| _	t
|||| jd| _|jdkr`t|jnt | _tj||jd| _t||| _t||| _d S )NZeps)r-   rk   )rf   rg   chunk_size_feed_forward
shift_sizer-   input_resolutionr   r   layer_norm_epslayernorm_beforer   	attentiondrop_path_raterd   r   	drop_pathlayernorm_afterr   intermediater   rq   )r`   rv   r6   r   r   r   ri   r#   r$   rg   F  s    
zClapAudioLayer.__init__c                 C   s"   t || jkrd| _t || _d S Nr   )minr-   r   )r`   r   r#   r#   r$   set_shift_and_window_sizeS  s    z(ClapAudioLayer.set_shift_and_window_sizec              	   C   s  | j dkrtjd||df|d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ].}|D ]$}	||d d ||	d d f< |d7 }qqt|| j}
|
d| j| j }
|
d|
d }||dkt	d|dkt	d}nd }|S )Nr   r   rm   r)   r&   g      Yrk   )
r   r8   r   slicer-   r2   r*   r   Zmasked_fillfloat)r`   r.   r/   rm   Zimg_maskZheight_slicesZwidth_slicescountZheight_sliceZwidth_sliceZmask_windows	attn_maskr#   r#   r$   get_attn_maskY  s*    &zClapAudioLayer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS r   )r-   r   rE   r   )r`   r   r.   r/   	pad_rightZ
pad_bottom
pad_valuesr#   r#   r$   	maybe_padu  s
    zClapAudioLayer.maybe_padNFr   input_dimensionsr   r   always_partitionr@   c                 C   s  |s|  | n |\}}| \}}	}
|}| |}|||||
}| |||\}}|j\}	}}}	| jdkrtj|| j | j fdd}n|}t	|| j
}|d| j
| j
 |
}| j|||jd}|d k	r||j}| j||||d}|d }|d| j
| j
|
}t|| j
||}| jdkr<tj|| j| jfdd}n|}|d dkpX|d dk}|r|d d d |d |d d f  }|||| |
}|| | }| |}| |}|| | }|r||d	 fn|f}|S )
Nr   )r   r&   )ZshiftsZdimsr)   r   r   r	   r(   r   )r   r   r   r*   r  r   r   r8   Zrollr2   r-   r   rm   torB   r   r4   r,   r   r   r   rq   )r`   r   r  r   r   r  r.   r/   r!   r   r   Zshortcutr  Z
height_padZ	width_padZshifted_hidden_statesZhidden_states_windowsr   Zattention_outputsr   Zattention_windowsZshifted_windowsZ
was_paddedlayer_outputlayer_outputsr#   r#   r$   rr   |  sN    

   $

zClapAudioLayer.forward)r   )NFF)rL   rM   rN   rg   r   r   r  r8   r   r   r3   r   rP   r   rr   rs   r#   r#   ri   r$   r   E  s      
r   c                       sT   e Zd Z fddZdejeeef eej	 ee
 ee
 eej dddZ  ZS )	ClapAudioStagec                    sf   t     | _| _t fddt|D | _|d k	rV|tjd| _	nd | _	d| _
d S )Nc              	      s4   g | ],}t  |d  dkr"dn jd  dqS )r&   r   )rv   r6   r   r   r   )r   r-   r]   irv   r6   r   r   r#   r$   
<listcomp>  s   z+ClapAudioStage.__init__.<locals>.<listcomp>)r6   
norm_layerF)rf   rg   rv   r6   r   
ModuleListrangeblocksr   
downsampleZpointing)r`   rv   r6   r   depthr   r   r  ri   r  r$   rg     s    
zClapAudioStage.__init__NFr  c                 C   s   |\}}t | jD ]4\}}	|d k	r*|| nd }
|	|||
||}|d }q|}| jd k	r|d d |d d  }}||||f}| ||}n||||f}|||f}|r||dd  7 }|S )Nr   r   r&   )	enumerater  r  )r`   r   r  r   r   r  r.   r/   r  layer_modulelayer_head_maskr	  !hidden_states_before_downsamplingZheight_downsampledZwidth_downsampledoutput_dimensionsZstage_outputsr#   r#   r$   rr     s*        


zClapAudioStage.forward)NFF)rL   rM   rN   rg   r8   r   r   r3   r   rP   r   rr   rs   r#   r#   ri   r$   r
    s      
r
  c                       s^   e Zd ZdZejfee eejdd fddZ	dd Z
ejeeef ejdd	d
Z  ZS )ClapAudioPatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`Tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    N)r   r6   r  r@   c                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr'   r&   Fr   )rf   rg   r   r6   r   r   	reductionr   )r`   r   r6   r  ri   r#   r$   rg   
  s
    
zClapAudioPatchMerging.__init__c                 C   sF   |d dkp|d dk}|rBddd|d d|d f}t j||}|S )Nr&   r   r   )r   rE   r   )r`   input_featurer.   r/   Z
should_padr  r#   r#   r$   r    s
    zClapAudioPatchMerging.maybe_pad)r  r  r@   c                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r&   r   r)   r'   )r   r*   r  r8   catr   r  )r`   r  r  r.   r/   r!   r6   r0   Zinput_feature_0Zinput_feature_1Zinput_feature_2Zinput_feature_3r#   r#   r$   rr     s    $$$$

zClapAudioPatchMerging.forward)rL   rM   rN   rO   r   r   r   r3   Modulerg   r  r8   r   rr   rs   r#   r#   ri   r$   r    s   $r  c                       sj   e Zd Z fddZdd Zdeej eej ee ee ee ee ee e	e
ef dd	d
Z  ZS )ClapAudioEncoderc                    s  t    t j_ _t _ j_jj	_	 j
_
 j
 j _t jdjd   _dd td jt jD jjfddtjD _t fddtjD _d_t j_tj_ j_td_ d S )	Nr&   r   c                 S   s   g | ]}|  qS r#   )item)r]   r   r#   r#   r$   r  A  s     z-ClapAudioEncoder.__init__.<locals>.<listcomp>r   c                    s,   g | ]$} d  d|   d d|  fqS )r   r&   r   r#   r  )r   r#   r$   r  D  s     c                    s|   g | ]t}t  t jd |  j|  j|  j| t jd| t jd|d   |jd k rptnddqS )r&   Nr   )rv   r6   r   r  r   r   r  )	r
  r3   ry   input_resolutionsdepthsr   r   
num_layersr  )r]   Zi_layer)rv   r   r`   r#   r$   r  G  s   
*F)!rf   rg   rD   r"  r#  rv   r   patch_embedr   r   r   Znum_mel_bins
freq_ratior3   ry   Znum_featuresr8   Zlinspacer   r   r   r  r!  r   r  layersgradient_checkpointingr{   
batch_normr   r   ZAdaptiveAvgPool1davgpoolr`   rv   ri   )rv   r   r   r`   r$   rg   4  s,    


 
zClapAudioEncoder.__init__c                 C   s   |j \}}}}t| j| j }| j| j }||ks:||krBtd||k rbtjj|||fddd}||k rtjj|||fddd}|j \}}}	}
|||| j |	| j |
}|	dddd
 }||||
| j |	| j }|S )	z
        The input is 4 normalized log mel spectrograms. It is reshape to the common shape of images. Each channel
        should represent 1 of the 4 crops of the spectrogram. For more details, refer to the [`ClapFeatureExtractor`].
        z@the wav size should be less than or equal to the swin input sizeZbicubicT)modeZalign_cornersr   r   r	   r&   )r   r3   r   r%  r   r   rE   r%   r   r+   r,   )r`   normalized_input_featuresr   r"   Zfreq_lengthZ
spec_widthZspec_heigthbatchr   timefreqr#   r#   r$   reshape_mel2img\  sD                z ClapAudioEncoder.reshape_mel2imgNFT)	is_longerr   r   output_hidden_states(output_hidden_states_before_downsamplingr  return_dictr@   c	           %         sX  | dd}| |}	|	 dd}	d }
| jrJ||j}t|dkd }
| |	}|jd }| 	||
}|rrdnd }|r~dnd } rdnd }| j
d }|r|j\}}}|j|f||f }|dddd}||f7 }||f7 }t| jD ]X\}}|d k	r|| nd }| j
| }| jrJ| jrJ fdd}tjj|||||}n|||| |}|d }|d }|d }|d |d	 f}|r|r|j\}}}|j|f|d |d f|f }|dddd}||f7 }||f7 }nP|r0|s0|j\}}}|j|f||f }|dddd}||f7 }||f7 } r||dd  7 }q| |}|j\}}}|dt| jd   | jd  }|dt| jd   | jd  } |ddd |||| }|j\}}}!}"|!| j }#||||!|# |#|"}|ddddd
 |||#d	}| t|d}$t|$d}$|sHtdd ||$||fD S t||$||dS )Nr   r	   r   r&   r#   c                    s    fdd}|S )Nc                     s    | f S re   r#   inputs)moduler   r#   r$   custom_forward  s    zOClapAudioEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr#   r7  r8  r  r7  r$   create_custom_forward  s    z7ClapAudioEncoder.forward.<locals>.create_custom_forwardr   r)   r'   c                 s   s   | ]}|d k	r|V  qd S re   r#   r]   vr#   r#   r$   ra     s   z+ClapAudioEncoder.forward.<locals>.<genexpr>)rJ   pooler_outputr   rK   )r   r(  r   r  rB   r8   wherer0  r   r$  r!  r*   r+   r  r&  r'  rn   utils
checkpointr   rD   r"  r   r,   r   r%  r)  r   rb   r   )%r`   input_featuresr1  r   r   r2  r3  r  r4  r,  Zis_longer_list_idxZis_longer_listr   Z
frames_numall_hidden_statesZall_reshaped_hidden_statesall_self_attentionsr  r!   r   hidden_sizeZreshaped_hidden_stater  r  r  r;  r	  r  r  rJ   Z
n_channelsZ
freq_shapeZtemporal_shapeZn_frequenciesZn_tempZ
c_freq_binZlatent_outputr#   r  r$   rr     s    






        



  
     zClapAudioEncoder.forward)NNFFFFT)rL   rM   rN   rg   r0  r   r8   rP   r   r   r   rR   rr   rs   r#   r#   ri   r$   r  3  s&   ('       
r  a=  
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`ClapConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
a6  
    Args:
        input_features (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Input audio features. This should be returnes by the [`ClapFeatureExtractor`] class that you can also
            retrieve from [`AutoFeatureExtractor`]. See [`ClapFeatureExtractor.__call__`] for details.
        is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
            Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
            the features.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
a$  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        input_features (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Input audio features. This should be returnes by the [`ClapFeatureExtractor`] class that you can also
            retrieve from [`AutoFeatureExtractor`]. See [`ClapFeatureExtractor.__call__`] for details.
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                       s2   e Zd Zeeef d fddZdd Z  ZS )ClapProjectionLayerru   c                    sH   t    || _|j}|j}t||| _t|j	 | _
t||| _d S re   )rf   rg   rv   rE  projection_dimr   r   linear1r
   Zprojection_hidden_act
activationlinear2)r`   rv   rE  rG  ri   r#   r$   rg   l  s    
zClapProjectionLayer.__init__c                 C   s"   |  |}| |}| |}|S re   )rH  rI  rJ  r   r#   r#   r$   rr   v  s    


zClapProjectionLayer.forward)	rL   rM   rN   r   r   r   rg   rr   rs   r#   r#   ri   r$   rF  k  s   
rF  c                       s2   e Zd ZdZ fddZd
ddZdd	 Z  ZS )ClapTextEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd |j| _tj|j|j| jd| _	d S )N)r;   r   position_embedding_typeabsoluteposition_ids)r   r)   T)
persistenttoken_type_idsr   )rf   rg   r   	EmbeddingZ
vocab_sizerE  Zpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddingsr   r   r   r   r   r[   rL  r   r8   rC   expandr   rN  r   r9   r;   r*  ri   r#   r$   rg     s.    
      zClapTextEmbeddings.__init__Nr   c                 C   s   |d kr*|d k	r t || j|}n
| |}|d k	r<| }n| d d }|d }|d krt| dr| jd d d |f }||d |}	|	}ntj|tj	| j
jd}|d kr| |}| |}
||
 }| jdkr| |}||7 }| |}| |}|S )Nr)   r   rP  r   rl   rM  )r>   r;   &create_position_ids_from_inputs_embedsr   hasattrrP  rV  r8   r   r9   rN  rB   rR  rU  rL  rT  r   r   )r`   r:   rP  rN  inputs_embedsr<   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedrU  
embeddingsrT  r#   r#   r$   rr     s0    








zClapTextEmbeddings.forwardc                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        Nr)   r   rl   r   )r   r8   rC   r;   r9   rB   r   rV  )r`   rY  rZ  Zsequence_lengthrN  r#   r#   r$   rW    s    	   z9ClapTextEmbeddings.create_position_ids_from_inputs_embeds)NNNNr   )rL   rM   rN   rO   rg   rr   rW  rs   r#   r#   ri   r$   rK  ~  s            
(rK  c                
       s   e Zd Zd fdd	ZejejdddZdejeej eej eej eej ee	e	ej   ee
 e	ej dd	d
Z  ZS )ClapTextSelfAttentionNc                    s   t    |j|j dkr>t|ds>td|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|pt|dd| _| jdks| jd	kr|j| _t	d
|j d | j| _|j| _d S )Nr   Zembedding_sizer   r   r   rL  rM  relative_keyrelative_key_queryr&   r   )rf   rg   rE  r   rX  r   r3   r   r   r   r   r   r   r   r   r   r   r[   rL  rS  rQ  distance_embedding
is_decoderr`   rv   rL  ri   r#   r$   rg     s*    
  zClapTextSelfAttention.__init__)r   r@   c                 C   s6   |  d d | j| jf }||}|ddddS r   r   r   r#   r#   r$   r     s    
z*ClapTextSelfAttention.transpose_for_scoresFr   r   r   encoder_hidden_statesencoder_attention_maskpast_key_valuer   r@   c                 C   s  |  |}|d k	}	|	r4|d k	r4|d }
|d }|}n|	r^| | |}
| | |}|}nv|d k	r| | |}
| | |}tj|d |
gdd}
tj|d |gdd}n | | |}
| | |}| |}|d k	}| jr|
|f}t||
dd}| j	dks | j	dkr|j
d |
j
d  }}|r^tj|d tj|jd	dd}ntj|tj|jd	dd}tj|tj|jd	dd}|| }| || j d }|j|jd
}| j	dkrtd||}|| }n4| j	dkrtd||}td|
|}|| | }|t| j }|d k	r:|| }tjj|dd}| |}|d k	rf|| }t||}|dddd }| d d | jf }||}|r||fn|f}| jr||f }|S )Nr   r   r&   r5   r)   r   r`  ra  rl   r   zbhld,lrd->bhlrzbhrd,lrd->bhlrr	   ) r   r   r   r   r8   r  rc  r   r   rL  r   tensorr9   rB   r*   rC   rb  rS  r  rm   Zeinsumr   r   r   r   rE   r   r   r+   r,   r   r   )r`   r   r   r   rf  rg  rh  r   r   Zis_cross_attentionr   r   r   	use_cacher   Zquery_lengthZ
key_lengthZposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keyr   r   r   r   r#   r#   r$   rr     sp    


 





zClapTextSelfAttention.forward)N)NNNNNF)rL   rM   rN   rg   r8   r   r   r   rP   r   r   rr   rs   r#   r#   ri   r$   r_    s$         r_  c                       s4   e Zd Z fddZejejejdddZ  ZS )ClapTextSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr   )rf   rg   r   r   rE  r   r   r   r   r   r   r*  ri   r#   r$   rg   `  s    
zClapTextSelfOutput.__init__r   c                 C   s&   |  |}| |}| || }|S re   r   r   r   r   r#   r#   r$   rr   f  s    

zClapTextSelfOutput.forwardr   r#   r#   ri   r$   rk  _  s   rk  c                
       sv   e Zd Zd
 fdd	Zdd Zdejeej eej eej eej ee	e	ej   ee
 e	ej ddd	Z  ZS )ClapTextAttentionNc                    s.   t    t||d| _t|| _t | _d S )NrL  )rf   rg   r_  r`   rk  rq   r   r   rd  ri   r#   r$   rg   o  s    

zClapTextAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S r   r   r   r#   r#   r$   r   u  s       zClapTextAttention.prune_headsFre  c              	   C   s<   |  |||||||}| |d |}	|	f|dd   }
|
S r   r   )r`   r   r   r   rf  rg  rh  r   r   r   r   r#   r#   r$   rr     s    
	zClapTextAttention.forward)N)NNNNNF)rL   rM   rN   rg   r   r8   r   r   rP   r   r   rr   rs   r#   r#   ri   r$   rn  n  s$         rn  c                       s0   e Zd Z fddZejejdddZ  ZS )ClapTextIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S re   )rf   rg   r   r   rE  intermediate_sizer   r   r   r   r
   r   r*  ri   r#   r$   rg     s
    
zClapTextIntermediate.__init__r   c                 C   s   |  |}| |}|S re   r   r   r#   r#   r$   rr     s    

zClapTextIntermediate.forwardr   r#   r#   ri   r$   rp    s   rp  c                       s4   e Zd Z fddZejejejdddZ  ZS )ClapTextOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S rl  )rf   rg   r   r   rq  rE  r   r   r   r   r   r   r*  ri   r#   r$   rg     s    
zClapTextOutput.__init__r   c                 C   s&   |  |}| |}| || }|S re   rm  r   r#   r#   r$   rr     s    

zClapTextOutput.forwardr   r#   r#   ri   r$   rr    s   rr  c                
       st   e Zd Z fddZd
ejeej eej eej eej eeeej   ee	 eej dddZ
dd	 Z  ZS )ClapTextLayerc                    sr   t    |j| _d| _t|| _|j| _|j| _| jrZ| jsLt|  dt|dd| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is addedrM  ro  )rf   rg   r   seq_len_dimrn  r   rc  add_cross_attentionr   crossattentionrp  r   rr  rq   r*  ri   r#   r$   rg     s    


zClapTextLayer.__init__NFre  c              	   C   s  |d k	r|d d nd }| j |||||d}	|	d }
| jrP|	dd }|	d }n|	dd  }d }| jr|d k	rt| dstd|  d|d k	r|d	d  nd }| |
||||||}|d }
||dd  }|d }|| }t| j| j| j|
}|f| }| jr||f }|S )
Nr&   r   rh  r   r   r)   rv  z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r   )	r   rc  rX  r   rv  r   feed_forward_chunkr   rt  )r`   r   r   r   rf  rg  rh  r   Zself_attn_past_key_valueZself_attention_outputsr   r   Zpresent_key_valueZcross_attn_present_key_valueZcross_attn_past_key_valueZcross_attention_outputsr  r#   r#   r$   rr     sV    


	   

zClapTextLayer.forwardc                 C   s   |  |}| ||}|S re   )r   rq   )r`   r   Zintermediate_outputr  r#   r#   r$   rx    s    
z ClapTextLayer.feed_forward_chunk)NNNNNF)rL   rM   rN   rg   r8   r   r   rP   r   r   rr   rx  rs   r#   r#   ri   r$   rs    s$         Ars  c                       s   e Zd Z fddZd	ejeej eej eej eej eeeej   ee	 ee	 ee	 ee	 e
eej ef dddZ  ZS )
ClapTextEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r#   )rs  )r]   r   ru   r#   r$   r    s     z,ClapTextEncoder.__init__.<locals>.<listcomp>F)	rf   rg   rv   r   r  r  num_hidden_layerslayerr'  r*  ri   ru   r$   rg     s    
 zClapTextEncoder.__init__NFT)r   r   r   rf  rg  past_key_valuesrj  r   r2  r4  r@   c              	      st  |	rdnd } rdnd } r(| j jr(dnd }| jrJ| jrJ|rJtd d}|rRdnd }t| jD ]\}}|	rv||f }|d k	r|| nd }|d k	r|| nd | jr| jrև fdd}tj	j

|||||||}n|||||| }|d }|r||d f7 } r`||d f }| j jr`||d	 f }q`|	r@||f }|
sbtd
d |||||fD S t|||||dS )Nr#   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fc                    s    fdd}|S )Nc                     s    | f S re   r#   r5  )r7  r   rh  r#   r$   r8  @  s    zNClapTextEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr#   r9  rw  r:  r$   r;  ?  s    z6ClapTextEncoder.forward.<locals>.create_custom_forwardr   r)   r   r&   c                 s   s   | ]}|d k	r|V  qd S re   r#   r<  r#   r#   r$   ra   d  s   z*ClapTextEncoder.forward.<locals>.<genexpr>)rJ   r|  r   rK   cross_attentions)rv   ru  r'  rn   loggerZwarning_oncer  r{  r8   r@  rA  rb   r   )r`   r   r   r   rf  rg  r|  rj  r   r2  r4  rC  rD  Zall_cross_attentionsZnext_decoder_cacher  r  r  r;  r	  r#   rw  r$   rr     sv    
	

zClapTextEncoder.forward)	NNNNNNFFT)rL   rM   rN   rg   r8   r   r   rP   r   r   r   r   rr   rs   r#   r#   ri   r$   ry    s.   	         ry  c                       s0   e Zd Z fddZejejdddZ  ZS )ClapTextPoolerc                    s*   t    t|j|j| _t | _d S re   )rf   rg   r   r   rE  r   ZTanhrI  r*  ri   r#   r$   rg   z  s    
zClapTextPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S r   )r   rI  )r`   r   Zfirst_token_tensorpooled_outputr#   r#   r$   rr     s    

zClapTextPooler.forwardr   r#   r#   ri   r$   r  y  s   r  c                   @   s.   e Zd ZdZeZdZdZdd Zd	ddZ	dS )
ClapPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    clapFc                 C   s*  | j j}t|trD|jjjjd|d d |jjjjd|d d nt|t	r|t
jj|j|d d t
jj|j|d d nt|t
jr|jjjd|d d nt|t
jr|jj  |jjd n^t|t
jt
jfr&| j jd d| j j d  | }t
jj|j|d |jdk	r&|jj  dS )	zInitialize the weightsrk   g{Gz?)Zmeanstd)r  g      ?g      r&   N)rv   Zinitializer_factorr   rK  rT  weightdataZnormal_rU  	ClapModelr   initlogit_scale_alogit_scale_trQ  r   r   Zzero_Zfill_rz   r   rE  rz  )r`   r7  factorZin_proj_stdr#   r#   r$   _init_weights  s"    

 z!ClapPreTrainedModel._init_weightsc                 C   s   t |tr||_d S re   )r   ry  r'  )r`   r7  r   r#   r#   r$   _set_gradient_checkpointing  s    
z/ClapPreTrainedModel._set_gradient_checkpointingN)F)
rL   rM   rN   rO   r   config_classZbase_model_prefixZsupports_gradient_checkpointingr  r  r#   r#   r#   r$   r    s   r  c                       s   e Zd ZeZdZed fddZejdddZ	e
eeeeddeej eej ee ee ee eeef d
ddZ  ZS )ClapAudioModelrB  ru   c                    s"   t  | t|| _|   d S re   )rf   rg   r  audio_encoder	post_initr*  ri   r#   r$   rg     s    
zClapAudioModel.__init__rZ   c                 C   s
   | j jjS re   )r  r$  r   r_   r#   r#   r$   get_input_embeddings  s    z#ClapAudioModel.get_input_embeddingsoutput_typer  NrB  r1  r   r2  r4  r@   c                 C   sP   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}| j|||||dS )ar  
        Returns:

        Examples:

        ```python
        >>> from datasets import load_dataset
        >>> from transformers import AutoProcessor, ClapAudioModel

        >>> dataset = load_dataset("ashraq/esc50")
        >>> audio_sample = dataset["train"]["audio"][0]["array"]

        >>> model = ClapAudioModel.from_pretrained("laion/clap-htsat-fused")
        >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-fused")

        >>> inputs = processor(audios=audio_sample, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        ```NrB  r1  r   r2  r4  )rv   use_return_dictr   r2  r  )r`   rB  r1  r   r2  r4  r#   r#   r$   rr     s    zClapAudioModel.forward)NNNNN)rL   rM   rN   r   r  main_input_namerg   r   r  r  r   CLAP_AUDIO_INPUTS_DOCSTRINGr   r   r   r8   rP   
BoolTensorr   r   r   rr   rs   r#   r#   ri   r$   r    s&   
     
r  c                       s   e Zd ZdZeZd fdd	Zdd Zdd Zde	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	ee
j  e	e e	e e	e e	e eee
j ef d
ddZ  ZS )ClapTextModela*  

    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in *Attention is
    all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
    Kaiser and Illia Polosukhin.

    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.

    .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762

    Tc                    sD   t  | || _t|| _t|| _|r2t|nd | _| 	  d S re   )
rf   rg   rv   rK  r^  ry  encoderr  poolerr  )r`   rv   Zadd_pooling_layerri   r#   r$   rg     s    

zClapTextModel.__init__c                 C   s   | j jS re   r^  rR  r_   r#   r#   r$   r    s    z"ClapTextModel.get_input_embeddingsc                 C   s   || j _d S re   r  r`   r   r#   r#   r$   set_input_embeddings	  s    z"ClapTextModel.set_input_embeddingsN)r:   r   rP  rN  r   rY  rf  rg  r|  rj  r   r2  r4  r@   c                 C   s^  |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}| j jrZ|
dk	rP|
n| j j}
nd}
|dk	rx|dk	rxtdn@|dk	r| || | }n"|dk	r| dd }ntd|\}}|dk	r|j	n|j	}|	dk	r|	d d j
d nd}|dkrtj||| f|d}|dkrft| jd	rT| jjddd|f }|||}|}ntj|tj|d
}| ||}| j jr|dk	r| \}}}||f}|dkrtj||d}| |}nd}| || j j}| j|||||d}| j||||||	|
|||d
}|d }| jdk	r$| |nd}|sB||f|dd  S t|||j|j|j|jdS )a  
        encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        NFzDYou cannot specify both input_ids and inputs_embeds at the same timer)   z5You have to specify either input_ids or inputs_embedsr   r&   rA   rP  rl   )r:   rN  rP  rY  r<   )	r   r   rf  rg  r|  rj  r   r2  r4  r   )rJ   r>  r|  r   rK   r}  )rv   r   r2  r  rc  rj  r   Z%warn_if_padding_and_no_attention_maskr   rB   r   r8   ZonesrX  r^  rP  rV  r   r9   Zget_extended_attention_maskZinvert_attention_maskZget_head_maskrz  r  r  r   r|  r   rK   r}  )r`   r:   r   rP  rN  r   rY  rf  rg  r|  rj  r   r2  r4  rZ  r!   r[  rB   r<   r\  r]  Zextended_attention_maskZencoder_batch_sizeZencoder_sequence_lengthr   Zencoder_hidden_shapeZencoder_extended_attention_maskZembedding_outputZencoder_outputsZsequence_outputr  r#   r#   r$   rr     s    $




zClapTextModel.forward)T)NNNNNNNNNNNNN)rL   rM   rN   rO   r   r  rg   r  r  r   r8   r   r   rP   r   r   r   r   rr   rs   r#   r#   ri   r$   r    sD                r  c                       s  e Zd ZeZed fddZeedee	j
 ee	j
 ee	j
 ee ee ee e	jdddZeedee	j
 ee	j
 ee	j
 ee ee ee e	jdd	d
Zeeeeeddee	j ee	j ee	j ee	j
 ee	j ee ee ee ee eeef d
ddZ  ZS )r  ru   c                    s   t  | t|jts.tdt|j dt|jtsPtdt|j d|j}|j}t	
tt|j| _t	
tt|j| _|j| _t|| _t|| _t|| _t|| _|   d S )NzKconfig.text_config is expected to be of type ClapTextConfig but is of type .zMconfig.audio_config is expected to be of type ClapAudioConfig but is of type )rf   rg   r   text_configr   r   typeaudio_configr   r   r   r8   ri  r   logZlogit_scale_init_valuer  r  rG  r  
text_modelrF  text_projectionr  audio_modelaudio_projectionr  )r`   rv   r  r  ri   r#   r$   rg     s&    



zClapModel.__init__Nr:   r   rN  r   r2  r4  r@   c           
      C   s   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}| j||||||d}|dk	rb|d n|j}| |}	tj|	dd}	|	S )a  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`ClapTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, ClapModel

        >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
        >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")

        >>> inputs = tokenizer(["the sound of a cat", "the sound of a dog"], padding=True, return_tensors="pt")
        >>> text_features = model.get_text_features(**inputs)
        ```Nr:   r   rN  r   r2  r4  r   r)   r5   )	rv   r   r2  r  r  r>  r  F	normalize)
r`   r:   r   rN  r   r2  r4  text_outputsr  Ztext_featuresr#   r#   r$   get_text_features  s     	
zClapModel.get_text_features)rB  r1  r   r   r2  r4  r@   c           
      C   sz   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}| j|||d}|sX|d n|j}| |}	tj|	dd}	|	S )a  
        Returns:
            audio_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The audio embeddings obtained by
            applying the projection layer to the pooled output of [`ClapAudioModel`].

        Examples:

        ```python
        >>> from transformers import AutoFeatureExtractor, ClapModel
        >>> import torch

        >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
        >>> feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")
        >>> random_audio = torch.rand((16_000))
        >>> inputs = feature_extractor(random_audio, return_tensors="pt")
        >>> audio_features = model.get_audio_features(**inputs)
        ```N)rB  r1  r4  r   r)   r5   )	rv   r   r2  r  r  r>  r  r  r  )
r`   rB  r1  r   r   r2  r4  audio_outputsr  Zaudio_featuresr#   r#   r$   get_audio_features  s    
zClapModel.get_audio_featuresr  )
r:   rB  r1  r   rN  return_lossr   r2  r4  r@   c
              	   C   sp  |dk	r|n| j j}|dk	r |n| j j}|	dk	r4|	n| j j}	| j|||||	d}
| j||||||	d}|	sr|
d n|
j}| |}|	s|d n|j}| |}||j	dddd }||j	dddd }| j
 }| j }t|| | }t|| | }d}|r,t|}t| }|| d	 }|	sZ||||||
f}|dk	rV|f| S |S t|||||||
d
S )ak  
        Returns:

        Examples:

        ```python
        >>> from datasets import load_dataset
        >>> from transformers import AutoProcessor, ClapModel

        >>> dataset = load_dataset("ashraq/esc50")
        >>> audio_sample = dataset["train"]["audio"][0]["array"]

        >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
        >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")

        >>> input_text = ["Sound of a dog", "Sound of vaccum cleaner"]

        >>> inputs = processor(text=input_text, audios=audio_sample, return_tensors="pt", padding=True)

        >>> outputs = model(**inputs)
        >>> logits_per_audio = outputs.logits_per_audio  # this is the audio-text similarity score
        >>> probs = logits_per_audio.softmax(dim=-1)  # we can take the softmax to get the label probabilities
        ```Nr  r  r   r&   r)   T)pr6   Zkeepdimg       @)rU   rV   rW   rI   rS   rX   rY   )rv   r   r2  r  r  r  r>  r  r  r   r  expr  r8   r   trG   rT   )r`   r:   rB  r1  r   rN  r  r   r2  r4  r  r  rS   rI   Zlogit_scale_textZlogit_scale_audiorW   rV   rU   Zcaption_lossZ
audio_lossrq   r#   r#   r$   rr     s\    &	



zClapModel.forward)NNNNNN)NNNNNN)	NNNNNNNNN)rL   rM   rN   r   r  rg   r   CLAP_TEXT_INPUTS_DOCSTRINGr   r8   r   r   rP   r  r  r  CLAP_INPUTS_DOCSTRINGr   rT   Z
LongTensorr  r   r   rr   rs   r#   r#   ri   r$   r    sr          /      -
         
r  zf
    CLAP Text Model with a projection layer on top (a linear layer on top of the pooled output).
    c                       s   e Zd ZeZed fddZejdddZdd Z	e
eeeed	deej eej eej ee ee ee eeef dddZ  ZS )ClapTextModelWithProjectionru   c                    s,   t  | t|| _t|| _|   d S re   )rf   rg   r  r  rF  r  r  r*  ri   r#   r$   rg   }  s    

z$ClapTextModelWithProjection.__init__rZ   c                 C   s
   | j jjS re   r  r^  rR  r_   r#   r#   r$   r    s    z0ClapTextModelWithProjection.get_input_embeddingsc                 C   s   || j j_d S re   r  r  r#   r#   r$   r    s    z0ClapTextModelWithProjection.set_input_embeddingsr  Nr  c                 C   s   |dk	r|n| j j}| j||||||d}|s6|d n|j}| |}	|st|	|d f|dd  }
tdd |
D S t|	|j|j|j	dS )	a  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, ClapTextModelWithProjection

        >>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
        >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")

        >>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> text_embeds = outputs.text_embeds
        ```Nr  r   r   r&   c                 s   s   | ]}|d k	r|V  qd S re   r#   r]   rq   r#   r#   r$   ra     s      z6ClapTextModelWithProjection.forward.<locals>.<genexpr>)rI   rJ   r   rK   )
rv   r  r  r>  r  rb   rH   rJ   r   rK   )r`   r:   r   rN  r   r2  r4  r  r  rI   r   r#   r#   r$   rr     s(    	
z#ClapTextModelWithProjection.forward)NNNNNN)rL   rM   rN   r   r  rg   r   r  r  r  r   r  r   rH   r   r8   r   r   r   r   rr   rs   r#   r#   ri   r$   r  t  s*   
      
r  zg
    CLAP Audio Model with a projection layer on top (a linear layer on top of the pooled output).
    c                       s   e Zd ZeZdZed fddZejdddZ	e
eeeeddeej eej ee ee ee eeef d
ddZ  ZS )ClapAudioModelWithProjectionrB  ru   c                    s,   t  | t|| _t|| _|   d S re   )rf   rg   r  r  rF  r  r  r*  ri   r#   r$   rg     s    

z%ClapAudioModelWithProjection.__init__rZ   c                 C   s   | j jjjS re   )r  r  r$  r   r_   r#   r#   r$   r    s    z1ClapAudioModelWithProjection.get_input_embeddingsr  Nr  c           
      C   s   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}| j|||||d}|s\|d n|j}| |}|s||d f|dd  }	tdd |	D S t||j	|j
|jdS )	a  
        Returns:

        Examples:

        ```python
        >>> from datasets import load_dataset
        >>> from transformers import ClapAudioModelWithProjection, ClapProcessor

        >>> model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused")
        >>> processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")

        >>> dataset = load_dataset("ashraq/esc50")
        >>> audio_sample = dataset["train"]["audio"][0]["array"]

        >>> inputs = processor(audios=audio_sample, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> audio_embeds = outputs.audio_embeds
        ```Nr  r   r   r&   c                 s   s   | ]}|d k	r|V  qd S re   r#   r  r#   r#   r$   ra   	  s      z7ClapAudioModelWithProjection.forward.<locals>.<genexpr>)rS   rJ   rK   r   )rv   r  r   r2  r  r>  r  rb   rR   rJ   rK   r   )
r`   rB  r1  r   r2  r4  r  r  rS   r   r#   r#   r$   rr     s,    
z$ClapAudioModelWithProjection.forward)NNNNN)rL   rM   rN   r   r  r  rg   r   r  r  r   r  r   rR   r   r8   rP   r  r   r   r   rr   rs   r#   r#   ri   r$   r    s&   
     
r  )r   )VrO   r   r   dataclassesr   typingr   r   r   r   r   r8   Ztorch.nn.functionalr   rE   r  Zactivationsr
   Zmodeling_outputsr   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r@  r   r   r   r   r   Zconfiguration_clapr   r   r   Z
get_loggerrL   r~  Z_CHECKPOINT_FOR_DOCZ"CLAP_PRETRAINED_MODEL_ARCHIVE_LISTr%   r2   r4   r>   r   rG   rH   rR   rT   r  rd   rt   r   r   r   r   r   r   r   r
  r  r  ZCLAP_START_DOCSTRINGr  r  r  rF  rK  r_  rk  rn  rp  rr  rs  ry  r  r  r  r  r  r  r  r#   r#   r#   r$   <module>   s   

%(ce'|<6 R &Z 2Wc%: * dF