U
    ,-eܖ                    @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZmZmZ ddlmZmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZmZmZmZmZmZ ddl m!Z! e"e#Z$dZ%dZ&ddgZ'eG dd deZ(eG dd deZ)eG dd deZ*eG dd deZ+eG dd deZ,eG dd deZ-eG dd deZ.eG d d! d!eZ/eG d"d# d#eZ0eG d$d% d%eZ1G d&d' d'e
j2Z3G d(d) d)e
j2Z4G d*d+ d+e
j2Z5G d,d- d-e
j2Z6G d.d/ d/e
j2Z7G d0d1 d1e
j2Z8G d2d3 d3e
j2Z9G d4d5 d5e
j2Z:G d6d7 d7e
j2Z;G d8d9 d9e
j2Z<G d:d; d;e
j2Z=G d<d= d=e
j2Z>G d>d? d?eZ?d@Z@dAZAedBe@G dCdD dDe?ZBdEdF ZCG dGdH dHe
j2ZDedIe@G dJdK dKe?ZEedLe@G dMdN dNe?ZFedOe@G dPdQ dQe?ZGedRe@G dSdT dTe?ZHedUe@G dVdW dWe?ZIedXe@G dYdZ dZe?ZJed[e@G d\d] d]e?ZKed^e@G d_d` d`e?ZLdS )azPyTorch LUKE model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FNgelu)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel)apply_chunking_to_forward)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
LukeConfigr   zstudio-ousia/luke-basezstudio-ousia/luke-largec                   @   s6   e Zd ZU dZdZejed< dZe	e
ej  ed< dS )BaseLukeModelOutputWithPoolinga  
    Base class for outputs of the LUKE model.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`):
            Sequence of entity hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) further processed by a
            Linear layer and a Tanh activation function.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length +
            entity_length, sequence_length + entity_length)`. Attentions weights after the attention softmax, used to
            compute the weighted average in the self-attention heads.
    Nentity_last_hidden_stateentity_hidden_states__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r    r$   r$   g/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/luke/modeling_luke.pyr   5   s   
r   c                   @   s6   e Zd ZU dZdZejed< dZe	e
ej  ed< dS )BaseLukeModelOutputa#  
    Base class for model's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`):
            Sequence of entity hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr   r   r   r$   r$   r$   r%   r&   T   s   
r&   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZejed< dZejed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dS )LukeMaskedLMOutputa>	  
    Base class for model's outputs, with potential hidden states and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            The sum of masked language modeling (MLM) loss and entity prediction loss.
        mlm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Masked language modeling (MLM) loss.
        mep_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Masked entity prediction (MEP) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        entity_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the entity prediction head (scores for each entity vocabulary token before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlossmlm_lossmep_losslogitsentity_logitshidden_statesr   
attentions)r   r   r   r    r(   r   r!   r"   r#   r)   r*   r+   r,   r-   r   r   r.   r$   r$   r$   r%   r'   s   s   
r'   c                   @   st   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dS )EntityClassificationOutputay  
    Outputs of entity classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr(   r+   r-   r   r.   r   r   r   r    r(   r   r!   r"   r#   r+   r-   r   r   r.   r$   r$   r$   r%   r/      s   
r/   c                   @   st   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dS )EntityPairClassificationOutputa~  
    Outputs of entity pair classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr(   r+   r-   r   r.   r0   r$   r$   r$   r%   r1      s   
r1   c                   @   st   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dS )EntitySpanClassificationOutputa  
    Outputs of entity span classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, entity_length, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr(   r+   r-   r   r.   r0   r$   r$   r$   r%   r2      s   
r2   c                   @   st   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dS )LukeSequenceClassifierOutputa  
    Outputs of sentence classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr(   r+   r-   r   r.   r0   r$   r$   r$   r%   r3      s   
r3   c                   @   st   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dS )LukeTokenClassifierOutputa  
    Base class for outputs of token classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr(   r+   r-   r   r.   r0   r$   r$   r$   r%   r4     s   
r4   c                   @   s   e Zd ZU dZdZeej ed< dZ	ejed< dZ
ejed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dS )	 LukeQuestionAnsweringModelOutputay  
    Outputs of question answering models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Span-start scores (before SoftMax).
        end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Span-end scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr(   start_logits
end_logitsr-   r   r.   )r   r   r   r    r(   r   r!   r"   r#   r6   r7   r-   r   r   r.   r$   r$   r$   r%   r5   ?  s   
r5   c                   @   st   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dS )LukeMultipleChoiceModelOutputa  
    Outputs of multiple choice models.

    Args:
        loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
            *num_choices* is the second dimension of the input tensors. (see *input_ids* above).

            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr(   r+   r-   r   r.   r0   r$   r$   r$   r%   r8   d  s   
r8   c                       s2   e Zd ZdZ fddZd	ddZdd Z  ZS )
LukeEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _|j| _tj|j|j| jd| _	d S )Npadding_idxZeps)super__init__r   	Embedding
vocab_sizehidden_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutr;   selfconfig	__class__r$   r%   r>     s    
  zLukeEmbeddings.__init__Nc           	      C   s   |d kr0|d k	r&t || j|j}n
| |}|d k	rB| }n| d d }|d krptj|tj| j	jd}|d kr| 
|}| |}| |}|| | }| |}| |}|S )Ndtypedevice)"create_position_ids_from_input_idsr;   torT   &create_position_ids_from_inputs_embedssizer!   zeroslongposition_idsrB   rD   rF   rG   rK   )	rM   	input_idstoken_type_idsr[   inputs_embedsinput_shaperD   rF   
embeddingsr$   r$   r%   forward  s"    






zLukeEmbeddings.forwardc                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        NrQ   r   rR   r   )rX   r!   Zaranger;   rZ   rT   	unsqueezeexpand)rM   r^   r_   Zsequence_lengthr[   r$   r$   r%   rW     s    	   z5LukeEmbeddings.create_position_ids_from_inputs_embeds)NNNN)r   r   r   r    r>   ra   rW   __classcell__r$   r$   rO   r%   r9     s       
!r9   c                       s<   e Zd Zed fddZdejejejdddZ  ZS )	LukeEntityEmbeddingsrN   c                    s   t    || _tj|j|jdd| _|j|jkrHtj	|j|jdd| _
t|j|j| _t|j|j| _tj|j|jd| _t|j| _d S )Nr   r:   Fbiasr<   )r=   r>   rN   r   r?   entity_vocab_sizeentity_emb_sizeentity_embeddingsrA   Linearentity_embedding_denserC   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rO   r$   r%   r>     s    
zLukeEntityEmbeddings.__init__N)
entity_idsr[   r]   c           	      C   s   |d krt |}| |}| jj| jjkr6| |}| |jdd}|dk	|
d}|| }t j|dd}||jddjdd }| |}|| | }| |}| |}|S )Nr   )minrQ   dimgHz>)r!   Z
zeros_likerk   rN   rj   rA   rm   rD   clamptype_asrb   sumrF   rG   rK   )	rM   rn   r[   r]   rk   rD   Zposition_embedding_maskrF   r`   r$   r$   r%   ra     s    





zLukeEntityEmbeddings.forward)N)	r   r   r   r   r>   r!   
LongTensorra   rd   r$   r$   rO   r%   re     s      re   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
LukeSelfAttentionc                    s   t    |j|j dkr@t|ds@td|jf d|j d|j| _t|j|j | _| j| j | _|j	| _	t
|j| j| _t
|j| j| _t
|j| j| _| j	rt
|j| j| _t
|j| j| _t
|j| j| _t
|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)r=   r>   rA   num_attention_headshasattr
ValueErrorintattention_head_sizeall_head_sizeuse_entity_aware_attentionr   rl   querykeyvalue	w2e_query	e2w_query	e2e_queryrI   Zattention_probs_dropout_probrK   rL   rO   r$   r%   r>     s"    
zLukeSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrQ   r      r   r
   )rX   ry   r}   viewpermute)rM   xZnew_x_shaper$   r$   r%   transpose_for_scores  s    
z&LukeSelfAttention.transpose_for_scoresNFc                  C   s  | d}|d kr|}ntj||gdd}| | |}| | |}	| jr|d k	r| | |}
| | |}| | 	|}| | 
|}|d d d d d |d d f }|d d d d d |d d f }|d d d d |d d d f }|d d d d |d d d f }t|
|dd}t||dd}t||dd}t||dd}tj||gdd}tj||gdd}tj||gdd}n$| | |}t||dd}|t| j }|d k	r|| }tjj|dd}| |}|d k	r|| }t||	}|dddd }|  d d | jf }|j| }|d d d |d d f }|d krd }n|d d |d d d f }|r|||f}n||f}|S )Nr   rq   rQ   rp   r
   r   r   )rX   r!   catr   r   r   r   r   r   r   r   matmulZ	transposemathsqrtr}   r   
functionalZsoftmaxrK   r   
contiguousr~   r   ) rM   word_hidden_statesr   attention_mask	head_maskoutput_attentions	word_sizeconcat_hidden_statesZ	key_layerZvalue_layerZw2w_query_layerZw2e_query_layerZe2w_query_layerZe2e_query_layerZw2w_key_layerZe2w_key_layerZw2e_key_layerZe2e_key_layerZw2w_attention_scoresZw2e_attention_scoresZe2w_attention_scoresZe2e_attention_scoresZword_attention_scoresZentity_attention_scoresZattention_scoresZquery_layerZattention_probsZcontext_layerZnew_context_layer_shapeZoutput_word_hidden_statesZoutput_entity_hidden_statesoutputsr$   r$   r%   ra     sV    
    




zLukeSelfAttention.forward)NNF)r   r   r   r>   r   ra   rd   r$   r$   rO   r%   rw     s   	   rw   c                       s4   e Zd Z fddZejejejdddZ  ZS )LukeSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr<   )r=   r>   r   rl   rA   denserG   rH   rI   rJ   rK   rL   rO   r$   r%   r>   l  s    
zLukeSelfOutput.__init__r-   input_tensorreturnc                 C   s&   |  |}| |}| || }|S Nr   rK   rG   rM   r-   r   r$   r$   r%   ra   r  s    

zLukeSelfOutput.forwardr   r   r   r>   r!   ZTensorra   rd   r$   r$   rO   r%   r   k  s   r   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
LukeAttentionc                    s*   t    t|| _t|| _t | _d S r   )r=   r>   rw   rM   r   outputsetZpruned_headsrL   rO   r$   r%   r>   z  s    


zLukeAttention.__init__c                 C   s   t dd S Nz4LUKE does not support the pruning of attention headsNotImplementedError)rM   Zheadsr$   r$   r%   prune_heads  s    zLukeAttention.prune_headsNFc                 C   s   | d}| |||||}|d kr2|d }|}	n(tj|d d dd}tj||gdd}	| ||	}
|
d d d |d d f }|d krd }n|
d d |d d d f }||f|dd   }|S )Nr   r   r   rq   )rX   rM   r!   r   r   )rM   r   r   r   r   r   r   Zself_outputsZconcat_self_outputsr   attention_outputZword_attention_outputZentity_attention_outputr   r$   r$   r%   ra     s(    
zLukeAttention.forward)NNF)r   r   r   r>   r   ra   rd   r$   r$   rO   r%   r   y  s      r   c                       s0   e Zd Z fddZejejdddZ  ZS )LukeIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r=   r>   r   rl   rA   intermediate_sizer   
isinstance
hidden_actstrr   intermediate_act_fnrL   rO   r$   r%   r>     s
    
zLukeIntermediate.__init__r-   r   c                 C   s   |  |}| |}|S r   )r   r   rM   r-   r$   r$   r%   ra     s    

zLukeIntermediate.forwardr   r$   r$   rO   r%   r     s   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )
LukeOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r=   r>   r   rl   r   rA   r   rG   rH   rI   rJ   rK   rL   rO   r$   r%   r>     s    
zLukeOutput.__init__r   c                 C   s&   |  |}| |}| || }|S r   r   r   r$   r$   r%   ra     s    

zLukeOutput.forwardr   r$   r$   rO   r%   r     s   r   c                       s.   e Zd Z fddZd	ddZdd Z  ZS )
	LukeLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S Nr   )
r=   r>   chunk_size_feed_forwardseq_len_dimr   	attentionr   intermediater   r   rL   rO   r$   r%   r>     s    


zLukeLayer.__init__NFc                 C   s   | d}| j|||||d}|d kr0|d }ntj|d d dd}|dd  }	t| j| j| j|}
|
d d d |d d f }|d krd }n|
d d |d d d f }||f|	 }	|	S )Nr   r   r   r   rq   )rX   r   r!   r   r   feed_forward_chunkr   r   )rM   r   r   r   r   r   r   Zself_attention_outputsZconcat_attention_outputr   layer_outputZword_layer_outputZentity_layer_outputr$   r$   r%   ra     s0    

   zLukeLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )rM   r   Zintermediate_outputr   r$   r$   r%   r     s    
zLukeLayer.feed_forward_chunk)NNF)r   r   r   r>   ra   r   rd   r$   r$   rO   r%   r     s      
%r   c                       s&   e Zd Z fddZdddZ  ZS )	LukeEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r$   )r   ).0_rf   r$   r%   
<listcomp>  s     z(LukeEncoder.__init__.<locals>.<listcomp>F)	r=   r>   rN   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingrL   rO   rf   r%   r>     s    
 zLukeEncoder.__init__NFTc                    s  |rdnd }|rdnd }	 r dnd }
t | jD ]\}}|rN||f }|	|f }	|d k	r^|| nd }| jr| jr fdd}tjj||||||}n||||| }|d }|d k	r|d } r.|
|d f }
q.|r||f }|	|f }	|stdd |||
||	fD S t|||
||	d	S )
Nr$   c                    s    fdd}|S )Nc                     s    | f S r   r$   )inputs)moduler   r$   r%   custom_forward  s    zJLukeEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr$   )r   r   r   r   r%   create_custom_forward  s    z2LukeEncoder.forward.<locals>.create_custom_forwardr   r   r   c                 s   s   | ]}|d k	r|V  qd S r   r$   r   vr$   r$   r%   	<genexpr>;  s   z&LukeEncoder.forward.<locals>.<genexpr>)last_hidden_stater-   r.   r   r   )		enumerater   r   Ztrainingr!   utils
checkpointtupler&   )rM   r   r   r   r   r   output_hidden_statesreturn_dictZall_word_hidden_statesZall_entity_hidden_statesZall_self_attentionsiZlayer_moduleZlayer_head_maskr   Zlayer_outputsr$   r   r%   ra     s`    




zLukeEncoder.forward)NNFFTr   r   r   r>   ra   rd   r$   r$   rO   r%   r     s   
     r   c                       s0   e Zd Z fddZejejdddZ  ZS )
LukePoolerc                    s*   t    t|j|j| _t | _d S r   )r=   r>   r   rl   rA   r   ZTanh
activationrL   rO   r$   r%   r>   Q  s    
zLukePooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )rM   r-   Zfirst_token_tensorpooled_outputr$   r$   r%   ra   V  s    

zLukePooler.forwardr   r$   r$   rO   r%   r   P  s   r   c                       s$   e Zd Z fddZdd Z  ZS )EntityPredictionHeadTransformc                    sV   t    t|j|j| _t|jt	r6t
|j | _n|j| _tj|j|jd| _d S r   )r=   r>   r   rl   rA   rj   r   r   r   r   r   transform_act_fnrG   rH   rL   rO   r$   r%   r>   `  s    
z&EntityPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r   rG   r   r$   r$   r%   ra   i  s    


z%EntityPredictionHeadTransform.forwardr   r$   r$   rO   r%   r   _  s   	r   c                       s$   e Zd Z fddZdd Z  ZS )EntityPredictionHeadc                    sH   t    || _t|| _tj|j|jdd| _	t
t|j| _d S )NFrg   )r=   r>   rN   r   	transformr   rl   rj   ri   decoder	Parameterr!   rY   rh   rL   rO   r$   r%   r>   q  s
    

zEntityPredictionHead.__init__c                 C   s   |  |}| || j }|S r   )r   r   rh   r   r$   r$   r%   ra   x  s    
zEntityPredictionHead.forwardr   r$   r$   rO   r%   r   p  s   r   c                   @   s>   e Zd ZdZeZdZdZddgZe	j
dddZdd
dZdS )LukePreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    lukeTr   re   r   c                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  n~t |tj
r|jdkr^|jj	  n|jjjd| jjd |jdk	r|jj|j 	  n&t |tjr|jj	  |jjd dS )zInitialize the weightsg        )ZmeanZstdNr         ?)r   r   rl   weightdataZnormal_rN   Zinitializer_rangerh   Zzero_r?   Zembedding_dimr;   rG   Zfill_)rM   r   r$   r$   r%   _init_weights  s    


z!LukePreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r   )r   r   r   )rM   r   r   r$   r$   r%   _set_gradient_checkpointing  s    
z/LukePreTrainedModel._set_gradient_checkpointingN)F)r   r   r   r    r   config_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesr   Moduler   r   r$   r$   r$   r%   r     s   r   a>  

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`LukeConfig`]): Model configuration class with all the parameters of the
            model. Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a>  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)

        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.

        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.

        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zThe bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any specific head on top.c                       s   e Zd Zdeed fddZdd Zdd Zd	d
 Zdd Z	dd Z
eedeeeddeej eej eej eej eej eej eej eej eej eej ee ee ee eeef dddZejeej dddZ  ZS )	LukeModelT)rN   add_pooling_layerc                    sN   t  | || _t|| _t|| _t|| _|r<t	|nd | _
|   d S r   )r=   r>   rN   r9   r`   re   rk   r   encoderr   pooler	post_init)rM   rN   r   rO   r$   r%   r>     s    


zLukeModel.__init__c                 C   s   | j jS r   r`   rB   rM   r$   r$   r%   get_input_embeddings  s    zLukeModel.get_input_embeddingsc                 C   s   || j _d S r   r   rM   r   r$   r$   r%   set_input_embeddings  s    zLukeModel.set_input_embeddingsc                 C   s   | j j S r   rk   r   r$   r$   r%   get_entity_embeddings  s    zLukeModel.get_entity_embeddingsc                 C   s   || j _ d S r   r   r   r$   r$   r%   set_entity_embeddings  s    zLukeModel.set_entity_embeddingsc                 C   s   t dd S r   r   )rM   Zheads_to_pruner$   r$   r%   _prune_heads  s    zLukeModel._prune_headsbatch_size, sequence_lengthoutput_typer   N)r\   r   r]   r[   rn   entity_attention_maskentity_token_type_idsentity_position_idsr   r^   r   r   r   r   c              	   C   s  |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}|dk	rV|
dk	rVtdn@|dk	rt| || | }n"|
dk	r|
 dd }ntd|\}}|dk	r|jn|
j}|dkrtj	||f|d}|dkrtj
|tj|d}|dk	r6|d}|dkrtj	||f|d}|dkr6tj
||ftj|d}| |	| j j}	| j||||
d}| ||}|dkrtd}n| |||}| j||||	|||d	}|d
 }| jdk	r| |nd}|s||f|dd  S t|||j|j|j|jdS )u  

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, LukeModel

        >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-base")
        >>> model = LukeModel.from_pretrained("studio-ousia/luke-base")
        # Compute the contextualized entity representation corresponding to the entity mention "Beyoncé"

        >>> text = "Beyoncé lives in Los Angeles."
        >>> entity_spans = [(0, 7)]  # character-based entity span corresponding to "Beyoncé"

        >>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
        >>> outputs = model(**encoding)
        >>> word_last_hidden_state = outputs.last_hidden_state
        >>> entity_last_hidden_state = outputs.entity_last_hidden_state
        # Input Wikipedia entities to obtain enriched contextualized representations of word tokens

        >>> text = "Beyoncé lives in Los Angeles."
        >>> entities = [
        ...     "Beyoncé",
        ...     "Los Angeles",
        ... ]  # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "Los Angeles"
        >>> entity_spans = [
        ...     (0, 7),
        ...     (17, 28),
        ... ]  # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"

        >>> encoding = tokenizer(
        ...     text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt"
        ... )
        >>> outputs = model(**encoding)
        >>> word_last_hidden_state = outputs.last_hidden_state
        >>> entity_last_hidden_state = outputs.entity_last_hidden_state
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timerQ   z5You have to specify either input_ids or inputs_embeds)rT   rR   r   )r\   r[   r]   r^   )r   r   r   r   r   r   )r   pooler_outputr-   r.   r   r   )rN   r   r   use_return_dictr{   Z%warn_if_padding_and_no_attention_maskrX   rT   r!   ZonesrY   rZ   Zget_head_maskr   r`   get_extended_attention_maskrk   r   r   r   r-   r.   r   r   )rM   r\   r   r]   r[   rn   r   r   r   r   r^   r   r   r   r_   Z
batch_sizeZ
seq_lengthrT   Zentity_seq_lengthZword_embedding_outputextended_attention_maskZentity_embedding_outputZencoder_outputssequence_outputr   r$   r$   r%   ra     sp    9






zLukeModel.forward)word_attention_maskr   c                 C   s   |}|dk	rt j||gdd}| dkrH|dddddddf }n8| dkrn|ddddddf }ntd|j d|j| jd}d	| t | jj }|S )
ac  
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            word_attention_mask (`torch.LongTensor`):
                Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
            entity_attention_mask (`torch.LongTensor`, *optional*):
                Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore.

        Returns:
            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        NrQ   rq   r
   r   z&Wrong shape for attention_mask (shape ))rS   r   )	r!   r   rr   r{   shaperV   rS   Zfinforo   )rM   r   r   r   r   r$   r$   r%   r     s    z%LukeModel.get_extended_attention_mask)T)NNNNNNNNNNNNN)r   r   r   r   boolr>   r   r   r   r   r   r   LUKE_INPUTS_DOCSTRINGformatr   r   _CONFIG_FOR_DOCr   r!   rv   r"   r   r   ra   r   rd   r$   r$   rO   r%   r     sR   
             
  r   c                 C   s2   |  | }tj|dd|| }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   rq   )ner|   r!   Zcumsumrt   rZ   )r\   r;   maskZincremental_indicesr$   r$   r%   rU     s    rU   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )
LukeLMHeadz*Roberta Head for masked language modeling.c                    sd   t    t|j|j| _tj|j|jd| _t|j|j	| _
tt|j	| _| j| j
_d S r   )r=   r>   r   rl   rA   r   rG   rH   
layer_normr@   r   r   r!   rY   rh   rL   rO   r$   r%   r>     s    
zLukeLMHead.__init__c                 K   s*   |  |}t|}| |}| |}|S r   )r   r   r
  r   )rM   featureskwargsr   r$   r$   r%   ra     s
    


zLukeLMHead.forwardc                 C   s*   | j jjjdkr| j| j _n
| j j| _d S )Nmeta)r   rh   rT   typer   r$   r$   r%   _tie_weights  s    zLukeLMHead._tie_weights)r   r   r   r    r>   ra   r  rd   r$   r$   rO   r%   r	    s   	
r	  z
    The LUKE model with a language modeling head and entity prediction head on top for masked language modeling and
    masked entity prediction.
    c                       s   e Zd ZdddgZ fddZ fddZdd	 Zd
d Zee	
deeeddeej eej eej eej eej eej eej eej eej eej eej eej ee ee ee eeef dddZ  ZS )LukeForMaskedLMzlm_head.decoder.weightzlm_head.decoder.biasz!entity_predictions.decoder.weightc                    s@   t  | t|| _t|| _t|| _t	 | _
|   d S r   )r=   r>   r   r   r	  lm_headr   entity_predictionsr   r   loss_fnr   rL   rO   r$   r%   r>     s    



zLukeForMaskedLM.__init__c                    s$   t    | | jj| jjj d S r   )r=   tie_weightsZ_tie_or_clone_weightsr  r   r   rk   r   rO   r$   r%   r    s    
zLukeForMaskedLM.tie_weightsc                 C   s   | j jS r   r  r   r   r$   r$   r%   get_output_embeddings  s    z%LukeForMaskedLM.get_output_embeddingsc                 C   s   || j _d S r   r  )rM   Znew_embeddingsr$   r$   r%   set_output_embeddings  s    z%LukeForMaskedLM.set_output_embeddingsr   r   N)r\   r   r]   r[   rn   r   r   r   labelsentity_labelsr   r^   r   r   r   r   c                 C   s0  |dk	r|n| j j}| j||||||||||||dd}d}d}| |j}|	dk	r|	|j}	| |d| j j	|	d}|dkr|}d}d}|j
dk	r| |j
}|
dk	r| |d| j j|
d}|dkr|}n|| }|stdd ||||||j|j|jfD S t||||||j|j|jdS )aS  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        entity_labels (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

        Returns:

        NTr\   r   r]   r[   rn   r   r   r   r   r^   r   r   r   rQ   c                 s   s   | ]}|d k	r|V  qd S r   r$   r   r$   r$   r%   r   d  s   
z*LukeForMaskedLM.forward.<locals>.<genexpr>)r(   r)   r*   r+   r,   r-   r   r.   )rN   r   r   r  r   rV   rT   r  r   r@   r   r  ri   r   r-   r   r.   r'   )rM   r\   r   r]   r[   rn   r   r   r   r  r  r   r^   r   r   r   r   r(   r)   r+   r*   r,   r$   r$   r%   ra     sn    "
zLukeForMaskedLM.forward)NNNNNNNNNNNNNNN)r   r   r   Z_tied_weights_keysr>   r  r  r  r   r  r  r   r'   r  r   r!   rv   r"   r  r   r   ra   rd   r$   r$   rO   r%   r    sP   

               
r  z
    The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity
    token) for entity classification tasks, such as Open Entity.
    c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeef dddZ  ZS )
LukeForEntityClassificationc                    sJ   t  | t|| _|j| _t|j| _t	|j
|j| _|   d S r   r=   r>   r   r   
num_labelsr   rI   rJ   rK   rl   rA   
classifierr   rL   rO   r$   r%   r>     s    
z$LukeForEntityClassification.__init__r   r   Nr\   r   r]   r[   rn   r   r   r   r   r^   r  r   r   r   r   c                 C   s   |dk	r|n| j j}| j|||||||||	|
||dd}|jdddddf }| |}| |}d}|dk	r||j}|jdkrt	j
||}n t	j
|d|d|}|stdd |||j|j|jfD S t|||j|j|jd	S )
u  
        labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
            Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
            used for the single-label classification. In this case, labels should contain the indices that should be in
            `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy
            loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0
            and 1 indicate false and true, respectively.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, LukeForEntityClassification

        >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")
        >>> model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")

        >>> text = "Beyoncé lives in Los Angeles."
        >>> entity_spans = [(0, 7)]  # character-based entity span corresponding to "Beyoncé"
        >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: person
        ```NTr  r   r   rQ   c                 s   s   | ]}|d k	r|V  qd S r   r$   r   r$   r$   r%   r     s   z6LukeForEntityClassification.forward.<locals>.<genexpr>r(   r+   r-   r   r.   )rN   r   r   r   rK   r  rV   rT   ndimr   r   cross_entropy binary_cross_entropy_with_logitsr   rt   r   r-   r   r.   r/   rM   r\   r   r]   r[   rn   r   r   r   r   r^   r  r   r   r   r   feature_vectorr+   r(   r$   r$   r%   ra     sH    .


 z#LukeForEntityClassification.forward)NNNNNNNNNNNNNN)r   r   r   r>   r   r  r  r   r/   r  r   r!   rv   r"   r  r   r   ra   rd   r$   r$   rO   r%   r    sD   
              
r  z
    The LUKE model with a classification head on top (a linear layer on top of the hidden states of the two entity
    tokens) for entity pair classification tasks, such as TACRED.
    c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeef dddZ  ZS )
LukeForEntityPairClassificationc                    sP   t  | t|| _|j| _t|j| _t	|j
d |jd| _|   d S )Nr   Fr  rL   rO   r$   r%   r>     s    
z(LukeForEntityPairClassification.__init__r   r   Nr  c                 C   s  |dk	r|n| j j}| j|||||||||	|
||dd}tj|jdddddf |jdddddf gdd}| |}| |}d}|dk	r||j	}|j
dkrtj||}n tj|d|d|}|stdd	 |||j|j|jfD S t|||j|j|jd
S )u"  
        labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
            Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
            used for the single-label classification. In this case, labels should contain the indices that should be in
            `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy
            loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0
            and 1 indicate false and true, respectively.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, LukeForEntityPairClassification

        >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
        >>> model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred")

        >>> text = "Beyoncé lives in Los Angeles."
        >>> entity_spans = [
        ...     (0, 7),
        ...     (17, 28),
        ... ]  # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"
        >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: per:cities_of_residence
        ```NTr  r   r   rq   rQ   c                 s   s   | ]}|d k	r|V  qd S r   r$   r   r$   r$   r%   r   [  s   z:LukeForEntityPairClassification.forward.<locals>.<genexpr>r   )rN   r   r   r!   r   r   rK   r  rV   rT   r!  r   r   r"  r#  r   rt   r   r-   r   r.   r1   r$  r$   r$   r%   ra     sN    1. 


 z'LukeForEntityPairClassification.forward)NNNNNNNNNNNNNN)r   r   r   r>   r   r  r  r   r1   r  r   r!   rv   r"   r  r   r   ra   rd   r$   r$   rO   r%   r&    sD   
              
r&  z
    The LUKE model with a span classification head on top (a linear layer on top of the hidden states output) for tasks
    such as named entity recognition.
    c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeef dddZ  ZS )
LukeForEntitySpanClassificationc                    sN   t  | t|| _|j| _t|j| _t	|j
d |j| _|   d S )Nr
   r  rL   rO   r$   r%   r>   r  s    
z(LukeForEntitySpanClassification.__init__r   r   N)r\   r   r]   r[   rn   r   r   r   entity_start_positionsentity_end_positionsr   r^   r  r   r   r   r   c                 C   s  |dk	r|n| j j}| j||||||||||||dd}|jd}|	ddd|}	|	j|jjkrt|	|jj}	t	
|jd|	}|
ddd|}
|
j|jjkr|
|jj}
t	
|jd|
}t	j|||jgdd}| |}| |}d}|dk	rT||j}|jdkr4tj|d| j|d}n tj|d|d|}|s|tdd	 |||j|j|jfD S t|||j|j|jd
S )u'	  
        entity_start_positions (`torch.LongTensor`):
            The start positions of entities in the word token sequence.

        entity_end_positions (`torch.LongTensor`):
            The end positions of entities in the word token sequence.

        labels (`torch.LongTensor` of shape `(batch_size, entity_length)` or `(batch_size, entity_length, num_labels)`, *optional*):
            Labels for computing the classification loss. If the shape is `(batch_size, entity_length)`, the cross
            entropy loss is used for the single-label classification. In this case, labels should contain the indices
            that should be in `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, entity_length,
            num_labels)`, the binary cross entropy loss is used for the multi-label classification. In this case,
            labels should only contain `[0, 1]`, where 0 and 1 indicate false and true, respectively.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, LukeForEntitySpanClassification

        >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
        >>> model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

        >>> text = "Beyoncé lives in Los Angeles"
        # List all possible entity spans in the text

        >>> word_start_positions = [0, 8, 14, 17, 21]  # character-based start positions of word tokens
        >>> word_end_positions = [7, 13, 16, 20, 28]  # character-based end positions of word tokens
        >>> entity_spans = []
        >>> for i, start_pos in enumerate(word_start_positions):
        ...     for end_pos in word_end_positions[i:]:
        ...         entity_spans.append((start_pos, end_pos))

        >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
        >>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
        ...     if predicted_class_idx != 0:
        ...         print(text[span[0] : span[1]], model.config.id2label[predicted_class_idx])
        Beyoncé PER
        Los Angeles LOC
        ```NTr  rQ   rp   r   rq   c                 s   s   | ]}|d k	r|V  qd S r   r$   r   r$   r$   r%   r     s   z:LukeForEntitySpanClassification.forward.<locals>.<genexpr>r   )rN   r   r   r   rX   rb   rc   rT   rV   r!   gatherr   r   rK   r  r!  r   r   r"  r   r  r#  rt   r   r-   r   r.   r2   )rM   r\   r   r]   r[   rn   r   r   r   r(  r)  r   r^   r  r   r   r   r   rA   Zstart_statesZ
end_statesr%  r+   r(   r$   r$   r%   ra   ~  sZ    A


  z'LukeForEntitySpanClassification.forward)NNNNNNNNNNNNNNNN)r   r   r   r>   r   r  r  r   r2   r  r   r!   rv   r"   r  r   r   ra   rd   r$   r$   rO   r%   r'  j  sL   
                
r'  z
    The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej eej eej eej eej eej eej ee ee ee eee	f dddZ  ZS )
LukeForSequenceClassificationc                    sZ   t  | |j| _t|| _t|jd k	r2|jn|j| _	t
|j|j| _|   d S r   r=   r>   r  r   r   r   rI   classifier_dropoutrJ   rK   rl   rA   r  r   rL   rO   r$   r%   r>     s    
z&LukeForSequenceClassification.__init__r   r   r   r   Nr  c                 C   s  |dk	r|n| j j}| j|||||||||	|
||dd}|j}| |}| |}d}|dk	rJ||j}| j jdkr| j	dkrd| j _n4| j	dkr|j
tjks|j
tjkrd| j _nd| j _| j jdkrt }| j	dkr|| | }n
|||}nN| j jdkr,t }||d| j	|d}n| j jdkrJt }|||}|srtd	d
 |||j|j|jfD S t|||j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTr  r   Z
regressionZsingle_label_classificationZmulti_label_classificationrQ   c                 s   s   | ]}|d k	r|V  qd S r   r$   r   r$   r$   r%   r   ^  s   z8LukeForSequenceClassification.forward.<locals>.<genexpr>r   )rN   r   r   r   rK   r  rV   rT   Zproblem_typer  rS   r!   rZ   r|   r	   squeezer   r   r   r   r-   r   r.   r3   )rM   r\   r   r]   r[   rn   r   r   r   r   r^   r  r   r   r   r   r   r+   r(   loss_fctr$   r$   r%   ra     sd    




"


z%LukeForSequenceClassification.forward)NNNNNNNNNNNNNN)r   r   r   r>   r   r  r  r   _CHECKPOINT_FOR_DOCr3   r  r   r!   rv   r"   r  r   r   ra   rd   r$   r$   rO   r%   r+    sL                 
r+  z
    The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To
    solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this
    class.
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej eej eej eej eej eej eej ee ee ee eee	f dddZ  ZS )
LukeForTokenClassificationc                    s^   t  | |j| _t|dd| _t|jd k	r6|jn|j| _	t
|j|j| _|   d S NF)r   r,  rL   rO   r$   r%   r>   v  s    z#LukeForTokenClassification.__init__r   r.  Nr  c                 C   s   |dk	r|n| j j}| j|||||||||	|
||dd}|j}| |}| |}d}|dk	r||j}t }||	d| j
|	d}|stdd |||j|j|jfD S t|||j|j|jdS )J  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        NTr  rQ   c                 s   s   | ]}|d k	r|V  qd S r   r$   r   r$   r$   r%   r     s   z5LukeForTokenClassification.forward.<locals>.<genexpr>r   )rN   r   r   r   rK   r  rV   rT   r   r   r  r   r-   r   r.   r4   )rM   r\   r   r]   r[   rn   r   r   r   r   r^   r  r   r   r   r   r   r+   r(   r0  r$   r$   r%   ra     sF    

z"LukeForTokenClassification.forward)NNNNNNNNNNNNNN)r   r   r   r>   r   r  r  r   r1  r4   r  r   r!   rv   r"   r  r   r   ra   rd   r$   r$   rO   r%   r2  m  sL   	              
r2  z
    The LUKE Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej eej eej eej eej eej eej eej ee ee ee eee	f dddZ  ZS )
LukeForQuestionAnsweringc                    s@   t  | |j| _t|dd| _t|j|j| _| 	  d S r3  )
r=   r>   r  r   r   r   rl   rA   
qa_outputsr   rL   rO   r$   r%   r>     s
    z!LukeForQuestionAnswering.__init__r   r.  N)r\   r   r]   r[   rn   r   r   r   r   r^   start_positionsend_positionsr   r   r   r   c                 C   sJ  |dk	r|n| j j}| j|||||||||	|
||dd}|j}| |}|jddd\}}|d}|d}d}|dk	r|dk	rt| dkr|d}t| dkr|d}|d}|	d| |	d| t
|d}|||}|||}|| d	 }|s0td
d ||||j|j|jfD S t||||j|j|jdS )a  
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        NTr  r   rQ   rq   r   )Zignore_indexr   c                 s   s   | ]}|d k	r|V  qd S r   r$   r   r$   r$   r%   r   .  s   z3LukeForQuestionAnswering.forward.<locals>.<genexpr>)r(   r6   r7   r-   r   r.   )rN   r   r   r   r6  splitr/  lenrX   Zclamp_r   r   r-   r   r.   r5   )rM   r\   r   r]   r[   rn   r   r   r   r   r^   r7  r8  r   r   r   r   r   r+   r6   r7   Z
total_lossZignored_indexr0  Z
start_lossZend_lossr$   r$   r%   ra     sh    "








z LukeForQuestionAnswering.forward)NNNNNNNNNNNNNNN)r   r   r   r>   r   r  r  r   r1  r5   r  r   r!   rv   r"   r  r   r   ra   rd   r$   r$   rO   r%   r5    sP                  
r5  z
    The LUKE Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej eej eej eej eej eej eej ee ee ee eee	f dddZ  ZS )
LukeForMultipleChoicec                    sP   t  | t|| _t|jd k	r*|jn|j| _t	|j
d| _|   d S r   )r=   r>   r   r   r   rI   r-  rJ   rK   rl   rA   r  r   rL   rO   r$   r%   r>   M  s    
zLukeForMultipleChoice.__init__z(batch_size, num_choices, sequence_lengthr.  Nr  c                 C   s  |dk	r|n| j j}|dk	r&|jd n|
jd }|dk	rJ|d|dnd}|dk	rh|d|dnd}|dk	r|d|dnd}|dk	r|d|dnd}|
dk	r|
d|
d|
dnd}
|dk	r|d|dnd}|dk	r|d|dnd}|dk	r(|d|dnd}|dk	rP|d|d|dnd}| j|||||||||	|
||dd}|j}| |}| |}|d|}d}|dk	r|	|j
}t }|||}|stdd |||j|j|jfD S t|||j|j|jd	S )
r4  Nr   rQ   rp   Tr  c                 s   s   | ]}|d k	r|V  qd S r   r$   r   r$   r$   r%   r     s   z0LukeForMultipleChoice.forward.<locals>.<genexpr>r   )rN   r   r  r   rX   r   r   rK   r  rV   rT   r   r   r-   r   r.   r8   )rM   r\   r   r]   r[   rn   r   r   r   r   r^   r  r   r   r   Znum_choicesr   r   r+   Zreshaped_logitsr(   r0  r$   r$   r%   ra   Y  s~    



zLukeForMultipleChoice.forward)NNNNNNNNNNNNNN)r   r   r   r>   r   r  r  r   r1  r8   r  r   r!   rv   r"   r  r   r   ra   rd   r$   r$   rO   r%   r;  E  sL                 
r;  )Mr    r   dataclassesr   typingr   r   r   r!   Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zactivationsr   r   Zmodeling_outputsr   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   r   r   r   Zconfiguration_luker   Z
get_loggerr   loggerr  r1  Z"LUKE_PRETRAINED_MODEL_ARCHIVE_LISTr   r&   r'   r/   r1   r2   r3   r4   r5   r8   r   r9   re   rw   r   r   r   r   r   r   r   r   r   r   ZLUKE_START_DOCSTRINGr  r   rU   r	  r  r  r&  r'  r+  r2  r5  r;  r$   r$   r$   r%   <module>   s    
*!!$#I(r04U!K I lq jYp