U
    ,-e?                     @   s  d Z ddlZddlZddlmZ ddlmZ ddlm	Z	m
Z
mZ ddlZddlZddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZ ddlmZmZm Z m!Z!m"Z" ddl#m$Z$ e!%e&Z'dZ(dZ)dgZ*eG dd deZ+eG dd deZ,eG dd deZ-dLddZ.dMddZ/dNdd Z0G d!d" d"ej1Z2G d#d$ d$ej1Z3G d%d& d&ej1Z4G d'd( d(ej1Z5G d)d* d*ej1Z6G d+d, d,ej1Z7G d-d. d.ej1Z8G d/d0 d0ej1Z9G d1d2 d2ej1Z:G d3d4 d4ej1Z;G d5d6 d6ej1Z<G d7d8 d8eZ=d9Z>d:Z?ed;e>G d<d= d=e=Z@G d>d? d?ej1ZAed@e>G dAdB dBe=ZBG dCdD dDej1ZCG dEdF dFej1ZDG dGdH dHej1ZEedIe>G dJdK dKe=ZFdS )Oz PyTorch TVLT model.    N)deepcopy)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputSequenceClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
TvltConfigr   zZinengTang/tvlt-basec                   @   s   e Zd ZU dZdZejed< dZejed< dZ	ejed< dZ
ejed< dZejed< dZejed< dZejed	< dZeeej  ed
< dZeeej  ed< dS )TvltModelOutputa  
    Class for TvltModel's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        last_pixel_hidden_state (`torch.FloatTensor` of shape `(batch_size, pixel_sequence_length, hidden_size)`):
            Pixel sequence of hidden-states at the output of the last layer of the model.
        last_audio_hidden_state (`torch.FloatTensor` of shape `(batch_size, audio_sequence_length, hidden_size)`):
            Audio sequence of hidden-states at the output of the last layer of the model.
        pixel_label_masks (`torch.FloatTensor` of shape `(batch_size, pixel_patch_length)`):
            Tensor indicating which pixel patches are masked (1) and which are not (0).
        audio_label_masks (`torch.FloatTensor` of shape `(batch_size, audio_patch_length)`):
            Tensor indicating which audio patches are masked (1) and which are not (0).
        pixel_ids_restore (`torch.LongTensor` of shape `(batch_size, pixel_patch_length)`):
            Tensor containing the ids permutation of pixel masking.
        audio_ids_restore (`torch.LongTensor` of shape `(batch_size, audio_patch_length)`):
            Tensor containing the ids permutation of audio masking.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlast_hidden_statelast_pixel_hidden_statelast_audio_hidden_statepixel_label_masksaudio_label_maskspixel_ids_restoreaudio_ids_restorehidden_states
attentions)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r   
LongTensorr   r   r    r!   r   r   r"    r+   r+   g/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/tvlt/modeling_tvlt.pyr   6   s   
r   c                   @   sL   e Zd ZU dZdZejed< dZe	e
ej  ed< dZe	e
ej  ed< dS )TvltDecoderOutputaM  
    Class for TvltDecoder's outputs, with potential hidden states and attentions.

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
            Pixel reconstruction logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlogitsr!   r"   )r#   r$   r%   r&   r.   r'   r(   r)   r!   r   r   r"   r+   r+   r+   r,   r-   _   s   
r-   c                   @   sz   e Zd ZU dZdZeej ed< dZ	ejed< dZ
ejed< dZejed< dZeeej  ed< dZeeej  ed< dS )	TvltForPreTrainingOutputa
  
    Class for TvltForPreTraining's outputs, with potential hidden states and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`):
            Pixel reconstruction loss.
        matching_logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
            Matching objective logits.
        pixel_logits (`torch.FloatTensor` of shape
            `(batch_size, pixel_patch_length, image_patch_size ** 3 * pixel_num_channels)`): Pixel reconstruction
            logits.
        audio_logits (`torch.FloatTensor` of shape
            `(batch_size, audio_patch_length, image_patch_size[0] * image_patch_size[1])`): Audio reconstruction
            logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlossmatching_logitspixel_logitsaudio_logitsr!   r"   )r#   r$   r%   r&   r0   r   r'   r(   r)   r1   r2   r3   r!   r   r"   r+   r+   r+   r,   r/   v   s   
r/         ?c                 C   s>   | j dd \}}tj||f| jd}t|d|  }||fS )!Generate noise for audio masking.N   devicer   )shaper'   randr8   int)pixel_values
pixel_mask
mask_ratio
batch_sizeseq_lennoiselen_keepr+   r+   r,   generate_pixel_mask_noise   s    rC   patch-level   c           
      C   s   | j dd \}}|dkrN|| }tj||| jdddd|||}n|dkrhtj||| jd}t|d|  }	||	fS )r5   Nr6   zframe-levelr7   r   rD   )r9   r'   r:   r8   	unsqueezerepeatviewr;   )
audio_values
audio_maskr>   	mask_typefreq_lenr?   r@   num_time_patchesrA   rB   r+   r+   r,   generate_audio_mask_noise   s$       rO   c                 C   s   | j \}}}tj|dd}tj|dd}|ddd|f }	tj| d|	ddd|d}
tj||g| jd}d|ddd|f< tj|d|d}|dk	r||9 }tj|d|	d}|
|||fS )z
    Perform random masking by per-sample shuffling on frame-level. Per-sample shuffling is done by argsort random
    noise. sequence: [batch_size, seq_len, hidden_dim], sequence
    r   dimNrF   rQ   indexr7   r   )r9   r'   ZargsortgatherrG   rH   Zonesr8   )sequencerA   rB   attention_masksr?   r@   Z
hidden_dimZids_shuffleids_restoreZids_keepZsequence_maskedZlabel_masksr+   r+   r,   random_masking   s     rX   c                       s*   e Zd ZdZ fddZdddZ  ZS )TvltPixelEmbeddings,Construct the patch and position embeddings.c                    st   t    t|| _| jj| _ttdd|j	| _
ttd|j|j	| _ttd| j|j	| _|| _d S Nr   )super__init__TvltPixelPatchEmbeddingspatch_embeddingsnum_patches_per_imager   	Parameterr'   zeroshidden_sizetype_embed_v
num_framestemporal_embedpos_embed_vconfigselfrh   	__class__r+   r,   r]      s    


zTvltPixelEmbeddings.__init__Nc           	      C   sh   |j \}}}}}| |}|| jd|d7 }|tj| jd d d |f | jdd7 }|| j7 }||fS Nr   rP   )	r9   r_   rg   rH   r'   repeat_interleaverf   r`   rd   )	rj   r<   rV   r?   re   num_channelsheightwidth
embeddingsr+   r+   r,   forward   s    
(
zTvltPixelEmbeddings.forward)Nr#   r$   r%   r&   r]   rs   __classcell__r+   r+   rk   r,   rY      s   rY   c                       s*   e Zd ZdZ fddZdddZ  ZS )TvltAudioEmbeddingsrZ   c                    s   t    t|| _| jj| _ttdd|j	| _
|j|jd  | _ttd| j| j |j	| _ttd| j|j	| _|j|jd  | _|| _d S r[   )r\   r]   TvltAudioPatchEmbeddingsr_   num_patchesr   ra   r'   rb   rc   type_embed_afrequency_lengthaudio_patch_sizenum_freq_patchespos_embed_a
freq_embedrh   ri   rk   r+   r,   r]      s    


 zTvltAudioEmbeddings.__init__Nc                 C   sh   |  |}|d| j }|| jd|d7 }|tj| jd d d |f | jdd7 }|| j7 }||fS rm   )	r_   sizer|   r~   rH   r'   rn   r}   ry   )rj   rJ   rV   rr   rN   r+   r+   r,   rs      s    
(
zTvltAudioEmbeddings.forward)Nrt   r+   r+   rk   r,   rv      s   rv   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )r^   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _
|| _|| _|| _tj||||d| _d S Nr   r   )Zkernel_sizeZstride)r\   r]   
image_sizeimage_patch_sizenum_image_channelsrc   
isinstancecollectionsabcIterable
patch_sizero   r`   r   Conv2d
projection)rj   rh   r   r   ro   rc   r`   rk   r+   r,   r]     s    
 z!TvltPixelPatchEmbeddings.__init__)r<   returnc              
   C   s   |j \}}}}}|| jkr"td|| jd ks>|| jd krltd| d| d| jd  d| jd  d	||| |||}| |ddd}|||| j | j	}|S )	NeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*) doesn't match model ().r6   )
r9   ro   
ValueErrorr   reshaper   flatten	transposer`   rc   )rj   r<   r?   re   ro   rp   rq   rr   r+   r+   r,   rs     s    
(z TvltPixelPatchEmbeddings.forward	r#   r$   r%   r&   r]   r'   Tensorrs   ru   r+   r+   rk   r,   r^     s   r^   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )rw   z
    This class turns `audio_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c           
         s   t    |j|j|j  }}}|j|j }}||f}t|tj	j
rH|n||f}|d |d  |d |d   }|d |d  |d |d  f}	|| _|| _|| _|| _|	| _tj||||d| _d S r   )r\   r]   spectrogram_lengthrz   r{   num_audio_channelsrc   r   r   r   r   spectrogram_sizer   ro   rx   patch_shaper   r   r   )
rj   rh   r   rz   r   ro   rc   r   rx   r   rk   r+   r,   r]   8  s     

  z!TvltAudioPatchEmbeddings.__init__)rJ   r   c              
   C   s   |j \}}}}|| jkr td|| jd ks<|| jd krjtd| d| d| jd  d| jd  d	| |ddd}|S )	Nr   r   r   zInput audio size (r   r   r   r6   )r9   ro   r   r   r   r   r   )rj   rJ   r?   ro   rp   rq   rr   r+   r+   r,   rs   M  s    
(z TvltAudioPatchEmbeddings.forwardr   r+   r+   rk   r,   rw   1  s   rw   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
TvltSelfAttentionc                    s   t    |j|j dkr@t|ds@td|jf d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .bias)r\   r]   rc   num_attention_headshasattrr   r;   attention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyvalueDropoutZattention_probs_dropout_probdropoutri   rk   r+   r,   r]   _  s    
zTvltSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrF   r   r6   r   r   )r   r   r   rI   permute)rj   xZnew_x_shaper+   r+   r,   transpose_for_scoresq  s    
z&TvltSelfAttention.transpose_for_scoresNFc                 C   s   |  |}| | |}| | |}| |}t||dd}	|	t| j	 }	|d k	rh|	| }	t
jdd|	}
| |
}
|d k	r|
| }
t|
|}|dddd }| d d | jf }|j| }|r||
fn|f}|S )NrF   rP   r   r6   r   r   )r   r   r   r   r'   matmulr   mathsqrtr   r   ZSoftmaxr   r   
contiguousr   r   rI   )rj   r!   attention_mask	head_maskoutput_attentionsZmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr+   r+   r,   rs   v  s$    



zTvltSelfAttention.forward)NNF)r#   r$   r%   r]   r   rs   ru   r+   r+   rk   r,   r   ^  s   r   c                       s@   e Zd ZdZedd fddZejejejdddZ  Z	S )	TvltSelfOutputz
    The residual connection is defined in TvltLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    Nrh   r   c                    s.   t    t|j|j| _t|j| _d S N)	r\   r]   r   r   rc   denser   hidden_dropout_probr   ri   rk   r+   r,   r]     s    
zTvltSelfOutput.__init__r!   input_tensorr   c                 C   s   |  |}| |}|S r   r   r   rj   r!   r   r+   r+   r,   rs     s    

zTvltSelfOutput.forward)
r#   r$   r%   r&   r   r]   r'   r   rs   ru   r+   r+   rk   r,   r     s   r   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
TvltAttentionc                    s*   t    t|| _t|| _t | _d S r   )r\   r]   r   	attentionr   outputsetpruned_headsri   rk   r+   r,   r]     s    


zTvltAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rP   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)rj   headsrS   r+   r+   r,   prune_heads  s       zTvltAttention.prune_headsNFc                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r   r   )rj   r!   r   r   r   Zself_outputsattention_outputr   r+   r+   r,   rs     s    zTvltAttention.forward)NNF)r#   r$   r%   r]   r   rs   ru   r+   r+   rk   r,   r     s   r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )TvltIntermediateNr   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r\   r]   r   r   rc   intermediate_sizer   r   Z
hidden_actstrr   intermediate_act_fnri   rk   r+   r,   r]     s
    
zTvltIntermediate.__init__)r!   r   c                 C   s   |  |}| |}|S r   )r   r   rj   r!   r+   r+   r,   rs     s    

zTvltIntermediate.forward	r#   r$   r%   r   r]   r'   r   rs   ru   r+   r+   rk   r,   r     s   r   c                       s<   e Zd Zedd fddZejejejdddZ  ZS )
TvltOutputNr   c                    s.   t    t|j|j| _t|j| _	d S r   )
r\   r]   r   r   r   rc   r   r   r   r   ri   rk   r+   r,   r]     s    
zTvltOutput.__init__r   c                 C   s    |  |}| |}|| }|S r   r   r   r+   r+   r,   rs     s    

zTvltOutput.forwardr   r+   r+   rk   r,   r     s   r   c                       s*   e Zd ZdZ fddZdddZ  ZS )		TvltLayerz?This corresponds to the Block class in the timm implementation.c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S Nr   Zeps)r\   r]   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   	LayerNormrc   layer_norm_epslayernorm_beforelayernorm_afterri   rk   r+   r,   r]     s    



zTvltLayer.__init__NFc           	      C   sj   | j | ||||d}|d }|dd  }|||j }| |}| |}| ||}|f| }|S )Nr   r   r   )r   r   tor8   r   r   r   )	rj   r!   r   r   r   Zself_attention_outputsr   r   Zlayer_outputr+   r+   r,   rs     s    


zTvltLayer.forward)NNFrt   r+   r+   rk   r,   r     s   
r   c                       s&   e Zd Z fddZdddZ  ZS )	TvltEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r+   r   .0_rh   r+   r,   
<listcomp>  s     z(TvltEncoder.__init__.<locals>.<listcomp>F)	r\   r]   rh   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingri   rk   r   r,   r]     s    
 zTvltEncoder.__init__NFTc                    s   |rdnd } rdnd }t | jD ]\}	}
|r8||f }|d k	rH||	 nd }| jr~| jr~ fdd}tjj||
|||}n|
||| }|d } r"||d f }q"|r||f }|stdd |||fD S t|||dS )	Nr+   c                    s    fdd}|S )Nc                     s    | f S r   r+   inputsmoduler   r+   r,   custom_forward5  s    zJTvltEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr+   r   r   r   r   r,   create_custom_forward4  s    z2TvltEncoder.forward.<locals>.create_custom_forwardr   r   c                 s   s   | ]}|d k	r|V  qd S r   r+   r   vr+   r+   r,   	<genexpr>L  s      z&TvltEncoder.forward.<locals>.<genexpr>)r   r!   r"   )		enumerater   r   trainingr'   utils
checkpointtupler   )rj   r!   r   r   r   output_hidden_statesreturn_dictall_hidden_statesall_self_attentionsilayer_moduleZlayer_head_maskr   layer_outputsr+   r   r,   rs      s6    	

zTvltEncoder.forward)NNFFTr#   r$   r%   r]   rs   ru   r+   r+   rk   r,   r     s   	     r   c                   @   s2   e Zd ZdZeZdZdZdZdd Z	ddd	Z
d
S )TvltPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    tvltr<   Tc                 C   sj   t |tjtjfr@|jjjd| jjd |j	dk	rf|j	j
  n&t |tjrf|j	j
  |jjd dS )zInitialize the weights        )meanZstdNg      ?)r   r   r   r   weightdataZnormal_rh   Zinitializer_ranger   Zzero_r   Zfill_)rj   r   r+   r+   r,   _init_weights_  s    
z!TvltPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r   )r   r   r   )rj   r   r   r+   r+   r,   _set_gradient_checkpointingk  s    
z/TvltPreTrainedModel._set_gradient_checkpointingN)F)r#   r$   r%   r&   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr  r  r+   r+   r+   r,   r   T  s   r   aF  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`TvltConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a	  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        audio_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Audio values. Audio values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        pixel_mask (`torch.FloatTensor` of shape `(batch_size, num_pixel_patches)`):
            Pixel masks. Pixel masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        audio_mask (`torch.FloatTensor` of shape `(batch_size, num_audio_patches)`):
            Audio masks. Audio masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Pixel values mixed can
            be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.

        pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel masks of pixel_values_mixed. Pixel masks mixed can be obtained using [`TvltProcessor`]. See
            [`TvltProcessor.__call__`] for details.

        mask_pixel (`bool`, *optional*):
            Whether to mask pixel for MAE tasks. Only set to True in TvltForPreTraining.

        mask_audio (`bool`, *optional*):
            Whether to mask audio for MAE tasks. Only set to True in TvltForPreTraining.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.

        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare TVLT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Z fddZdd Zdd Zeeee	e
ddejejeej eej eeee ee ee eeej e	f d

ddZ  ZS )	TvltModelc                    sv   t  | || _t|| _t|| _t|| _t	
tdd|j| _|jrVd | _nt	j|j|jd| _|   d S r   )r\   r]   rh   rY   pixel_embeddingsrv   audio_embeddingsr   encoderr   ra   r'   rb   rc   cls_embeddingZuse_mean_pooling	layernormr   r   	post_initri   rk   r+   r,   r]     s    


zTvltModel.__init__c                 C   s   | j j| jjfS r   )r  r_   r  )rj   r+   r+   r,   get_input_embeddings  s    zTvltModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )rj   Zheads_to_pruner   r   r+   r+   r,   _prune_heads  s    zTvltModel._prune_headsoutput_typer  NF)
r<   rJ   r=   rK   
mask_pixel
mask_audior   r   r   r   c
                 C   s  |dk	r|n| j j}|dk	r |n| j j}|	dk	r4|	n| j j}	| ||\}
}| ||\}}d}d}|rt|
|| j jd\}}t|
|||d\}
}}}d}d}|r| j j	| j j
d  }t||| j j| j j|d\}}t||||d\}}}}|d}t| j|dd|
|gd}|
d}d}|dk	rX|dk	rXt|ddddf ||gd}| }d}|dk	rz| ||}| j|||||	d}|d }| jdk	r| |}|dddd| f }|ddd| df }|	s |||||||f|dd  S t||||||||j|jd	S )	a  
        Returns:

        Examples:

        ```python
        >>> from transformers import TvltProcessor, TvltModel
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))

        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltModel.from_pretrained("ZinengTang/tvlt-base")

        >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```N)r=   r>   )rV   r   )rK   r>   rL   rM   r   )r   r   r   r   )	r   r   r   r   r   r   r    r!   r"   )rh   r   r   use_return_dictr  r  rC   Zpixel_mask_ratiorX   rz   r{   rO   Zaudio_mask_ratioZaudio_mask_typer   r'   catr	  rH   Zget_extended_attention_maskr  r
  r   r!   r"   )rj   r<   rJ   r=   rK   r  r  r   r   r   Zpixel_embedding_outputZaudio_embedding_outputr   r   Zpixel_mask_noiseZpixel_len_keepr   r    r|   Zaudio_mask_noiseZaudio_len_keepr?   Zembedding_outputZmasked_pixel_lenr   Zinput_shapeZextended_attention_maskZencoder_outputssequence_outputpixel_sequence_outputaudio_sequence_outputr+   r+   r,   rs     s    %  


 
"



zTvltModel.forward)NNFFNNN)r#   r$   r%   r]   r  r  r   TVLT_INPUTS_DOCSTRINGr   r   _CONFIG_FOR_DOCr'   r(   r   boolr   r   rs   ru   r+   r+   rk   r,   r    s0   
       r  c                       s&   e Zd Z fddZdddZ  ZS )TvltDecoderc                    sv   t    t| |j _|j _|j _|j	 _
t fddt|jD | _tj|j|jd| _d| _|| _d S )Nc                    s   g | ]}t  qS r+   r   r   Zdecoder_configr+   r,   r   Y  s     z(TvltDecoder.__init__.<locals>.<listcomp>r   F)r\   r]   r   decoder_hidden_sizerc   Zdecoder_num_hidden_layersr   Zdecoder_num_attention_headsr   Zdecoder_intermediate_sizer   r   r   r   decoder_layersr   r   r
  r   rh   ri   rk   r  r,   r]   P  s    
zTvltDecoder.__init__FTc                    s   |rdnd } rdnd }t | jD ]l\}}|r8||f }| jrh| jrh fdd}	tjj|	||d }
n|| d}
|
d } r"||
d f }q"|r||f }| |}|stdd |||fD S t	|||d	S )
Nr+   c                    s    fdd}|S )Nc                     s    | f S r   r+   r   r   r+   r,   r   r  s    zJTvltDecoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr+   r   r   r   r,   r   q  s    z2TvltDecoder.forward.<locals>.create_custom_forwardr   r   r   c                 s   s   | ]}|d k	r|V  qd S r   r+   r   r+   r+   r,   r     s      z&TvltDecoder.forward.<locals>.<genexpr>)r.   r!   r"   )
r   r  r   r   r'   r   r   r
  r   r-   )rj   r!   r   r   r   r   r   r   r   r   r   r.   r+   r   r,   rs   a  s,    


zTvltDecoder.forward)FFTr   r+   r+   rk   r,   r  O  s
      r  zTThe TVLT Model transformer with the decoder on top for self-supervised pre-training.c                       s   e Zd Z fddZdd Zdd Zdd Zd	d
 Zdd Ze	e
eeeddejejeej eej eej eej eej ee ee ee eeej ef dddZ  ZS )TvltForPreTrainingc           	         s  t  | || _|j| _|j| _| js6| js6tdt|| _| jrPt|| _	| jrt
j|j|jdd| _t
tdd|j| _t
tdd|j| _t|| _|j}|j}| jjj}t
td||| _t
td|j|| _t
tdd|| _| jjj}|j|jd  }t
td|| || _ t
td||| _!t
tdd|| _"| jj#d d | jj$ }t%||| _&| jjd | jjd  | jj' }t%||| _(|| _|| _|| _)|j#| _#|j| _| *  d S )Nz;Must set at least one of matching task and MAE task to trueTr   r   r   r6   )+r\   r]   rh   task_matchingtask_maer   r  r   TvltMatchingHeadmatching_headr   r   rc   r  encoder_to_decoderra   r'   rb   pixel_mask_tokenaudio_mask_tokenr  decoderre   r  r`   decoder_pixel_pos_embeddecoder_temporal_embeddecoder_pixel_type_embedr  rx   rz   r{   decoder_audio_pos_embeddecoder_freq_embeddecoder_audio_type_embedr   r   TvltMAEHeadpixel_mae_headr   audio_mae_headr|   r  )	rj   rh   r  re   r`   Znum_audio_patchesr|   Zpixel_mae_output_dimZaudio_mae_output_dimrk   r+   r,   r]     sL    




zTvltForPreTraining.__init__c           
   	   C   s   |j \}}}}}|j d | jd  }|j d | jd  }|j||||| jd || jd fd}	td|	}	|	j||| | | jd | jd  | fd}	|	S )zJ
        pixel_values: [batch_size, num_frames, 3, height, width]
        r   r      r   r9   zntchpwq->nthwpqc)r9   r   r   r'   einsum)
rj   r<   r?   re   ro   rp   rq   num_patches_heightnum_patches_widthpatchified_pixel_valuesr+   r+   r,   patchify_pixel  s*    
z!TvltForPreTraining.patchify_pixelc           	      C   s   |j \}}}}|| jd  }|| jd  }|j|||| jd || jd fd}td|}|j||| | jd | jd  | fd}|S )z>
        audio_values: [batch_size, 1, height, width]
        r   r   r2  znchpwq->nhwpqc)r9   r{   r   r'   r3  )	rj   rJ   r?   ro   rp   rq   r4  r5  patchified_audio_valuesr+   r+   r,   patchify_audio  s(    
z!TvltForPreTraining.patchify_audioc                 C   s:   |  |}|| d }|jdd}||  |  }|S Nr6   rF   rP   )r7  r   sum)rj   r<   Zpixel_predictionsmaskr6  r0   r+   r+   r,   pixel_mae_loss  s
    
z!TvltForPreTraining.pixel_mae_lossc                 C   s:   |  |}|| d }|jdd}||  |  }|S r:  )r9  r   r;  )rj   rJ   Zaudio_predictionsr<  r8  r0   r+   r+   r,   audio_mae_loss  s
    
z!TvltForPreTraining.audio_mae_lossc           	      C   sZ   |j \}}}|||j d | d}tj||gdd}tj|d|ddd|d}|S )Nr   rP   rF   rR   )r9   rH   r'   r  rT   rG   )	rj   Z
mask_tokenrU   rW   r?   Z
seq_lengthrQ   Zmask_tokensZpadded_sequencer+   r+   r,   concatenate_mask  s      z#TvltForPreTraining.concatenate_maskr  N)r<   rJ   r=   rK   labelspixel_values_mixedpixel_mask_mixedr   r   r   r   c                  C   s  |
dk	r|
n| j j}
d}| jr|dkr.td|dkr>td| j||||||	|
d}|d }| |}t }||d|d}||7 }d}d}| jr\| j	r\| j||||dd||	|
d		}|
r|j
n|d
 }|
r|jn|d }|
r|jn|d }|
r|jn|d }|
r|jn|d }|
r*|jn|d }| |}| |}|d
}| | j||}|| jd
|d
 }|tj| jddd|f | jd
d }|| j }| |}| |j}| | j||}|d
| j }|| j d
|d
 }|tj| j!ddd|f | jd
d }|| j" }| |}| #|j}| $|||| %||| }||7 }|
s|||f|dd  }|dk	r|f| S |S t&|||||j'|j(dS )aF  
        pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be
            obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.

        pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel masks of pixel_values_mixed. Pixel values mixed can be obtained using [`TvltProcessor`]. See
            [`TvltProcessor.__call__`] for details.

        labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the vision audio matching loss. Indices should be in `[0, 1]`. num_labels has to be 1.

        Return:

        Examples:

        ```python
        >>> from transformers import TvltProcessor, TvltForPreTraining
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> images_mixed = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))
        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltForPreTraining.from_pretrained("ZinengTang/tvlt-base")
        >>> input_dict = processor(
        ...     images, audio, images_mixed, sampling_rate=44100, mask_pixel=True, mask_audio=True, return_tensors="pt"
        ... )

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```Nr   zMatching task requires labelsz)Matching task requires pixel_values_mixedr=   rK   r   r   r   r   rF   T)r=   rK   r  r  r   r   r   r   r6   r   r1        rP      )r0   r1   r2   r3   r!   r"   ))rh   r  r   r   r   r#  r   rI   r!  r   r   r   r   r   r   r    r$  r   r?  r%  r(  rH   r'   rn   r)  r`   r*  r'  r/  r.   r&  r|   r,  r+  r-  r0  r=  r>  r/   r!   r"   ) rj   r<   rJ   r=   rK   r@  rA  rB  r   r   r   Z
total_lossr   r  r1   loss_fctr0   r2   r3   r  r  r   r   r   r    Zpixel_decoder_inputZaudio_decoder_inputre   Zpixel_decoder_outputsrN   Zaudio_decoder_outputsr   r+   r+   r,   rs     s    1


  

  

  zTvltForPreTraining.forward)NNNNNNNN)r#   r$   r%   r]   r7  r9  r=  r>  r?  r   r  r   r/   r  r'   r(   r   r*   r  r   r   rs   ru   r+   r+   rk   r,   r    s:   6	
        r  c                       s$   e Zd Z fddZdd Z  ZS )
TvltPoolerc                    s*   t    t|j|j| _t | _d S r   )r\   r]   r   r   rc   r   ZTanh
activationri   rk   r+   r,   r]     s    
zTvltPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   rI  )rj   r!   Zfirst_token_tensorZpooled_outputr+   r+   r,   rs     s    

zTvltPooler.forwardr   r+   r+   rk   r,   rH    s   rH  c                       s$   e Zd Z fddZdd Z  ZS )r"  c                    s(   t    t|| _t|jd| _d S r[   )r\   r]   rH  poolerr   r   rc   fcri   rk   r+   r,   r]     s    

zTvltMatchingHead.__init__c                 C   s   |  | |}|S r   )rK  rJ  r   r+   r+   r,   rs     s    zTvltMatchingHead.forwardr   r+   r+   rk   r,   r"    s   r"  c                       s&   e Zd Zd fdd	Zdd Z  ZS )r.  Nc                    s$   t    || _t|j|| _d S r   )r\   r]   rh   r   r   r  r'  )rj   rh   Z
output_dimrk   r+   r,   r]     s    
zTvltMAEHead.__init__c                 C   s   |  |}|S r   )r'  r   r+   r+   r,   rs     s    
zTvltMAEHead.forward)Nr   r+   r+   rk   r,   r.    s   r.  z
    Tvlt Model transformer with a classifier head on top (an MLP on top of the final hidden state of the [CLS] token)
    for audiovisual classification tasks, e.g. CMU-MOSEI Sentiment Analysis and Audio to Video Retrieval.
    c                       s   e Zd Z fddZeeeeedde	j
e	j
ee	j
 ee	j
 ee ee ee ee	j eee	j
 ef d	ddZ  ZS )	 TvltForAudioVisualClassificationc              	      sp   t  | t|| _tt|j|jd tj|jd |j	dt
 t|jd |j| _|| _|   d S )Nr6   r   )r\   r]   r  r   r   Z
Sequentialr   rc   r   r   ZGELUZ
num_labels
classifierrh   r  ri   rk   r+   r,   r]     s    
z)TvltForAudioVisualClassification.__init__r  N)	r<   rJ   r=   rK   r   r   r   r@  r   c	              	   C   s   |dk	r|n| j j}| j|||||||d}	|	d dddf }
| |
}d}|dk	r| j jdkrtt }|||}n| j jdkrt }|||}|s|f|	dd  }|dk	r|f| S |S t|||	j|	j	dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes
            refers to the number of classes in audiovisual tasks.

        Return:

        Examples:
        ```python
        >>> from transformers import TvltProcessor, TvltForAudioVisualClassification
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))
        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltForAudioVisualClassification.from_pretrained("ZinengTang/tvlt-base")
        >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```NrC  r   Z
regressionZclassificationr1  )r0   r.   r!   r"   )
rh   r  r   rM  Z	loss_typer
   r	   r   r!   r"   )rj   r<   rJ   r=   rK   r   r   r   r@  r   r  r.   r0   rG  r   r+   r+   r,   rs     s:    $	

z(TvltForAudioVisualClassification.forward)NNNNNN)r#   r$   r%   r]   r   r  r   r   r  r'   r(   r   r  r*   r   r   rs   ru   r+   r+   rk   r,   rL    s(   
      rL  )Nr4   )Nr4   rD   rE   )N)Gr&   collections.abcr   r   copyr   dataclassesr   typingr   r   r   r'   Ztorch.utils.checkpointr   Ztorch.nnr   r	   r
   Zactivationsr   Zmodeling_outputsr   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   r   r   r   Zconfiguration_tvltr   Z
get_loggerr#   loggerr  Z_CHECKPOINT_FOR_DOCZ"TVLT_PRETRAINED_MODEL_ARCHIVE_LISTr   r-   r/   rC   rO   rX   ModulerY   rv   r^   rw   r   r   r   r   r   r   r   r   ZTVLT_START_DOCSTRINGr  r  r  r  rH  r"  r.  rL  r+   r+   r+   r,   <module>   s   
(!
	

)-=#';- $@  