U
    ,-e                     @   s  d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZ ddlmZ dd	lmZmZmZmZmZmZ dd
lmZ ddlmZmZmZ ddlm Z m!Z!m"Z"m#Z# ddl$m%Z% e"&e'Z(dZ)dZ*dgZ+eG dd deZ,G dd dej-Z.G dd dej-Z/G dd dej-Z0G dd dej-Z1G dd dej-Z2G dd dej-Z3G dd  d ej-Z4G d!d" d"ej-Z5G d#d$ d$ej-Z6G d%d& d&ej-Z7G d'd( d(eZ8d)Z9d*Z:d+Z;e d,e9G d-d. d.e8Z<G d/d0 d0ej-Z=e d1e9G d2d3 d3e8Z>G d4d5 d5ej-Z?G d6d7 d7ej-Z@e d8e9G d9d: d:e8ZAe d;e9G d<d= d=e8ZBe d>e;G d?d@ d@e8ZCe dAe9G dBdC dCe8ZDdS )Dz PyTorch ViLT model.    N)	dataclass)ListOptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)add_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
ViltConfigr   zdandelin/vilt-b32-mlmc                   @   sf   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeeej   ed< dZeeeej   ed< dS )(ViltForImagesAndTextClassificationOutputa  
    Class for outputs of [`ViltForImagesAndTextClassification`].

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            List of tuples of `torch.FloatTensor` (one for each image-text pair, each tuple containing the output of
            the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            List of tuples of `torch.FloatTensor` (one for each image-text pair, each tuple containing the attention
            weights of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the
            attention softmax, used to compute the weighted average in the self-attention heads.
    Nlosslogitshidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r   r    r'   r'   g/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/vilt/modeling_vilt.pyr   9   s
   
r   c                       s4   e Zd ZdZ fddZd
ddZddd	Z  ZS )ViltEmbeddingsz
    Construct the text and patch embeddings.

    Text embeddings are equivalent to BERT embeddings.

    Patch embeddings are equivalent to ViT embeddings.
    c                    s   t    t|| _ttdd|j| _	t
|| _| jj}ttd|d |j| _t|j|j| _t|j| _|| _d S Nr   )super__init__TextEmbeddingstext_embeddingsr   	Parameterr$   zeroshidden_size	cls_tokenViltPatchEmbeddingspatch_embeddingsnum_patchesposition_embeddings	EmbeddingZmodality_type_vocab_sizetoken_type_embeddingsDropouthidden_dropout_probdropoutconfig)selfr<   r5   	__class__r'   r(   r,   \   s    


zViltEmbeddings.__init__   c                    s  | j jjj\}}}}|  |}|d d d d d d d f  }tjj||jd |jd fd }|d d df j	ddd d df }	|d d df j	ddd d df }
|j\}} | j
j| j
j }| jd d dd d d f ddd|||tj fddt|	|
D dd}|ddd}|ddd}tjtt|jd	 t|jd
 ddd
dj|jd}|d d d d d d d d f }||jd |jd d
d
d
}|dd}|d}dk sd kstts |	|
 }| n|	|
 }t| |jddd| jddd d df  }fdd|D }fdd|D }dd |D }dd |D }fdd|D }g }t t|||D ]\}\}}}|dkrt!t"| }|#|| |  n>tj!t"| |dd}|#tj|| || | gdd qtj|dd}||d d df |d d df f |d
|}||d d df |d d df f |d
}||d d df |d d df f |d
d}||d d df |d d df f |d
|}| j$|d
d
}tj||fdd}tj| jd d dd d f d d d d d f |d
d
|fdd}|| }| %|}tjt"|jd d||gdd}||| fffS )N   r	   sizer   r   dimc              
      sB   g | ]:\}}t jt jj||fd ddd| d | fqS )ZbilinearT)rC   modeZalign_cornersr   )r   
functionalpadinterpolate).0hw)heightspatial_poswidthr'   r(   
<listcomp>x   s   
z/ViltEmbeddings.visual_embed.<locals>.<listcomp>Zij)ZindexingdeviceF)as_tuplec                    s$   g | ]}  d d df |k qS Nr   r'   rJ   u)	valid_idxr'   r(   rP      s     c                    s$   g | ]}  d d df |k qS rV   r'   rW   )non_valid_idxr'   r(   rP      s     c                 S   s   g | ]}| d qS r   rB   rJ   vr'   r'   r(   rP      s     c                 S   s   g | ]}| d qS r[   rB   r\   r'   r'   r(   rP      s     c                    s   g | ]} | qS r'   r'   r\   max_image_lengthr'   r(   rP      s     T)replacement)&r4   
projectionweightshapefloatr   rG   rI   longsumr<   
image_size
patch_sizer6   	transposeviewr$   catzipflattenstackr   arangetorT   expand
isinstanceintmaxminZnonzerounique	enumerateZmultinomialonesappendr2   r;   )r=   pixel_values
pixel_maskr_   _phpwxZx_maskZx_hZx_w
batch_sizenum_channelsZ	patch_dimZ	pos_embedpatch_indexZeffective_resolutionZunique_rowsZvalid_row_idxZnon_valid_row_idxZ
valid_numsZnon_valid_numsZpad_numsselectir]   nvpZvalid_choiceZ
pad_choiceZ
cls_tokensr'   )rM   r_   rZ   rN   rY   rO   r(   visual_embedk   s|    
 $$$0
$ 
 

(.,..6 
&zViltEmbeddings.visual_embedr   c	              	   C   s   | j |||d}	|d kr4| j||| jjd\}}
}n
|d}
|d krJd}|	| tj|tj|	j	d }	|| tj
|
|tj|	j	d }tj|	|gdd}tj||
gdd}||fS )N)	input_idstoken_type_idsinputs_embedsr^   r   dtyperT   rD   )r.   r   r<   r_   rm   r8   r$   Z
zeros_likere   rT   Z	full_likerk   )r=   r   attention_maskr   rz   r{   r   image_embedsimage_token_type_idxZtext_embedsZimage_masksr   
embeddingsmasksr'   r'   r(   forward   s.        
zViltEmbeddings.forward)r@   )r   )r    r!   r"   r#   r,   r   r   __classcell__r'   r'   r>   r(   r)   S   s
   
a r)   c                       s*   e Zd ZdZ fddZdddZ  ZS )r-   zGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd d S )N)padding_idxZepsposition_embedding_typeabsoluteposition_ids)r   rR   F)
persistentr   )r   )r+   r,   r   r7   
vocab_sizer1   Zpad_token_idword_embeddingsZmax_position_embeddingsr6   Ztype_vocab_sizer8   	LayerNormlayer_norm_epsr9   r:   r;   getattrr   Zregister_bufferr$   ro   rq   r0   r   rC   re   r=   r<   r>   r'   r(   r,      s"    
    zTextEmbeddings.__init__Nc                 C   s   |d k	r|  }n|  d d }|d }|d krH| jd d d |f }|d krt| dr| jd d d |f }||d |}|}ntj|tj| jjd}|d kr| 	|}| 
|}	||	 }
| jdkr| |}|
|7 }
| |
}
| |
}
|
S )NrR   r   r   r   r   r   )rC   r   hasattrr   rq   r$   r0   re   rT   r   r8   r   r6   r   r;   )r=   r   r   r   r   input_shape
seq_lengthZbuffered_token_type_idsZ buffered_token_type_ids_expandedr8   r   r6   r'   r'   r(   r     s,    







zTextEmbeddings.forward)NNNNr    r!   r"   r#   r,   r   r   r'   r'   r>   r(   r-      s   r-   c                       s(   e Zd ZdZ fddZdd Z  ZS )r3   z#
    Image to Patch Embedding.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )Zkernel_sizeZstride)r+   r,   rg   rh   r   r1   rr   collectionsabcIterabler5   r   Conv2dra   )r=   r<   rg   rh   r   r1   r5   r>   r'   r(   r,   +  s    
 zViltPatchEmbeddings.__init__c                 C   s.   |j \}}}}|| jkr td| |}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.)rc   r   
ValueErrorra   )r=   rz   r   r   rM   rO   r   r'   r'   r(   r   :  s    

zViltPatchEmbeddings.forwardr   r'   r'   r>   r(   r3   &  s   r3   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
ViltSelfAttentionc                    s   t    |j|j dkr@t|ds@td|jf d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .bias)r+   r,   r1   num_attention_headsr   r   rs   attention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyvaluer9   Zattention_probs_dropout_probr;   r   r>   r'   r(   r,   E  s    
zViltSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrR   r   rA   r   r	   )rC   r   r   rj   permute)r=   r   Znew_x_shaper'   r'   r(   transpose_for_scoresW  s    
z&ViltSelfAttention.transpose_for_scoresNFc                 C   s   |  |}| | |}| | |}| |}t||dd}	|	t| j	 }	|d k	rh|	| }	t
jdd|	}
| |
}
|d k	r|
| }
t|
|}|dddd }| d d | jf }|j| }|r||
fn|f}|S )NrR   rQ   rD   r   rA   r   r	   )r   r   r   r   r$   matmulri   mathsqrtr   r   ZSoftmaxr;   r   
contiguousrC   r   rj   )r=   r   r   	head_maskoutput_attentionsZmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr'   r'   r(   r   \  s$    



zViltSelfAttention.forward)NNF)r    r!   r"   r,   r   r   r   r'   r'   r>   r(   r   D  s   r   c                       s@   e Zd ZdZedd fddZejejejdddZ  Z	S )	ViltSelfOutputz
    The residual connection is defined in ViltLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    Nr<   returnc                    s.   t    t|j|j| _t|j| _d S N)	r+   r,   r   r   r1   denser9   r:   r;   r   r>   r'   r(   r,     s    
zViltSelfOutput.__init__r   input_tensorr   c                 C   s   |  |}| |}|S r   r   r;   r=   r   r   r'   r'   r(   r     s    

zViltSelfOutput.forward)
r    r!   r"   r#   r   r,   r$   Tensorr   r   r'   r'   r>   r(   r     s   r   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
ViltAttentionc                    s*   t    t|| _t|| _t | _d S r   )r+   r,   r   	attentionr   outputsetpruned_headsr   r>   r'   r(   r,     s    


zViltAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rD   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)r=   headsindexr'   r'   r(   prune_heads  s       zViltAttention.prune_headsNFc                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r   r   )r=   r   r   r   r   Zself_outputsattention_outputr   r'   r'   r(   r     s    zViltAttention.forward)NNF)r    r!   r"   r,   r   r   r   r'   r'   r>   r(   r     s   r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )ViltIntermediateNr   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r+   r,   r   r   r1   intermediate_sizer   rr   
hidden_actstrr
   intermediate_act_fnr   r>   r'   r(   r,     s
    
zViltIntermediate.__init__)r   r   c                 C   s   |  |}| |}|S r   )r   r   r=   r   r'   r'   r(   r     s    

zViltIntermediate.forward	r    r!   r"   r   r,   r$   r   r   r   r'   r'   r>   r(   r     s   r   c                       s<   e Zd Zedd fddZejejejdddZ  ZS )
ViltOutputNr   c                    s.   t    t|j|j| _t|j| _	d S r   )
r+   r,   r   r   r   r1   r   r9   r:   r;   r   r>   r'   r(   r,     s    
zViltOutput.__init__r   c                 C   s    |  |}| |}|| }|S r   r   r   r'   r'   r(   r     s    

zViltOutput.forwardr   r'   r'   r>   r(   r     s   r   c                       s*   e Zd ZdZ fddZdddZ  ZS )		ViltLayerz?This corresponds to the Block class in the timm implementation.c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   r   )r+   r,   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   r   r1   r   layernorm_beforelayernorm_afterr   r>   r'   r(   r,     s    



zViltLayer.__init__NFc           	      C   sj   | j | ||||d}|d }|dd  }|||j }| |}| |}| ||}|f| }|S )Nr   r   r   )r   r   rp   rT   r   r   r   )	r=   r   r   r   r   Zself_attention_outputsr   r   Zlayer_outputr'   r'   r(   r     s    


zViltLayer.forward)NNFr   r'   r'   r>   r(   r     s   
r   c                       s&   e Zd Z fddZdddZ  ZS )	ViltEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r'   )r   )rJ   r|   r<   r'   r(   rP      s     z(ViltEncoder.__init__.<locals>.<listcomp>F)	r+   r,   r<   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   r>   r   r(   r,     s    
 zViltEncoder.__init__NFTc                    s   |rdnd } rdnd }t | jD ]\}	}
|r8||f }|d k	rH||	 nd }| jr~| jr~ fdd}tjj||
|||}n|
||| }|d } r"||d f }q"|r||f }|stdd |||fD S t|||dS )	Nr'   c                    s    fdd}|S )Nc                     s    | f S r   r'   )inputs)moduler   r'   r(   custom_forward  s    zJViltEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr'   )r   r   r   )r   r(   create_custom_forward  s    z2ViltEncoder.forward.<locals>.create_custom_forwardr   r   c                 s   s   | ]}|d k	r|V  qd S r   r'   r\   r'   r'   r(   	<genexpr>/  s      z&ViltEncoder.forward.<locals>.<genexpr>)last_hidden_stater   r   )	rw   r   r   Ztrainingr$   utils
checkpointtupler   )r=   r   r   r   r   output_hidden_statesreturn_dictZall_hidden_statesZall_self_attentionsr   Zlayer_moduleZlayer_head_maskr   Zlayer_outputsr'   r   r(   r     s6    	

zViltEncoder.forward)NNFFTr    r!   r"   r,   r   r   r'   r'   r>   r(   r     s   	     r   c                   @   s6   e Zd ZdZeZdZdZddgZdd Z	dd	d
Z
dS )ViltPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    viltTr)   r   c                 C   s   t |tjtjfr@|jjjd| jjd |j	dk	r|j	j
  nft |tjr|jjjd| jjd |jdk	r|jj|j 
  n&t |tjr|j	j
  |jjd dS )zInitialize the weightsg        )ZmeanZstdNg      ?)rr   r   r   r   rb   dataZnormal_r<   Zinitializer_ranger   Zzero_r7   r   r   Zfill_)r=   r   r'   r'   r(   _init_weightsB  s    

z!ViltPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r   )rr   r   r   )r=   r   r   r'   r'   r(   _set_gradient_checkpointingR  s    
z/ViltPreTrainedModel._set_gradient_checkpointingN)F)r    r!   r"   r#   r   config_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesr   r   r'   r'   r'   r(   r   7  s   r   aH  
    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`ViltConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            [What are attention masks?](../glossary#attention-mask)

        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)

        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`ViltImageProcessor.__call__`] for details.

        pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:

            - 1 for pixels that are real (i.e. **not masked**),
            - 0 for pixels that are padding (i.e. **masked**).
            `What are attention masks? <../glossary.html#attention-mask>`__

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.

        image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
            Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `pixel_values` into patch embeddings.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            [What are attention masks?](../glossary#attention-mask)

        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)

        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_images, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`ViltImageProcessor.__call__`] for details.

        pixel_mask (`torch.LongTensor` of shape `(batch_size, num_images, height, width)`, *optional*):
            Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:

            - 1 for pixels that are real (i.e. **not masked**),
            - 0 for pixels that are padding (i.e. **masked**).
            `What are attention masks? <../glossary.html#attention-mask>`__

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.

        image_embeds (`torch.FloatTensor` of shape `(batch_size, num_images, num_patches, hidden_size)`, *optional*):
            Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `pixel_values` into patch embeddings.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare ViLT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zd fdd	Zdd Zdd Zdd	 Zeee	e
ed
deej eej eej eej eej eej eej eej ee ee ee ee ee
eej f dddZ  ZS )	ViltModelTc                    sX   t  | || _t|| _t|| _tj|j	|j
d| _|rFt|nd | _|   d S Nr   )r+   r,   r<   r)   r   r   encoderr   r   r1   r   	layernorm
ViltPoolerpooler	post_init)r=   r<   add_pooling_layerr>   r'   r(   r,     s    

zViltModel.__init__c                 C   s
   | j jjS r   r   r.   r   r=   r'   r'   r(   get_input_embeddings  s    zViltModel.get_input_embeddingsc                 C   s   || j j_d S r   r   )r=   r   r'   r'   r(   set_input_embeddings  s    zViltModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r=   Zheads_to_pruner   r   r'   r'   r(   _prune_heads  s    zViltModel._prune_headsoutput_typer   N)r   r   r   rz   r{   r   r   r   r   r   r   r   r   c              
   C   s  |
dk	r|
n| j j}
|dk	r |n| j j}|dk	r4|n| j j}|dk	rV|dk	rVtdn@|dk	rt| || | }n"|dk	r| dd }ntd|\}}|dk	r|jn|j}|dkrtj	||f|d}|dk	r|dk	rtdn|dkr|dkrtd|dk	r|j
d n|j
d }||kr.td	|dkrTtj	|| j j| j jf|d}| || j j}| j||||||||	d
\}}| ||}| j||||
||d}|d }| |}| jdk	r| |nd}|s||f|dd  S t|||j|jdS )a  
        Returns:

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltModel
        >>> from PIL import Image
        >>> import requests

        >>> # prepare image and text
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text = "hello world"

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        >>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")

        >>> inputs = processor(image, text, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> last_hidden_states = outputs.last_hidden_state
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timerR   z5You have to specify either input_ids or inputs_embedsrS   zFYou cannot specify both pixel_values and image_embeds at the same timez7You have to specify either pixel_values or image_embedsr   zAThe text inputs and image inputs need to have the same batch size)r   )r   r   r   r   r   r   )r   pooler_outputr   r   )r<   r   r   use_return_dictr   Z%warn_if_padding_and_no_attention_maskrC   rT   r$   rx   rc   rg   Zget_head_maskr   r   Zget_extended_attention_maskr   r   r   r   r   r   )r=   r   r   r   rz   r{   r   r   r   r   r   r   r   r   Ztext_batch_sizer   rT   Zimage_batch_sizeZembedding_outputZextended_attention_maskZencoder_outputssequence_outputpooled_outputr'   r'   r(   r     sp    '






zViltModel.forward)T)NNNNNNNNNNNN)r    r!   r"   r,   r  r  r  r   VILT_INPUTS_DOCSTRINGr   r   _CONFIG_FOR_DOCr   r$   
LongTensorr%   rs   boolr   r   r   r   r'   r'   r>   r(   r     sB   
            r   c                       s$   e Zd Z fddZdd Z  ZS )r   c                    s*   t    t|j|j| _t | _d S r   )r+   r,   r   r   r1   r   ZTanh
activationr   r>   r'   r(   r,   g  s    
zViltPooler.__init__c                 C   s(   |d d df }|  |}| |}|S rV   )r   r  )r=   r   Zfirst_token_tensorr
  r'   r'   r(   r   l  s    

zViltPooler.forwardr   r'   r'   r>   r(   r   f  s   r   zU
    ViLT Model with a language modeling head on top as done during pretraining.
    c                       s   e Zd ZddgZ fddZdd Zdd Zee	d	e
eed
deej eej eej eej eej eej eej eej eej ee ee ee eeeej f dddZ  ZS )ViltForMaskedLMzmlm_score.decoder.weightzmlm_score.decoder.biasc                    s,   t  | t|| _t|| _|   d S r   )r+   r,   r   r   ViltMLMHead	mlm_scorer   r   r>   r'   r(   r,   ~  s    

zViltForMaskedLM.__init__c                 C   s   | j jS r   r  decoderr   r'   r'   r(   get_output_embeddings  s    z%ViltForMaskedLM.get_output_embeddingsc                 C   s   || j _d S r   r  )r=   Znew_embeddingsr'   r'   r(   set_output_embeddings  s    z%ViltForMaskedLM.set_output_embeddingszbatch_size, sequence_lengthr  Nr   r   r   rz   r{   r   r   r   labelsr   r   r   r   c                 C   s  |dk	r|n| j j}| j|||||||||
||d}|dd \}}|dk	rV|jd n|jd }|ddd|f |dd|df  }}| |}d}|	dk	rt }|	|j}	||d| j j	|	d}|s|f|dd  }|dk	r|f| S |S t
|||j|jdS )a	  
        labels (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in *[-100, 0, ...,
            config.vocab_size]* (see *input_ids* docstring) Tokens with indices set to *-100* are ignored (masked), the
            loss is only computed for the tokens with labels in *[0, ..., config.vocab_size]*

        Returns:

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltForMaskedLM
        >>> import requests
        >>> from PIL import Image
        >>> import re
        >>> import torch

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text = "a bunch of [MASK] laying on a [MASK]."

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        >>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")

        >>> # prepare inputs
        >>> encoding = processor(image, text, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**encoding)

        >>> tl = len(re.findall("\[MASK\]", text))
        >>> inferred_token = [text]

        >>> # gradually fill in the MASK tokens, one by one
        >>> with torch.no_grad():
        ...     for i in range(tl):
        ...         encoded = processor.tokenizer(inferred_token)
        ...         input_ids = torch.tensor(encoded.input_ids)
        ...         encoded = encoded["input_ids"][0][1:-1]
        ...         outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values)
        ...         mlm_logits = outputs.logits[0]  # shape (seq_len, vocab_size)
        ...         # only take into account text features (minus CLS and SEP token)
        ...         mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
        ...         mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
        ...         # only take into account text
        ...         mlm_values[torch.tensor(encoded) != 103] = 0
        ...         select = mlm_values.argmax().item()
        ...         encoded[select] = mlm_ids[select].item()
        ...         inferred_token = [processor.decode(encoded)]

        >>> selected_token = ""
        >>> encoded = processor.tokenizer(inferred_token)
        >>> output = processor.decode(encoded.input_ids[0], skip_special_tokens=True)
        >>> print(output)
        a bunch of cats laying on a couch.
        ```N
r   r   rz   r{   r   r   r   r   r   r   rA   r   rR   r   r   r   r   )r<   r  r   rc   r  r   rp   rT   rj   r   r   r   r   )r=   r   r   r   rz   r{   r   r   r   r  r   r   r   r   r	  r
  Ztext_seq_lenZtext_featuresr|   Z
mlm_logitsZmasked_lm_lossloss_fctr   r'   r'   r(   r     s@    I*
zViltForMaskedLM.forward)NNNNNNNNNNNN)r    r!   r"   Z_tied_weights_keysr,   r  r  r   r  formatr   r   r  r   r$   r  r%   r  r   r   r   r   r'   r'   r>   r(   r  u  sB   	
            r  c                       s$   e Zd Z fddZdd Z  ZS )ViltPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S r   )r+   r,   r   r   r1   r   rr   r   r   r
   transform_act_fnr   r   r   r>   r'   r(   r,     s    
z$ViltPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r  r   r   r'   r'   r(   r   
  s    


z#ViltPredictionHeadTransform.forwardr   r'   r'   r>   r(   r     s   	r  c                       s&   e Zd Zd fdd	Zdd Z  ZS )r  Nc                    sb   t    || _t|| _tj|j|jdd| _	t
t|j| _|d k	rT|| j	_| j| j	_d S )NFr   )r+   r,   r<   r  	transformr   r   r1   r   r  r/   r$   r0   r   rb   )r=   r<   rb   r>   r'   r(   r,     s    

zViltMLMHead.__init__c                 C   s   |  |}| |}|S r   )r  r  )r=   r   r'   r'   r(   r     s    

zViltMLMHead.forward)Nr   r'   r'   r>   r(   r    s   r  z
    Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
    token) for visual question answering, e.g. for VQAv2.
    c                       s   e Zd Z fddZeeeeedde	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e e	e e	e eeee
j f dddZ  ZS )	ViltForQuestionAnsweringc              	      sl   t  | |j| _t|| _tt|j|jd t	|jd t
 t|jd |j| _|   d S )NrA   )r+   r,   
num_labelsr   r   r   
Sequentialr   r1   r   GELU
classifierr   r   r>   r'   r(   r,   ,  s    
z!ViltForQuestionAnswering.__init__r  Nr  c                 C   s   |dk	r|n| j j}| j|||||||||
||d}|r>|jn|d }| |}d}|	dk	r|	|j}	tj	||	|	j
d  }|s|f|dd  }|dk	r|f| S |S t|||j|jdS )a  
        labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
            all answers that are applicable for a given example in the batch, or a soft encoding indicating which
            answers are applicable, where 1.0 is the highest score.

        Returns:

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltForQuestionAnswering
        >>> import requests
        >>> from PIL import Image

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text = "How many cats are there?"

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
        >>> model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

        >>> # prepare inputs
        >>> encoding = processor(image, text, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**encoding)
        >>> logits = outputs.logits
        >>> idx = logits.argmax(-1).item()
        >>> print("Predicted answer:", model.config.id2label[idx])
        Predicted answer: 2
        ```Nr  r   rA   r  )r<   r  r   r  r$  rp   rT   r   rG   Z binary_cross_entropy_with_logitsrc   r   r   r   r=   r   r   r   rz   r{   r   r   r   r  r   r   r   r   r  r   r   r   r'   r'   r(   r   =  s:    1
z ViltForQuestionAnswering.forward)NNNNNNNNNNNNr    r!   r"   r,   r   r  r   r   r  r   r$   r  r%   r  r   r   r   r   r'   r'   r>   r(   r   $  s<   
            r   z
    Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
    token) for image-to-text or text-to-image retrieval, e.g. MSCOCO and F30K.
    c                       s   e Zd Z fddZeeeeedde	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e e	e e	e eeee
j f dddZ  ZS )	ViltForImageAndTextRetrievalc                    s2   t  | t|| _t|jd| _|   d S r*   )	r+   r,   r   r   r   r   r1   rank_outputr   r   r>   r'   r(   r,     s    
z%ViltForImageAndTextRetrieval.__init__r  Nr  c                 C   s   |dk	r|n| j j}| j|||||||||
||d}|r>|jn|d }| |}d}|	dk	rp|	|j}	td|s|f|dd  }|dk	r|f| S |S t|||j	|j
dS )a'  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels are currently not supported.

        Returns:

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltForImageAndTextRetrieval
        >>> import requests
        >>> from PIL import Image

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
        >>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")

        >>> # forward pass
        >>> scores = dict()
        >>> for text in texts:
        ...     # prepare inputs
        ...     encoding = processor(image, text, return_tensors="pt")
        ...     outputs = model(**encoding)
        ...     scores[text] = outputs.logits[0, :].item()
        ```Nr  r   zTraining is not yet supported.rA   r  )r<   r  r   r  r(  rp   rT   NotImplementedErrorr   r   r   r%  r'   r'   r(   r     s:    -
z$ViltForImageAndTextRetrieval.forward)NNNNNNNNNNNNr&  r'   r'   r>   r(   r'    s<   
            r'  zq
    Vilt Model transformer with a classifier head on top for natural language visual reasoning, e.g. NLVR2.
    c                       s   e Zd Z fddZeeeeedde	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e e	e e	e eeee
j f dddZ  ZS )	"ViltForImagesAndTextClassificationc              	      sv   t  | |j| _t|| _|j}tt|j	| |j	| t
|j	| t t|j	| |j| _|   d S r   )r+   r,   r!  r   r   
num_imagesr   r"  r   r1   r   r#  r$  r   )r=   r<   r+  r>   r'   r(   r,     s    
z+ViltForImagesAndTextClassification.__init__r  Nr  c                 C   sN  |
dk	r|
n| j j}
|dk	r |n| j j}|dk	r4|n| j j}|dk	rX|jdkrX|d}|dk	rt|jdkrt|d}|dk	r|jd nd}|dkr|dk	r|jd nd}|| j jkrtdg }|rg nd}|
rg nd}t	|D ]}| j
||||dk	r|dd|ddddddf nd|dk	rB|dd|ddddf nd|||dk	rn|dd|ddddf nd|d |
||d}|r|jn|d }|| |r||j |
r||j qtj|dd}| |}d}|	dk	rt }|	|j}	||d| j|	d}|s>|||f}|dk	r:|f| S |S t||||d	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Binary classification labels.

        Returns:

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltForImagesAndTextClassification
        >>> import requests
        >>> from PIL import Image

        >>> image1 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
        >>> image2 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_1.jpg", stream=True).raw)
        >>> text = "The left image contains twice the number of dogs as the right image."

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
        >>> model = ViltForImagesAndTextClassification.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")

        >>> # prepare inputs
        >>> encoding = processor([image1, image2], text, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
        >>> logits = outputs.logits
        >>> idx = logits.argmax(-1).item()
        >>> print("Predicted answer:", model.config.id2label[idx])
        Predicted answer: True
        ```N   r   r	   z\Make sure to match the number of images in the model with the number of images in the input.)r   r   rz   r{   r   r   r   r   r   r   r   rR   rD   r  )r<   r   r   r  ndimZ	unsqueezerc   r+  r   r   r   r  ry   r   r   r$   rk   r$  r   rp   rT   rj   r!  r   )r=   r   r   r   rz   r{   r   r   r   r  r   r   r   r+  Zpooler_outputsr   r   r   r   r  r
  r   r   r  r   r'   r'   r(   r     sn    /

.((



z*ViltForImagesAndTextClassification.forward)NNNNNNNNNNNN)r    r!   r"   r,   r   r  r   r   r  r   r$   r  r%   r  r   r   r   r   r'   r'   r>   r(   r*    s<   
            r*  z
    ViLT Model with a token classification head on top (a linear layer on top of the final hidden-states of the text
    tokens) e.g. for Named-Entity-Recognition (NER) tasks.
    c                       s   e Zd Z fddZeeeeedde	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e e	e e	e eeee
j f dddZ  ZS )	ViltForTokenClassificationc                    sN   t  | |j| _t|dd| _t|j| _t	|j
|j| _|   d S )NF)r   )r+   r,   r!  r   r   r   r9   r:   r;   r   r1   r$  r   r   r>   r'   r(   r,     s    z#ViltForTokenClassification.__init__r  Nr  c                 C   s   |dk	r|n| j j}| j|||||||||
||d}|d }|dk	rN|jd n|jd }| |}| |ddd|f }d}|	dk	rt }|	|j}	||	d| j
|		d}|s|f|dd  }|dk	r|f| S |S t|||j|jdS )z
        labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Returns:
        Nr  r   r   rR   rA   r  )r<   r  r   rc   r;   r$  r   rp   rT   rj   r!  r   r   r   )r=   r   r   r   rz   r{   r   r   r   r  r   r   r   r   r	  Ztext_input_sizer   r   r  r   r'   r'   r(   r     s@    
z"ViltForTokenClassification.forward)NNNNNNNNNNNN)r    r!   r"   r,   r   r  r   r   r  r   r$   r  r%   r  r   r   r   r   r'   r'   r>   r(   r.    s<   
            r.  )Er#   collections.abcr   r   dataclassesr   typingr   r   r   r   r$   Ztorch.utils.checkpointr   Ztorch.nnr   Zactivationsr
   Zmodeling_outputsr   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   r   r   r   Zconfiguration_viltr   Z
get_loggerr    loggerr  Z_CHECKPOINT_FOR_DOCZ"VILT_PRETRAINED_MODEL_ARCHIVE_LISTr   Moduler)   r-   r3   r   r   r   r   r   r   r   r   ZVILT_START_DOCSTRINGr  Z4VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRINGr   r   r  r  r  r   r'  r*  r.  r'   r'   r'   r(   <module>   s    
 9=#&; 78  j_ 