U
    ,-e                     @   s  d Z ddlZddlZddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlmZ ddlmZmZ ddlmZmZ dd	lmZ dd
lmZmZmZmZmZmZ ddlmZ eeZ dZ!dZ"dgZ#G dd dej$Z%eG dd deZ&eG dd deZ'eG dd deZ(dd Z)G dd dej$Z*G dd dej$Z+G dd dej$Z,G dd  d ej$Z-G d!d" d"ej$Z.G d#d$ d$ej$Z/G d%d& d&ej$Z0G d'd( d(ej$Z1G d)d* d*ej$Z2G d+d, d,ej$Z3G d-d. d.ej$Z4G d/d0 d0ej$Z5G d1d2 d2ej$Z6G d3d4 d4ej$Z7G d5d6 d6ej$Z8G d7d8 d8ej$Z9G d9d: d:ej$Z:G d;d< d<eZ;d=Z<d>Z=ed?e<G d@dA dAe;Z>edBe<G dCdD dDe;Z?edEe<G dFdG dGe;Z@dS )Hz PyTorch LXMERT model.    N)	dataclass)DictOptionalTupleUnion)nn)CrossEntropyLossSmoothL1Loss   )ACT2FNgelu)PreTrainedModel)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )LxmertConfigzunc-nlp/lxmert-base-uncasedr   c                       s$   e Zd Z fddZdd Z  ZS )GeLUc                    s   t    d S N)super__init__self	__class__ k/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/lxmert/modeling_lxmert.pyr   4   s    zGeLU.__init__c                 C   s   t |S r   )r   )r   xr   r   r   forward7   s    zGeLU.forward__name__
__module____qualname__r   r!   __classcell__r   r   r   r   r   3   s   r   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dS )LxmertModelOutputak  
    Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language,
    visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship"
    encoder")


    Args:
        language_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the language encoder.
        vision_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the visual encoder.
        pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed
            by a Linear layer and a Tanh activation function. The Linear
        language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlanguage_outputvision_outputpooled_outputlanguage_hidden_statesvision_hidden_stateslanguage_attentionsvision_attentionscross_encoder_attentions)r#   r$   r%   __doc__r(   r   torchFloatTensor__annotations__r)   r*   r+   r   r,   r-   r.   r/   r   r   r   r   r'   ;   s   
"r'   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dS )
 LxmertForQuestionAnsweringOutputa	  
    Output type of [`LxmertForQuestionAnswering`].

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.k.
        question_answering_score (`torch.FloatTensor` of shape `(batch_size, n_qa_answers)`, *optional*):
            Prediction scores of question answering objective (classification).
        language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlossquestion_answering_scorer+   r,   r-   r.   r/   )r#   r$   r%   r0   r5   r   r1   r2   r3   r6   r+   r   r,   r-   r.   r/   r   r   r   r   r4   i   s   
r4   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dZeeej  ed< dS )LxmertForPreTrainingOutputak  
    Output type of [`LxmertForPreTraining`].

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        cross_relationship_score (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Prediction scores of the textual matching objective (classification) head (scores of True/False
            continuation before SoftMax).
        question_answering_score (`torch.FloatTensor` of shape `(batch_size, n_qa_answers)`):
            Prediction scores of question answering objective (classification).
        language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.

    Nr5   prediction_logitscross_relationship_scorer6   r+   r,   r-   r.   r/   )r#   r$   r%   r0   r5   r   r1   r2   r3   r8   r9   r6   r+   r   r,   r-   r.   r/   r   r   r   r   r7      s   
#r7   c                 C   s  zddl }ddl}ddl}W n  tk
r<   td  Y nX tj|}t	d|  |j
|}g }g }	|D ]@\}
}t	d|
 d|  |j
||
}||
 |	| qrt||	D ]\}
}|
d}
tdd	 |
D rt	d
d|
  q| }|
D ]}|d|r&|d|}n|g}|d dksH|d dkrTt|d}n|d dksp|d dkr|t|d}nz|d dkrt|d}n`|d dkrt|d}nFzt||d }W n2 tk
r   t	d
d|
  Y qY nX t|dkrt|d }|| }q|dd dkr:t|d}n|dkrN||}z|j|jksbtW n< tk
r } z| j|j|jf7  _ W 5 d}~X Y nX t	d|
  t||_q| S )z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape /c                 s   s   | ]}|d kV  qdS ))Zadam_vZadam_mZAdamWeightDecayOptimizerZAdamWeightDecayOptimizer_1Zglobal_stepNr   ).0nr   r   r   	<genexpr>   s   	z,load_tf_weights_in_lxmert.<locals>.<genexpr>z	Skipping z[A-Za-z]+_\d+z_(\d+)ZkernelgammaweightZoutput_biasbetabiasZoutput_weightsZsquad
classifier   r   iZ_embeddingszInitialize PyTorch weight )renumpyZ
tensorflowImportErrorloggererrorospathabspathinfotrainZlist_variablesZload_variableappendzipsplitanyjoin	fullmatchgetattrAttributeErrorlenint	transposeshapeAssertionErrorargsr1   Z
from_numpydata)modelconfigZtf_checkpoint_pathrD   nptfZtf_pathZ	init_varsnamesZarraysnamerY   arrayZpointerZm_nameZscope_namesnumer   r   r   load_tf_weights_in_lxmert   sv    

	


rf   c                       s*   e Zd ZdZ fddZdddZ  ZS )LxmertEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    sp   t    tj|j|jdd| _tj|j|jdd| _tj|j	|jdd| _
tj|jdd| _t|j| _d S )Nr   )padding_idx-q=Zeps)r   r   r   	Embedding
vocab_sizehidden_sizeword_embeddingsZmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormDropouthidden_dropout_probdropoutr   r^   r   r   r   r     s    
zLxmertEmbeddings.__init__Nc                 C   s   |d k	r|  }|j}n|  d d }|j}|d }tj|tj|d}|d|}|d krvtj|tj| jjd}|d kr| 	|}| 
|}| |}	|| |	 }
| |
}
| |
}
|
S )Nr   dtypedevicer   )sizery   r1   Zarangelong	unsqueezeexpandzerosposition_idsrn   ro   rp   rq   rt   )r   	input_idstoken_type_idsinputs_embedsinput_shapery   Z
seq_lengthr   ro   rp   
embeddingsr   r   r   r!     s$    




zLxmertEmbeddings.forward)NN)r#   r$   r%   r0   r   r!   r&   r   r   r   r   rg     s   rg   c                       s0   e Zd Zd	 fdd	Zdd Zd
ddZ  ZS )LxmertAttentionNc                    s   t    |j|j dkr4td|j d|j d|j| _t|j|j | _| j| j | _|d krj|j}t	|j| j| _
t	|| j| _t	|| j| _t|j| _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ())r   r   rm   num_attention_heads
ValueErrorrW   attention_head_size	head_sizer   Linearquerykeyvaluerr   Zattention_probs_dropout_probrt   )r   r^   Zctx_dimr   r   r   r   9  s    
zLxmertAttention.__init__c                 C   s6   |  d d | j| jf }||}|ddddS )Nrv   r   rC   r   r
   )rz   r   r   viewpermute)r   r    Znew_x_shaper   r   r   transpose_for_scoresM  s    
z$LxmertAttention.transpose_for_scoresFc                 C   s   |  |}| |}| |}| |}| |}	| |}
t||	dd}|t| j	 }|d k	rp|| }t
jj|dd}| |}t||
}|dddd }| d d | jf }||}|r||fn|f}|S )Nrv   )dimr   rC   r   r
   )r   r   r   r   r1   matmulrX   mathsqrtr   r   Z
functionalZsoftmaxrt   r   
contiguousrz   r   r   )r   hidden_statescontextattention_maskoutput_attentionsZmixed_query_layerZmixed_key_layerZmixed_value_layerZquery_layerZ	key_layerZvalue_layerZattention_scoresattention_probsZcontext_layerZnew_context_layer_shapeoutputsr   r   r   r!   U  s$    







zLxmertAttention.forward)N)NF)r#   r$   r%   r   r   r!   r&   r   r   r   r   r   8  s   r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertAttentionOutputc                    s@   t    t|j|j| _tj|jdd| _t|j| _	d S Nri   rj   )
r   r   r   r   rm   denserq   rr   rs   rt   ru   r   r   r   r   v  s    
zLxmertAttentionOutput.__init__c                 C   s&   |  |}| |}| || }|S r   r   rt   rq   r   r   input_tensorr   r   r   r!   |  s    

zLxmertAttentionOutput.forwardr"   r   r   r   r   r   u  s   r   c                       s&   e Zd Z fddZdddZ  ZS )LxmertCrossAttentionLayerc                    s"   t    t|| _t|| _d S r   )r   r   r   attr   outputru   r   r   r   r     s    

z"LxmertCrossAttentionLayer.__init__NFc           	      C   sD   | j ||||d}|r|d }| |d |}|r:||fn|f}|S Nr   r   r   )r   r   )	r   r   Z
ctx_tensorctx_att_maskr   r   r   attention_outputr   r   r   r   r!     s    z!LxmertCrossAttentionLayer.forward)NFr"   r   r   r   r   r     s   r   c                       s&   e Zd Z fddZdddZ  ZS )LxmertSelfAttentionLayerc                    s"   t    t|| _t|| _d S r   )r   r   r   r   r   r   ru   r   r   r   r     s    

z!LxmertSelfAttentionLayer.__init__Fc                 C   sD   | j ||||d}|r|d }| |d |}|r:||fn|f}|S r   )r   r   )r   r   r   r   r   r   r   r   r   r   r   r!     s    z LxmertSelfAttentionLayer.forward)Fr"   r   r   r   r   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertIntermediatec                    s,   t    t|j|j| _t|j | _	d S r   )
r   r   r   r   rm   intermediate_sizer   r   
hidden_actintermediate_act_fnru   r   r   r   r     s    
zLxmertIntermediate.__init__c                 C   s   |  |}| |}|S r   )r   r   r   r   r   r   r   r!     s    

zLxmertIntermediate.forwardr"   r   r   r   r   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertOutputc                    s@   t    t|j|j| _tj|jdd| _t|j	| _
d S r   )r   r   r   r   r   rm   r   rq   rr   rs   rt   ru   r   r   r   r     s    
zLxmertOutput.__init__c                 C   s&   |  |}| |}| || }|S r   r   r   r   r   r   r!     s    

zLxmertOutput.forwardr"   r   r   r   r   r     s   r   c                       s&   e Zd Z fddZdddZ  ZS )LxmertLayerc                    s,   t    t|| _t|| _t|| _d S r   )r   r   r   	attentionr   intermediater   r   ru   r   r   r   r     s    


zLxmertLayer.__init__NFc                 C   sD   | j |||d}|d }| |}| ||}|f|dd   }|S )Nr   r   r   )r   r   r   )r   r   r   r   r   r   Zintermediate_outputZlayer_outputr   r   r   r!     s    
zLxmertLayer.forward)NFr"   r   r   r   r   r     s   r   c                       s@   e Zd Z fddZdddZdd Zdd	 Zdd
dZ  ZS )LxmertXLayerc                    sT   t    t|| _t|| _t|| _t|| _t	|| _
t|| _t	|| _d S r   )r   r   r   visual_attentionr   lang_self_attvisn_self_attr   
lang_interr   lang_output
visn_intervisn_outputru   r   r   r   r     s    






zLxmertXLayer.__init__Fc                 C   s,   | j ||||d}| j |||dd}||fS )N)r   r   F)r   )r   
lang_inputlang_attention_maskvisual_inputvisual_attention_maskoutput_x_attentionslang_att_outputvisual_att_outputr   r   r   	cross_att  s    	zLxmertXLayer.cross_attc                 C   s0   | j ||dd}| j||dd}|d |d fS )NFr   r   )r   r   )r   r   r   r   r   r   r   r   r   r   self_att  s    zLxmertXLayer.self_attc                 C   s4   |  |}| |}| ||}| ||}||fS r   )r   r   r   r   )r   r   r   Zlang_inter_outputZvisual_inter_outputr   visual_outputr   r   r   	output_fc  s
    

zLxmertXLayer.output_fcc                 C   sj   | j |||||d\}}|dd  }| |d ||d |\}}| ||\}	}
|rb|	|
|d fS |	|
fS )N)r   r   r   r   r   r   r   )r   r   r   )r   
lang_featsr   visual_featsr   r   r   r   r   r   r   r   r   r   r!   	  s.    
zLxmertXLayer.forward)F)F)	r#   r$   r%   r   r   r   r   r!   r&   r   r   r   r   r     s    
 r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertVisualFeatureEncoderc                    sl   t    |j}|j}t||j| _tj|jdd| _	t||j| _
tj|jdd| _t|j| _d S r   )r   r   visual_feat_dimZvisual_pos_dimr   r   rm   visn_fcrq   visn_layer_normbox_fcbox_layer_normrr   rs   rt   )r   r^   Zfeat_dimZpos_dimr   r   r   r   -  s    
z#LxmertVisualFeatureEncoder.__init__c                 C   sB   |  |}| |}| |}| |}|| d }| |}|S NrC   )r   r   r   r   rt   )r   r   
visual_posr    yr   r   r   r   r!   <  s    




z"LxmertVisualFeatureEncoder.forwardr"   r   r   r   r   r   ,  s   r   c                       s&   e Zd Z fddZdddZ  ZS )LxmertEncoderc                    s   t    t | _ | _ j| _ j| _ j	| _
t fddt| jD | _t fddt| jD | _t fddt| j
D | _	d S )Nc                    s   g | ]}t  qS r   r   r;   _r^   r   r   
<listcomp>V  s     z*LxmertEncoder.__init__.<locals>.<listcomp>c                    s   g | ]}t  qS r   )r   r   r   r   r   r   W  s     c                    s   g | ]}t  qS r   r   r   r   r   r   r   X  s     )r   r   r   r   r^   Zl_layersZnum_l_layersx_layersZnum_x_layersr_layersZnum_r_layersr   Z
ModuleListrangelayerru   r   r   r   r   H  s    

  zLxmertEncoder.__init__Nc                 C   sd  d}d}|s| j jrdnd }	|s(| j jr,dnd }
|s<| j jr@dnd }| ||}| jD ]:}||||d}|d }||f }|
d k	rV|
|d f }
qV| jD ]:}||||d}|d }||f }|	d k	r|	|d f }	q| jD ]P}||||||d}|d d \}}||f }||f }|d k	r||d f }q||r8|	nd f}||rJ|
nd f}|||r^|nd fS )Nr   r   r   r   rC   )r^   r   r   r   r   r   )r   r   r   r   r   r   r   r,   r+   r.   r-   r/   Zlayer_moduleZ	l_outputsZ	v_outputsZ	x_outputsvisual_encoder_outputslang_encoder_outputsr   r   r   r!   Z  sR    	






zLxmertEncoder.forward)NNr"   r   r   r   r   r   G  s     r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertPoolerc                    s.   t t|   t|j|j| _t | _d S r   )	r   r   r   r   r   rm   r   ZTanh
activationru   r   r   r   r     s    zLxmertPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r   r   Zfirst_token_tensorr*   r   r   r   r!     s    

zLxmertPooler.forwardr"   r   r   r   r   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertPredictionHeadTransformc                    sB   t t|   t|j|j| _t|j | _	tj
|jdd| _
d S r   )r   r   r   r   r   rm   r   r   r   transform_act_fnrq   ru   r   r   r   r     s    z&LxmertPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r   rq   r   r   r   r   r!     s    


z%LxmertPredictionHeadTransform.forwardr"   r   r   r   r   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertLMPredictionHeadc                    sZ   t t|   t|| _tj|d|ddd| _|| j_	t
t|d| _d S )Nr   r   FrA   )r   r   r   r   	transformr   r   rz   decoderr?   	Parameterr1   r~   rA   r   r^   Zlxmert_model_embedding_weightsr   r   r   r     s    
zLxmertLMPredictionHead.__init__c                 C   s   |  |}| || j }|S r   )r   r   rA   r   r   r   r   r!     s    
zLxmertLMPredictionHead.forwardr"   r   r   r   r   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertVisualAnswerHeadc              	      sN   t    |j}tt||d t tj|d ddt|d || _d S )NrC   ri   rj   )	r   r   rm   r   Z
Sequentialr   r   rq   logit_fc)r   r^   
num_labelsZhid_dimr   r   r   r     s    
zLxmertVisualAnswerHead.__init__c                 C   s
   |  |S r   )r   r   r   r   r   r!     s    zLxmertVisualAnswerHead.forwardr"   r   r   r   r   r     s   
r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertVisualObjHeadc                    s   t    t _i } jr.d jd|d<  jrDd jd|d<  jr`d j	f j	d|d< |_
t fddj
D _d S )	Nrv   )rY   rd   objattrrv   featc                    s&   i | ]}|t  jj| d  qS )rd   )r   r   rm   visual_losses)r;   r   r^   r   r   r   
<dictcomp>  s      z0LxmertVisualObjHead.__init__.<locals>.<dictcomp>)r   r   r   r   visual_obj_lossnum_object_labelsvisual_attr_lossnum_attr_labelsvisual_feat_lossr   r   r   Z
ModuleDictdecoder_dictr   r^   r   r   r   r   r     s    


zLxmertVisualObjHead.__init__c                 C   s0   |  |}i }| jD ]}| j| |||< q|S r   )r   r   r   )r   r   r   r   r   r   r   r!     s
    

zLxmertVisualObjHead.forwardr"   r   r   r   r   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )LxmertPreTrainingHeadsc                    s.   t t|   t||| _t|jd| _d S r   )	r   r   r   r   predictionsr   r   rm   seq_relationshipr   r   r   r   r     s    zLxmertPreTrainingHeads.__init__c                 C   s   |  |}| |}||fS r   )r   r   )r   Zsequence_outputr*   Zprediction_scoresZseq_relationship_scorer   r   r   r!     s    

zLxmertPreTrainingHeads.forwardr"   r   r   r   r   r     s   r   c                   @   s$   e Zd ZdZeZeZdZdd Z	dS )LxmertPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    lxmertc                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  nft |tj
rz|jjjd| jjd |jdk	r|jj|j 	  n&t |tjr|jj	  |jjd dS )zInitialize the weights        )meanZstdN      ?)
isinstancer   r   r?   r\   Znormal_r^   Zinitializer_rangerA   Zzero_rk   rh   rq   Zfill_)r   moduler   r   r   _init_weights  s    

z#LxmertPreTrainedModel._init_weightsN)
r#   r$   r%   r0   r   config_classrf   Zload_tf_weightsZbase_model_prefixr  r   r   r   r   r     s
   r   aR  

    The LXMERT model was proposed in [LXMERT: Learning Cross-Modality Encoder Representations from
    Transformers](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. It's a vision and language transformer
    model, pretrained on a variety of multi-modal datasets comprising of GQA, VQAv2.0, MSCOCO captions, and Visual
    genome, using a combination of masked language modeling, region of interest feature regression, cross entropy loss
    for question answering attribute prediction, and object tag prediction.

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`LxmertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  

    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
            This input represents visual features. They ROI pooled object features from bounding boxes using a
            faster-RCNN model)

            These are currently not provided by the transformers library.
        visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
            This input represents spacial features corresponding to their relative (via index) visual features. The
            pre-trained LXMERT model expects these spacial features to be normalized bounding boxes on a scale of 0 to
            1.

            These are currently not provided by the transformers library.
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        visual_attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z`The bare Lxmert Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Z fddZdd Zdd Zeede	e
eeddeej eej eej eej eej eej eej ee ee ee eeeej f d
ddZ  ZS )LxmertModelc                    s6   t  | t|| _t|| _t|| _|   d S r   )	r   r   rg   r   r   encoderr   pooler	post_initru   r   r   r   r   t  s
    


zLxmertModel.__init__c                 C   s   | j jS r   r   rn   r   r   r   r   get_input_embeddings|  s    z LxmertModel.get_input_embeddingsc                 C   s   || j _d S r   r  )r   Znew_embeddingsr   r   r   set_input_embeddings  s    z LxmertModel.set_input_embeddingsbatch_size, sequence_length
checkpointoutput_typer  N)r   r   r   r   r   r   r   r   output_hidden_statesreturn_dictreturnc              
   C   sp  |d k	r|n| j j}|	d k	r |	n| j j}	|
d k	r4|
n| j j}
|d k	rV|d k	rVtdn@|d k	rt| || | }n"|d k	r| d d }ntd|d krtd|d krtd|d k	r|jn|j}|d krtj	||d}|d krtj
|tj|d}|dd	}|j| jd
}d| t| jj }|d k	rn|dd	}|j| jd
}d| t| jj }nd }| |||}| j||||||d}|d d	 \}}|d }|d }d}|r|d }|d }|d	 }|||f}|	r||fnd}|d }|d }| |}|
s&|||f| | S t||||	r8|nd |	rD|nd |rP|nd |r\|nd |rh|nd dS )NzDYou cannot specify both input_ids and inputs_embeds at the same timerv   z5You have to specify either input_ids or inputs_embedsz`visual_feats` cannot be `None`z`visual_pos` cannot be `None`ry   rw   r   rC   )rx   r  )r   r   r   r   r   r   )r*   r(   r)   r+   r,   r-   r.   r/   )r^   r   r  use_return_dictr   Z%warn_if_padding_and_no_attention_maskrz   ry   r1   Zonesr~   r{   r|   torx   Zfinfominr   r  r	  r'   )r   r   r   r   r   r   r   r   r   r  r  r   ry   Zextended_attention_maskZextended_visual_attention_maskZembedding_outputZencoder_outputsr   r   r,   r+   Zall_attentionsr-   r.   r/   r   r   r   r*   r   r   r   r!     s    


	
zLxmertModel.forward)
NNNNNNNNNN)r#   r$   r%   r   r  r  r   LXMERT_INPUTS_DOCSTRINGformatr   _CHECKPOINT_FOR_DOCr'   _CONFIG_FOR_DOCr   r1   
LongTensorr2   boolr   r   r!   r&   r   r   r   r   r  o  s@             r  z6Lxmert Model with a specified pretraining head on top.c                       s  e Zd ZdgZ fddZdd Zdd Zejdd	d
Z	dd Z
dd Zeedeeeddeej eej eej eej eej eej eej eej eeeeejejf f  eej eej ee ee ee eeeej f dddZ  ZS )LxmertForPreTrainingzcls.predictions.decoder.weightc                    s  t  | || _|j| _|j| _|j| _|j| _|j| _|j| _t	|| _
t|| j
jjj| _| jrpt|| _| jrt|| j| _|   tddtddt d| _i }|jrd|jdd|d< |jrd|jdd|d< |jrd	|jf|jd
d|d< || _d S )Nnone)Z	reduction)l2	visual_cecer   r"  )rY   rd   r5   r   r   rv   r!  r   )r   r   r^   num_qa_labelsvisual_loss_normalizertask_mask_lmtask_obj_predicttask_matchedtask_qar  r   r   r   rn   r?   clsr   obj_predict_headr   answer_headr
  r	   r   	loss_fctsr   r   r   r   r   r   r   r   r   r   r   r      sH    




zLxmertForPreTraining.__init__c                 C   s8   |   }|dks|dkrdS | |}|| j_|| _|S a  
        Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
        will add newly initialized weights. Reducing the size will remove weights from the end

        Args:
            num_labels (`int`, *optional*):
                New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
                weights at the end. Reducing the size will remove weights from the end. If not provided or `None`, just
                returns a pointer to the qa labels ``torch.nn.Linear``` module of the model without doing anything.

        Return:
            `torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
        Nget_qa_logit_layer_resize_qa_labelsr^   r$  r   r   cur_qa_logit_layernew_qa_logit_layerr   r   r   resize_num_qa_labels7  s    
z)LxmertForPreTraining.resize_num_qa_labelsc                 C   s&   |   }| ||}| | |   S r   r0  _get_resized_qa_labels_set_qa_logit_layerr2  r   r   r   r1  O  s    
z&LxmertForPreTraining._resize_qa_labelsr  c                 C   s   t | dr| jjd S dS )a  
        Returns the linear layer that produces question answering logits.

        Returns:
            `nn.Module`: A torch module mapping the question answering prediction hidden states or `None` if LXMERT
            does not have a visual answering head.
        r,  rv   Nhasattrr,  r   r   r   r   r   r0  U  s    
z'LxmertForPreTraining.get_qa_logit_layerc                 C   s   || j jd< d S Nrv   r,  r   r   Zqa_logit_layerr   r   r   r8  `  s    z(LxmertForPreTraining._set_qa_logit_layerc                 C   s   |d kr|S |j  \}}||kr&|S t|dd d k	rDt||}ntj||dd}||j j | | t||}|j j	d |d d f |j j	d |d d f< t|dd d k	r|j
j	d | |j
j	d |< |S NrA   Fr   r?   rz   rT   r   r   r  ry   r  r  r\   rA   r   r3  r   Zcur_qa_labelsZ
hidden_dimr4  Znum_labels_to_copyr   r   r   r7  c  s    

,z+LxmertForPreTraining._get_resized_qa_labelsr  )r  r  N)r   r   r   r   r   r   r   labels
obj_labelsmatched_labelansr   r  r  r  c           *      K   s  d|krt dt |d}|dk	r*|n| jj}|dk	r@|jn|j}| j||||||||||d
}|d |d |d   }}}| ||\}}| j	r| 
|}n|d d }|dkr|
dkr|	dkr|dkrdntjd|d	}|dk	r| jr| jd
 |d| jj|d}||7 }|
dk	rT| jrT| jd
 |dd|
d}||7 }|	dk	r(| jr(tjd|jd	}| |}| j D ]\}}|	| \}} |d }!|d }"|d }#| j}$| j|" }%|| }&|%|&d|!||#}'|' dkr|'d}'|'| d  |$ }'||'7 }q||7 }|dk	rb| j	rb| jd
 |d| j|d}(||(7 }|s|||f|dd  })|dk	r|f|) S |)S t|||||j|j|j|j|jd	S )aV  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        obj_labels (`Dict[Str: Tuple[Torch.FloatTensor, Torch.FloatTensor]]`, *optional*):
            each key is named after each one of the visual losses and each element of the tuple is of the shape
            `(batch_size, num_features)` and `(batch_size, num_features, visual_feature_dim)` for each the label id and
            the label score respectively
        matched_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the whether or not the text input matches the image (classification) loss. Input
            should be a sequence pair (see `input_ids` docstring) Indices should be in `[0, 1]`:

            - 0 indicates that the sentence does not match the image,
            - 1 indicates that the sentence does match the image.
        ans (`Torch.Tensor` of shape `(batch_size)`, *optional*):
            a one hot representation hof the correct answer *optional*

        Returns:
        Zmasked_lm_labelszlThe `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.N
r   r   r   r   r   r   r   r  r   r  r   r   rC   r   r  r#  rv   rd   r5   rY   r
   )	r5   r8   r9   r6   r+   r,   r-   r.   r/   ) warningswarnFutureWarningpopr^   r  ry   r   r*  r)  r,  r1   Ztensorr&  r-  r   rl   r(  r'  r+  r   itemsr%  r   r  r$  r7   r+   r,   r-   r.   r/   )*r   r   r   r   r   r   r   r   rB  rC  rD  rE  r   r  r  kwargsry   lxmert_outputr   r   r*   Zlang_prediction_scoresr9   answer_scoreZ
total_lossZmasked_lm_lossZmatched_lossZtotal_visual_lossZvisual_prediction_scores_dictr   Zkey_infolabelZ	mask_confZ
output_dimZloss_fct_nameZlabel_shaper?   Zvisual_loss_fctZvisual_prediction_scoresZvisual_lossZanswer_lossr   r   r   r   r!   ~  s    )





 
zLxmertForPreTraining.forward)NNNNNNNNNNNNNN)r#   r$   r%   Z_tied_weights_keysr   r5  r1  r   Moduler0  r8  r7  r   r  r  r   r7   r  r   r1   r  r2   r   strr   Tensorr  r   r!   r&   r   r   r   r   r    sP   7
              r  zHLxmert Model with a visual-answering head on top for downstream QA tasksc                       s   e Zd Z fddZdd Zdd Zejddd	Zd
d Z	dd Z
eedeeeeddeej eej eej eej eej eej eej eej ee ee ee eeeej f dddZ  ZS )LxmertForQuestionAnsweringc                    sN   t  | || _|j| _|j| _t|| _t|| j| _| 	  t
 | _d S r   )r   r   r^   r$  r%  r  r   r   r,  r
  r   r5   ru   r   r   r   r   
  s    
z#LxmertForQuestionAnswering.__init__c                 C   s8   |   }|dks|dkrdS | |}|| j_|| _|S r.  r/  r2  r   r   r   r5    s    
z/LxmertForQuestionAnswering.resize_num_qa_labelsc                 C   s&   |   }| ||}| | |   S r   r6  r2  r   r   r   r1  5  s    
z,LxmertForQuestionAnswering._resize_qa_labelsr9  c                 C   s   t | dr| jjd S dS )a  
        Returns the linear layer that produces question answering logits

        Returns:
            `nn.Module`: A torch module mapping the question answering prediction hidden states. `None`: A NoneType
            object if Lxmert does not have the visual answering head.
        r,  rv   Nr:  r   r   r   r   r0  ;  s    	
z-LxmertForQuestionAnswering.get_qa_logit_layerc                 C   s   || j jd< d S r<  r=  r>  r   r   r   r8  G  s    z.LxmertForQuestionAnswering._set_qa_logit_layerc                 C   s   |d kr|S |j  \}}||kr&|S t|dd d k	rDt||}ntj||dd}||j j | | t||}|j j	d |d d f |j j	d |d d f< t|dd d k	r|j
j	d | |j
j	d |< |S r?  r@  rA  r   r   r   r7  J  s    

,z1LxmertForQuestionAnswering._get_resized_qa_labelsr  r  N)r   r   r   r   r   r   r   rB  r   r  r  r  c                 C   s   |dk	r|n| j j}| j||||||||
|	|d
}|d }| |}d}|dk	rl| |d| j|d}|s|f|dd  }|dk	r|f| S |S t|||j|j	|j
|j|jdS )z
        labels (`Torch.Tensor` of shape `(batch_size)`, *optional*):
            A one-hot representation of the correct answer
        NrF  rC   rv   r
   )r5   r6   r+   r,   r-   r.   r/   )r^   r  r   r,  r5   r   r$  r4   r+   r,   r-   r.   r/   )r   r   r   r   r   r   r   r   rB  r   r  r  rM  r*   rN  r5   r   r   r   r   r!   e  s<    
z"LxmertForQuestionAnswering.forward)NNNNNNNNNNN)r#   r$   r%   r   r5  r1  r   rP  r0  r8  r7  r   r  r  r   r  r4   r  r   r1   r  r2   rR  r  r   r   r!   r&   r   r   r   r   rS    sJ              rS  )Ar0   r   rI   rG  dataclassesr   typingr   r   r   r   r1   r   Ztorch.nnr   r	   Zactivationsr   r   Zmodeling_utilsr   utilsr   r   r   r   r   r   Zconfiguration_lxmertr   Z
get_loggerr#   rG   r  r  Z$LXMERT_PRETRAINED_MODEL_ARCHIVE_LISTrP  r   r'   r4   r7   rf   rg   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   ZLXMERT_START_DOCSTRINGr  r  r  rS  r   r   r   r   <module>   s~    
-'/O(=[Q:   
