U
    9%e                     @   sZ  d Z ddlZddlZddlmZmZmZ ddlZddlZddlm	Z	 ddl
mZmZmZ ddlmZ ddlmZmZmZ dd	lmZmZmZmZmZmZ dd
lmZmZmZmZ ddl m!Z! ddl"m#Z# e!$e%Z&dZ'dZ(ddgZ)G dd de	j*Z+d=ddZ,G dd de	j*Z-G dd de	j*Z.G dd de	j*Z/G dd de	j*Z0G dd de	j*Z1G dd  d e	j*Z2G d!d" d"e	j*Z3G d#d$ d$e	j*Z4G d%d& d&e	j*Z5G d'd( d(e	j*Z6G d)d* d*e	j*Z7G d+d, d,e	j*Z8G d-d. d.eZ9d/Z:d0Z;ed1e:G d2d3 d3e9Z<ed4e:G d5d6 d6e9Z=ed7e:G d8d9 d9e9Z>ed:e:G d;d< d<e9Z?dS )>z PyTorch MarkupLM model.    N)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)add_start_docstrings%add_start_docstrings_to_model_forwardreplace_return_docstrings))BaseModelOutputWithPastAndCrossAttentions,BaseModelOutputWithPoolingAndCrossAttentionsMaskedLMOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModelapply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)logging   )MarkupLMConfigzmicrosoft/markuplm-baser   zmicrosoft/markuplm-largec                       s*   e Zd ZdZ fddZdddZ  ZS )XPathEmbeddingszConstruct the embeddings from xpath tags and subscripts.

    We drop tree-id in this version, as its info can be covered by xpath.
    c                    s   t t|    j| _t j| j  j| _t	 j
| _t | _t j| j d j | _td j  j| _t fddt| jD | _t fddt| jD | _d S )N   c                    s   g | ]}t  j jqS  )r   	EmbeddingZmax_xpath_tag_unit_embeddingsxpath_unit_hidden_size.0_configr   m/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/markuplm/modeling_markuplm.py
<listcomp>P   s   z,XPathEmbeddings.__init__.<locals>.<listcomp>c                    s   g | ]}t  j jqS r   )r   r   Zmax_xpath_subs_unit_embeddingsr   r    r#   r   r%   r&   W   s   )superr   __init__	max_depthr   Linearr   hidden_sizeZxpath_unitseq2_embeddingsDropouthidden_dropout_probdropoutZReLU
activationxpath_unitseq2_inner	inner2emb
ModuleListrangexpath_tag_sub_embeddingsxpath_subs_sub_embeddingsselfr$   	__class__r#   r%   r(   C   s"    


zXPathEmbeddings.__init__Nc              	   C   s   g }g }t | jD ]P}|| j| |d d d d |f  || j| |d d d d |f  qtj|dd}tj|dd}|| }| | | 	| 
|}|S )Ndim)r3   r)   appendr4   r5   torchcatr1   r.   r/   r0   )r7   xpath_tags_seqxpath_subs_seqZxpath_tags_embeddingsZxpath_subs_embeddingsixpath_embeddingsr   r   r%   forward]   s    &(zXPathEmbeddings.forward)NN)__name__
__module____qualname____doc__r(   rD   __classcell__r   r   r8   r%   r   =   s   r   c                 C   s6   |  | }tj|dd|| | }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   r;   )neintr>   ZcumsumZtype_aslong)	input_idspadding_idxpast_key_values_lengthmaskZincremental_indicesr   r   r%   "create_position_ids_from_input_idsp   s    rQ   c                       s2   e Zd ZdZ fddZdd Zd
dd	Z  ZS )MarkupLMEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t t|   || _tj|j|j|jd| _	t|j
|j| _|j| _t|| _t|j|j| _tj|j|jd| _t|j| _| jdt|j
ddd |j| _tj|j
|j| jd| _d S )N)rN   Zepsposition_ids)r   r:   F)
persistent)r'   rR   r(   r$   r   r   
vocab_sizer+   Zpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsr)   r   rC   Ztype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsr,   r-   r.   Zregister_bufferr>   arangeexpandrN   r6   r8   r   r%   r(      s(    
    zMarkupLMEmbeddings.__init__c                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        Nr:   r   dtypedevicer   )sizer>   r]   rN   rL   ra   	unsqueezer^   )r7   inputs_embedsinput_shapeZsequence_lengthrT   r   r   r%   &create_position_ids_from_inputs_embeds   s    	   z9MarkupLMEmbeddings.create_position_ids_from_inputs_embedsNr   c                 C   s<  |d k	r|  }n|  d d }|d k	r0|jn|j}	|d kr`|d k	rVt|| j|}n
| |}|d krztj|tj|	d}|d kr| |}|d kr| j	j
tjtt|| jg tj|	d }|d kr| j	jtjtt|| jg tj|	d }|}
| |}| |}| ||}|
| | | }| |}| |}|S )Nr:   r_   )rb   ra   rQ   rN   rf   r>   zerosrL   rW   r$   Z
tag_pad_idonestuplelistr)   Zsubs_pad_idrY   rZ   rC   r[   r.   )r7   rM   r@   rA   token_type_idsrT   rd   rO   re   ra   Zwords_embeddingsrY   rZ   rC   
embeddingsr   r   r%   rD      s@    




  
  



zMarkupLMEmbeddings.forward)NNNNNNr   )rE   rF   rG   rH   r(   rf   rD   rI   r   r   r8   r%   rR      s          rR   c                       s4   e Zd Z fddZejejejdddZ  ZS )MarkupLMSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S NrS   )r'   r(   r   r*   r+   denser[   r\   r,   r-   r.   r6   r8   r   r%   r(      s    
zMarkupLMSelfOutput.__init__hidden_statesinput_tensorreturnc                 C   s&   |  |}| |}| || }|S Nro   r.   r[   r7   rq   rr   r   r   r%   rD      s    

zMarkupLMSelfOutput.forwardrE   rF   rG   r(   r>   TensorrD   rI   r   r   r8   r%   rm      s   rm   c                       s0   e Zd Z fddZejejdddZ  ZS )MarkupLMIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S rt   )r'   r(   r   r*   r+   intermediate_sizero   
isinstance
hidden_actstrr
   intermediate_act_fnr6   r8   r   r%   r(      s
    
zMarkupLMIntermediate.__init__rq   rs   c                 C   s   |  |}| |}|S rt   )ro   r~   r7   rq   r   r   r%   rD      s    

zMarkupLMIntermediate.forwardrw   r   r   r8   r%   ry      s   ry   c                       s4   e Zd Z fddZejejejdddZ  ZS )MarkupLMOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S rn   )r'   r(   r   r*   rz   r+   ro   r[   r\   r,   r-   r.   r6   r8   r   r%   r(     s    
zMarkupLMOutput.__init__rp   c                 C   s&   |  |}| |}| || }|S rt   ru   rv   r   r   r%   rD   	  s    

zMarkupLMOutput.forwardrw   r   r   r8   r%   r     s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )MarkupLMPoolerc                    s*   t    t|j|j| _t | _d S rt   )r'   r(   r   r*   r+   ro   ZTanhr/   r6   r8   r   r%   r(     s    
zMarkupLMPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )ro   r/   )r7   rq   Zfirst_token_tensorpooled_outputr   r   r%   rD     s    

zMarkupLMPooler.forwardrw   r   r   r8   r%   r     s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )MarkupLMPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S rn   )r'   r(   r   r*   r+   ro   r{   r|   r}   r
   transform_act_fnr[   r\   r6   r8   r   r%   r(   "  s    
z(MarkupLMPredictionHeadTransform.__init__r   c                 C   s"   |  |}| |}| |}|S rt   )ro   r   r[   r   r   r   r%   rD   +  s    


z'MarkupLMPredictionHeadTransform.forwardrw   r   r   r8   r%   r   !  s   	r   c                       s$   e Zd Z fddZdd Z  ZS )MarkupLMLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)bias)r'   r(   r   	transformr   r*   r+   rV   decoder	Parameterr>   rg   r   r6   r8   r   r%   r(   4  s
    

z!MarkupLMLMPredictionHead.__init__c                 C   s   |  |}| |}|S rt   )r   r   r   r   r   r%   rD   A  s    

z MarkupLMLMPredictionHead.forward)rE   rF   rG   r(   rD   rI   r   r   r8   r%   r   3  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )MarkupLMOnlyMLMHeadc                    s   t    t|| _d S rt   )r'   r(   r   predictionsr6   r8   r   r%   r(   I  s    
zMarkupLMOnlyMLMHead.__init__)sequence_outputrs   c                 C   s   |  |}|S rt   )r   )r7   r   prediction_scoresr   r   r%   rD   M  s    
zMarkupLMOnlyMLMHead.forwardrw   r   r   r8   r%   r   H  s   r   c                
       s   e Zd Zd fdd	ZejejdddZdejeej eej eej eej ee	e	ej   ee
 e	ej dd	d
Z  ZS )MarkupLMSelfAttentionNc                    s   t    |j|j dkr>t|ds>td|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|pt|dd| _| jdks| jd	kr|j| _t	d
|j d | j| _|j| _d S )Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()position_embedding_typeabsoluterelative_keyrelative_key_query   r   )r'   r(   r+   num_attention_headshasattr
ValueErrorrK   attention_head_sizeall_head_sizer   r*   querykeyvaluer,   Zattention_probs_dropout_probr.   getattrr   rX   r   distance_embedding
is_decoderr7   r$   r   r8   r   r%   r(   T  s*    
  zMarkupLMSelfAttention.__init__)xrs   c                 C   s6   |  d d | j| jf }||}|ddddS )Nr:   r   r   r   r	   )rb   r   r   viewpermute)r7   r   Znew_x_shaper   r   r%   transpose_for_scoresn  s    
z*MarkupLMSelfAttention.transpose_for_scoresFrq   attention_mask	head_maskencoder_hidden_statesencoder_attention_maskpast_key_valueoutput_attentionsrs   c                 C   s  |  |}|d k	}	|	r4|d k	r4|d }
|d }|}n|	r^| | |}
| | |}|}nv|d k	r| | |}
| | |}tj|d |
gdd}
tj|d |gdd}n | | |}
| | |}| |}|d k	}| jr|
|f}t||
dd}| j	dks | j	dkr|j
d |
j
d  }}|r^tj|d tj|jd	dd}ntj|tj|jd	dd}tj|tj|jd	dd}|| }| || j d }|j|jd
}| j	dkrtd||}|| }n4| j	dkrtd||}td|
|}|| | }|t| j }|d k	r:|| }tjj|dd}| |}|d k	rf|| }t||}|dddd }| d d | jf }||}|r||fn|f}| jr||f }|S )Nr   r   r   r;   r:   r   r   r_   r`   zbhld,lrd->bhlrzbhrd,lrd->bhlrr	   ) r   r   r   r   r>   r?   r   matmulZ	transposer   shapeZtensorrL   ra   r   r]   r   rX   tor`   Zeinsummathsqrtr   r   Z
functionalZsoftmaxr.   r   
contiguousrb   r   )r7   rq   r   r   r   r   r   r   Zmixed_query_layerZis_cross_attentionZ	key_layerZvalue_layerZquery_layer	use_cacheZattention_scoresZquery_lengthZ
key_lengthZposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keyZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr   r   r%   rD   s  sp    


 





zMarkupLMSelfAttention.forward)N)NNNNNF)rE   rF   rG   r(   r>   rx   r   r   FloatTensorr   boolrD   rI   r   r   r8   r%   r   S  s$         r   c                
       sv   e Zd Zd
 fdd	Zdd Zdejeej eej eej eej ee	e	ej   ee
 e	ej ddd	Z  ZS )MarkupLMAttentionNc                    s.   t    t||d| _t|| _t | _d S )Nr   )r'   r(   r   r7   rm   outputsetpruned_headsr   r8   r   r%   r(     s    

zMarkupLMAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r;   )lenr   r7   r   r   r   r   r   r   r   r   ro   r   union)r7   headsindexr   r   r%   prune_heads  s       zMarkupLMAttention.prune_headsFr   c              	   C   s<   |  |||||||}| |d |}	|	f|dd   }
|
S )Nr   r   )r7   r   )r7   rq   r   r   r   r   r   r   Zself_outputsattention_outputr   r   r   r%   rD     s    
	zMarkupLMAttention.forward)N)NNNNNF)rE   rF   rG   r(   r   r>   rx   r   r   r   r   rD   rI   r   r   r8   r%   r     s$         r   c                
       st   e Zd Z fddZd
ejeej eej eej eej eeeej   ee	 eej dddZ
dd	 Z  ZS )MarkupLMLayerc                    sr   t    |j| _d| _t|| _|j| _|j| _| jrZ| jsLt|  dt|dd| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is addedr   r   )r'   r(   chunk_size_feed_forwardseq_len_dimr   	attentionr   add_cross_attentionr   crossattentionry   intermediater   r   r6   r8   r   r%   r(     s    


zMarkupLMLayer.__init__NFr   c              	   C   s  |d k	r|d d nd }| j |||||d}	|	d }
| jrP|	dd }|	d }n|	dd  }d }| jr|d k	rt| dstd|  d|d k	r|d	d  nd }| |
||||||}|d }
||dd  }|d }|| }t| j| j| j|
}|f| }| jr||f }|S )
Nr   r   r   r   r   r:   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r   )	r   r   r   r   r   r   feed_forward_chunkr   r   )r7   rq   r   r   r   r   r   r   Zself_attn_past_key_valueZself_attention_outputsr   r   Zpresent_key_valueZcross_attn_present_key_valueZcross_attn_past_key_valueZcross_attention_outputslayer_outputr   r   r%   rD     sV    


	   

zMarkupLMLayer.forwardc                 C   s   |  |}| ||}|S rt   )r   r   )r7   r   Zintermediate_outputr   r   r   r%   r   \  s    
z MarkupLMLayer.feed_forward_chunk)NNNNNF)rE   rF   rG   r(   r>   rx   r   r   r   r   rD   r   rI   r   r   r8   r%   r     s$         Ar   c                       s   e Zd Z fddZd	ejeej eej eej eej eeeej   ee	 ee	 ee	 ee	 e
eej ef dddZ  ZS )
MarkupLMEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r   )r   r    r#   r   r%   r&   g  s     z,MarkupLMEncoder.__init__.<locals>.<listcomp>F)	r'   r(   r$   r   r2   r3   num_hidden_layerslayergradient_checkpointingr6   r8   r#   r%   r(   d  s    
 zMarkupLMEncoder.__init__NFT)rq   r   r   r   r   past_key_valuesr   r   output_hidden_statesreturn_dictrs   c              	      st  |	rdnd } rdnd } r(| j jr(dnd }| jrJ| jrJ|rJtd d}|rRdnd }t| jD ]\}}|	rv||f }|d k	r|| nd }|d k	r|| nd | jr| jrև fdd}tj	j

|||||||}n|||||| }|d }|r||d f7 } r`||d f }| j jr`||d	 f }q`|	r@||f }|
sbtd
d |||||fD S t|||||dS )Nr   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fc                    s    fdd}|S )Nc                     s    | f S rt   r   )inputs)moduler   r   r   r%   custom_forward  s    zNMarkupLMEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr   )r   r   r   )r   r%   create_custom_forward  s    z6MarkupLMEncoder.forward.<locals>.create_custom_forwardr   r:   r   r   c                 s   s   | ]}|d k	r|V  qd S rt   r   )r!   vr   r   r%   	<genexpr>  s   z*MarkupLMEncoder.forward.<locals>.<genexpr>)last_hidden_stater   rq   
attentionscross_attentions)r$   r   r   ZtrainingloggerZwarning_once	enumerater   r>   utils
checkpointri   r   )r7   rq   r   r   r   r   r   r   r   r   r   Zall_hidden_statesZall_self_attentionsZall_cross_attentionsZnext_decoder_cacherB   Zlayer_moduleZlayer_head_maskr   Zlayer_outputsr   r   r%   rD   j  sv    
	

zMarkupLMEncoder.forward)	NNNNNNFFT)rE   rF   rG   r(   r>   rx   r   r   r   r   r   r   rD   rI   r   r   r8   r%   r   c  s.   	         r   c                       sL   e Zd ZdZeZeZdZdd Z	e
eeeejf  d fddZ  ZS )MarkupLMPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    markuplmc                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  nft |tj
rz|jjjd| jjd |jdk	r|jj|j 	  n&t |tjr|jj	  |jjd dS )zInitialize the weightsg        )ZmeanZstdN      ?)r{   r   r*   weightdataZnormal_r$   Zinitializer_ranger   Zzero_r   rN   r[   Zfill_)r7   r   r   r   r%   _init_weights  s    

z%MarkupLMPreTrainedModel._init_weights)pretrained_model_name_or_pathc                    s   t t| j|f||S rt   )r'   r   from_pretrained)clsr   Z
model_argskwargsr8   r   r%   r     s    
z'MarkupLMPreTrainedModel.from_pretrained)rE   rF   rG   rH   r   config_class&MARKUPLM_PRETRAINED_MODEL_ARCHIVE_LISTZpretrained_model_archive_mapZbase_model_prefixr   classmethodr   r   r}   osPathLiker   rI   r   r   r8   r%   r     s   r   aK  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`MarkupLMConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)

        xpath_tags_seq (`torch.LongTensor` of shape `({0}, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.

        xpath_subs_seq (`torch.LongTensor` of shape `({0}, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.

        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: `1` for
            tokens that are NOT MASKED, `0` for MASKED tokens.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`: `0` corresponds to a *sentence A* token, `1` corresponds to a *sentence B* token

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: `1`
            indicates the head is **not masked**, `0` indicates the head is **masked**.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            If set to `True`, the attentions tensors of all attention layers are returned. See `attentions` under
            returned tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            If set to `True`, the hidden states of all layers are returned. See `hidden_states` under returned tensors
            for more detail.
        return_dict (`bool`, *optional*):
            If set to `True`, the model will return a [`~file_utils.ModelOutput`] instead of a plain tuple.
zbThe bare MarkupLM Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zd fdd	Zdd Zdd Zdd	 Zee	d
e
eeddeej eej eej eej eej eej eej eej ee ee ee eeef dddZdddZdd Z  ZS )MarkupLMModelTc                    sD   t  | || _t|| _t|| _|r2t|nd | _| 	  d S rt   )
r'   r(   r$   rR   rl   r   encoderr   pooler	post_init)r7   r$   add_pooling_layerr8   r   r%   r(   )  s    

zMarkupLMModel.__init__c                 C   s   | j jS rt   rl   rW   )r7   r   r   r%   get_input_embeddings5  s    z"MarkupLMModel.get_input_embeddingsc                 C   s   || j _d S rt   r   )r7   r   r   r   r%   set_input_embeddings8  s    z"MarkupLMModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r7   Zheads_to_pruner   r   r   r   r%   _prune_heads;  s    zMarkupLMModel._prune_headsbatch_size, sequence_lengthoutput_typer   N)rM   r@   rA   r   rk   rT   r   rd   r   r   r   rs   c                 C   s  |	dk	r|	n| j j}	|
dk	r |
n| j j}
|dk	r4|n| j j}|dk	rV|dk	rVtdn@|dk	rt| || | }n"|dk	r| dd }ntd|dk	r|jn|j}|dkrtj	||d}|dkrtj
|tj|d}|dd}|j| jd	}d
| d }|dk	r| dkrP|dddd}|| j jdddd}n$| dkrt|ddd}|jt|  jd	}ndg| j j }| j||||||d}| j||||	|
|d}|d }| jdk	r| |nd}|s||f|dd  S t|||j|j|jdS )a`  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoProcessor, MarkupLMModel

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> model = MarkupLMModel.from_pretrained("microsoft/markuplm-base")

        >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"

        >>> encoding = processor(html_string, return_tensors="pt")

        >>> outputs = model(**encoding)
        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 4, 768]
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timer:   z5You have to specify either input_ids or inputs_embeds)ra   r_   r   r   r   r   g     r   )rM   r@   rA   rT   rk   rd   )r   r   r   r   )r   Zpooler_outputrq   r   r   )r$   r   r   use_return_dictr   Z%warn_if_padding_and_no_attention_maskrb   ra   r>   rh   rg   rL   rc   r   r`   r<   r^   r   next
parametersrl   r   r   r   rq   r   r   )r7   rM   r@   rA   r   rk   rT   r   rd   r   r   r   re   ra   Zextended_attention_maskZembedding_outputZencoder_outputsr   r   r   r   r%   rD   C  sn    $


zMarkupLMModel.forwardc                 K   sB   |j }|d kr||}|d k	r4|d d dd f }||||dS )Nr:   )rM   r   r   r   )r   Znew_ones)r7   rM   r   r   r   Zmodel_kwargsre   r   r   r%   prepare_inputs_for_generation  s    
z+MarkupLMModel.prepare_inputs_for_generationc                    s.   d}|D ] }|t  fdd|D f7 }q|S )Nr   c                 3   s"   | ]}| d  |jV  qdS )r   N)Zindex_selectr   ra   )r!   Z
past_statebeam_idxr   r%   r     s     z/MarkupLMModel._reorder_cache.<locals>.<genexpr>)ri   )r7   r   r  Zreordered_pastZ
layer_pastr   r  r%   _reorder_cache  s    zMarkupLMModel._reorder_cache)T)NNNNNNNNNNN)NNT)rE   rF   rG   r(   r   r   r   r   MARKUPLM_INPUTS_DOCSTRINGformatr   r   _CONFIG_FOR_DOCr   r>   Z
LongTensorr   r   r   r   rD   r  r  rI   r   r   r8   r%   r   #  sH   
           
i     
r   z
    MarkupLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeej ef dddZ  ZS )
MarkupLMForQuestionAnsweringc                    s@   t  | |j| _t|dd| _t|j|j| _| 	  d S NF)r   )
r'   r(   
num_labelsr   r   r   r*   r+   
qa_outputsr   r6   r8   r   r%   r(     s
    z%MarkupLMForQuestionAnswering.__init__r   r   N)rM   r@   rA   r   rk   rT   r   rd   start_positionsend_positionsr   r   r   rs   c                 C   sT  |dk	r|n| j j}| j|||||||||||d}|d }| |}|jddd\}}|d }|d }d}|	dk	r|
dk	rt|	 dkr|	d}	t|
 dkr|
d}
|d}|		d| |
	d| t
|d}|||	}|||
}|| d }|s>||f|dd  }|dk	r:|f| S |S t||||j|jd	S )
a  
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoProcessor, MarkupLMForQuestionAnswering
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base-finetuned-websrc")
        >>> model = MarkupLMForQuestionAnswering.from_pretrained("microsoft/markuplm-base-finetuned-websrc")

        >>> html_string = "<html> <head> <title>My name is Niels</title> </head> </html>"
        >>> question = "What's his name?"

        >>> encoding = processor(html_string, questions=question, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> answer_start_index = outputs.start_logits.argmax()
        >>> answer_end_index = outputs.end_logits.argmax()

        >>> predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
        >>> processor.decode(predict_answer_tokens).strip()
        'Niels'
        ```N
r@   rA   r   rk   rT   r   rd   r   r   r   r   r   r:   r;   )Zignore_indexr   )lossstart_logits
end_logitsrq   r   )r$   r   r   r  splitsqueezer   r   rb   Zclamp_r   r   rq   r   )r7   rM   r@   rA   r   rk   rT   r   rd   r  r  r   r   r   r   r   logitsr  r  Z
total_lossZignored_indexloss_fctZ
start_lossZend_lossr   r   r   r%   rD     sT    6






z$MarkupLMForQuestionAnswering.forward)NNNNNNNNNNNNN)rE   rF   rG   r(   r   r  r  r   r   r  r   r>   rx   r   r   r   rD   rI   r   r   r8   r%   r    s@   	

             r  z9MarkupLM Model with a `token_classification` head on top.c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeej ef dddZ  ZS )
MarkupLMForTokenClassificationc                    sb   t  | |j| _t|dd| _|jd k	r2|jn|j}t|| _	t
|j|j| _|   d S r	  )r'   r(   r
  r   r   classifier_dropoutr-   r   r,   r.   r*   r+   
classifierr   r7   r$   r  r8   r   r%   r(   K  s    z'MarkupLMForTokenClassification.__init__r   r   NrM   r@   rA   r   rk   rT   r   rd   labelsr   r   r   rs   c                 C   s   |dk	r|n| j j}| j|||||||||
||d}|d }| |}d}|	dk	rtt }||d| j j|	d}|s|f|dd  }|dk	r|f| S |S t|||j|j	dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForTokenClassification
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> processor.parse_html = False
        >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)

        >>> nodes = ["hello", "world"]
        >>> xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"]
        >>> node_labels = [1, 2]
        >>> encoding = processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ```Nr  r   r:   r   r  r  rq   r   )
r$   r   r   r  r   r   r
  r   rq   r   )r7   rM   r@   rA   r   rk   rT   r   rd   r  r   r   r   r   r   r   r  r  r   r   r   r%   rD   Y  s@    ,
z&MarkupLMForTokenClassification.forward)NNNNNNNNNNNN)rE   rF   rG   r(   r   r  r  r   r   r  r   r>   rx   r   r   r   rD   rI   r   r   r8   r%   r  H  s<   
            r  z
    MarkupLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeej ef dddZ  ZS )
!MarkupLMForSequenceClassificationc                    sd   t  | |j| _|| _t|| _|jd k	r4|jn|j}t	|| _
t|j|j| _|   d S rt   )r'   r(   r
  r$   r   r   r  r-   r   r,   r.   r*   r+   r  r   r  r8   r   r%   r(     s    
z*MarkupLMForSequenceClassification.__init__r   r   Nr  c                 C   s  |dk	r|n| j j}| j|||||||||
||d}|d }| |}| |}d}|	dk	r<| j jdkr| jdkr~d| j _n4| jdkr|	jtj	ks|	jtj
krd| j _nd| j _| j jdkrt }| jdkr|| |	 }n
|||	}nN| j jdkrt }||d| j|	d}n| j jdkr<t }|||	}|sl|f|dd  }|dk	rh|f| S |S t|||j|jd	S )
a&  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)

        >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"
        >>> encoding = processor(html_string, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ```Nr  r   Z
regressionZsingle_label_classificationZmulti_label_classificationr:   r   r  )r$   r   r   r.   r  Zproblem_typer
  r`   r>   rL   rK   r   r  r   r   r   r   rq   r   )r7   rM   r@   rA   r   rk   rT   r   rd   r  r   r   r   r   r   r  r  r  r   r   r   r%   rD     sZ    +




"


z)MarkupLMForSequenceClassification.forward)NNNNNNNNNNNN)rE   rF   rG   r(   r   r  r  r   r   r  r   r>   rx   r   r   r   rD   rI   r   r   r8   r%   r    s<   	
            r  )r   )@rH   r   r   typingr   r   r   r>   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Z
file_utilsr   r   r   Zmodeling_outputsr   r   r   r   r   r   Zmodeling_utilsr   r   r   r   r   r   Zconfiguration_markuplmr   Z
get_loggerrE   r   Z_CHECKPOINT_FOR_DOCr  r   Moduler   rQ   rR   rm   ry   r   r   r   r   r   r   r   r   r   r   ZMARKUPLM_START_DOCSTRINGr  r   r  r  r  r   r   r   r%   <module>   sn    
3
c 2Wb"1 $wc