U
    ,-e                  	   @   s  d Z ddlZddlmZmZmZ ddlZddlZddlmZ ddl	m
Z
mZmZmZ ddl	mZ ddlmZmZmZ dd	lmZmZmZmZmZ dd
lmZ ddlmZ ddlmZ e e!Z"dZ#dZ$dddddddddg	Z%ej&ej'e(ej)dddZ*ej+e(ej)dddZ,d:dd Z-G d!d" d"ej.Z/G d#d$ d$ej.Z0G d%d& d&ej.Z1G d'd( d(eZ2d)Z3d*Z4ed+e3G d,d- d-e2Z5ed.e3G d/d0 d0e2Z6ed1e3G d2d3 d3e2Z7ed4e3G d5d6 d6e2Z8ed7e3G d8d9 d9e2Z9dS );zPyTorch MPT model.    N)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLoss	LayerNormMSELoss)
functional   )add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forward))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentionsQuestionAnsweringModelOutput SequenceClassifierOutputWithPastTokenClassifierOutput)PreTrainedModel)logging   )	MptConfigzmosaicml/mpt-7br   zmosaicml/mpt-7b-storywriterzmosaicml/mpt-7b-instructzmosaicml/mpt-7b-8kzmosaicml/mpt-7b-8k-instructzmosaicml/mpt-7b-8k-chatzmosaicml/mpt-30bzmosaicml/mpt-30b-instructzmosaicml/mpt-30b-chat)input_ids_shapedevicepast_key_values_lengthreturnc                 C   s   | \}}t j||| ft j|d}t j||d}|dddf |dddf k |dd|df< |dkr|d|ddd|f< |ddddddf |d||| }|S )z3
    Make causal mask used for self-attention.
    dtyper   r   Nr   Fr   )torchemptyboolarangeexpand)r   r   r   
batch_sizeZtarget_lengthmaskZseq_idsexpanded_mask r'   e/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/mpt/modeling_mpt.py_make_causal_mask;   s    0(r)   )r%   
tgt_lengthr   c                 C   sL   | j \}}|dk	r|n|}| ddddddf tj }||d||S )zn
    Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
    Nr   )shapetor   r!   r#   )r%   r*   r$   
src_lengthr&   r'   r'   r(   _expand_maskO   s    
"r.      c                 C   s   t jd| dt j|dddd|}dtt|  }t jd|d t j|d}|||  }dt d| }|d| dd}|| krt 	|ddd |ddd gd|  }|| }|
dS )a  
    Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
    relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
    the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
    https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
    r   r            ?Nr   )r   r"   Zint32viewmathceillog2Zfloat32powconcatsqueeze)	num_headssequence_lengthalibi_bias_maxr   alibiZnum_heads_power_of_2baseZslopesr'   r'   r(   build_mpt_alibi_tensorZ   s    $*r>   c                       sP   e Zd ZdZed fddZd	ejejee	ej  eej dddZ
  ZS )
MptAttentionzyMulti-head self attention.
    Using torch or triton attention implemetation enables user to also use additive bias.
    configc                    s   t    |j| _|j| _|j| _| j| j | _|jj| _| jd kr\dt	
| j| j  | _|jj| _tj| jd| j dd| _tj| j| jdd| _d S )Nr   r   Fbias)super__init__hidden_sizen_headsmax_seq_lenZmax_seq_lengthhead_dimattn_configsoftmax_scaler3   sqrt
attn_pdropattn_dropout_pr   LinearWqkvout_projselfrA   	__class__r'   r(   rE   v   s    



zMptAttention.__init__N)hidden_statesposition_biaspast_key_valueattention_maskc                 C   s   |j d d \}}| |}|jddd\}}	}
|||| j| jdd}|	||| j| jdd}	|
||| j| jdd}
|d k	rt|dkrtj	|d |	gdd}	tj	|d |
gdd}
|	|
f}n|	|
f}t
||	dd| j }|d kr|n||d j d  }|d k	rt|j dkr<tdt|j  |	j d }td|d| }td|d| }|d d |d |d f }|| }|d k	r||t|jj}tjj| dd|
j}tjj|| j| jd	}t
||
}|dddd ||d}| |}|||fS )
Nr0   r   dimr   r   z6Expecting position_bias shape to be 3 dimensions, got ptraining)r+   rP   chunkreshaperG   rI   Z	transposelenr   catmatmulrK   
ValueErrormaxsizeZmasked_fillZfinfor   minr   r
   Zsoftmaxfloatr,   dropoutrN   r`   Zpermute
contiguousr2   rQ   )rS   rV   rW   rX   rY   r$   
seq_lengthZ	mixed_qkvZquery_statesZ
key_statesZvalue_statesZattention_scoresZquery_lengthZ
key_lengthZposition_bias_query_indexZposition_bias_key_indexattn_weightsZcontext_statesZattn_outputr'   r'   r(   forward   s<    





zMptAttention.forward)NN)__name__
__module____qualname____doc__r   rE   r   Tensorr   r   ro   __classcell__r'   r'   rT   r(   r?   q   s     r?   c                       s:   e Zd Zed fddZejejejdddZ  ZS )MptMLPr@   c                    sX   t    |j}tj|d| dd| _tjdd| _tjd| |dd| _|j	j
| _d S )N   FrB   none)Zapproximate)rD   rE   rF   r   rO   up_projZGELUact	down_projrJ   rM   hidden_dropoutrS   rA   rF   rT   r'   r(   rE      s    
zMptMLP.__init__)rV   residualr   c                 C   s:   |  | |}| |}tj|| j| jd}|| }|S )Nr^   )rz   ry   r{   Frk   r|   r`   )rS   rV   r~   Zintermediate_outputoutputr'   r'   r(   ro      s
    
zMptMLP.forward)	rp   rq   rr   r   rE   r   rt   ro   ru   r'   r'   rT   r(   rv      s   	rv   c                       sR   e Zd Zed fddZd	ejejejeeejejf  e	e	dddZ
  ZS )
MptBlockr@   c                    sx   t    |j}t||jd| _d | j_|j| _t	|| _
t||jd| _d | j_t|| _|jj| _t| j| _d S )NZeps)rD   rE   rF   r   layer_norm_epsilonnorm_1rC   rG   r9   r?   attnnorm_2rv   ffnrJ   rM   Zdropout_rater   Dropoutresid_attn_dropoutr}   rT   r'   r(   rE      s    



zMptBlock.__init__NF)rV   rW   rY   
layer_past	use_cacheoutput_attentionsc                 C   st   |  |}|}| j||||d\}	}
}| |	| }| |}|}| ||}|f}|rb||f7 }|rp||
f7 }|S )N)rW   rY   rX   )r   r   r   r   r   )rS   rV   rW   rY   r   r   r   Zlayernorm_outputr~   Zattn_outputsrn   rX   r   outputsr'   r'   r(   ro      s$    



zMptBlock.forward)NFF)rp   rq   rr   r   rE   r   rt   r   r   r!   ro   ru   r'   r'   rT   r(   r      s      r   c                       s   e Zd ZeZdZdZdgZdgZ fddZ	e
jddd	Zde
jedddZeeeejejf  eeejejf  dddZ  ZS )MptPreTrainedModeltransformerTr   z
lm_head.*.c                    s   t  j|| d S N)rD   rE   )rS   inputskwargsrT   r'   r(   rE     s    zMptPreTrainedModel.__init__modulec                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  nnt |tj
rz|jjjd| jjd |jdk	r|jj|j 	  n.t |tr|jdk	r|jj	  |jjd dS )zInitialize the weights.g        )ZmeanZstdNr1   )
isinstancer   rO   weightdataZnormal_rA   Zinitializer_rangerC   Zzero_	EmbeddingZpadding_idxr   Zfill_)rS   r   r'   r'   r(   _init_weights  s    



z MptPreTrainedModel._init_weightsF)r   valuec                 C   s   t |tr||_d S r   )r   MptModelgradient_checkpointing)rS   r   r   r'   r'   r(   _set_gradient_checkpointing)  s    
z.MptPreTrainedModel._set_gradient_checkpointing)rX   r   c                    s8   | d d j \}}||  t fdd| D S )zw
        Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
        r   c                 3   s2   | ]*}|d    |d   fV  qdS r   r   N)rb   .0r   Zbatch_size_times_num_headsrI   rm   r'   r(   	<genexpr>8  s   z;MptPreTrainedModel._convert_to_mpt_cache.<locals>.<genexpr>)r+   tuple)rX   r$   r9   r'   r   r(   _convert_to_mpt_cache-  s
    z(MptPreTrainedModel._convert_to_mpt_cache)F)rp   rq   rr   r   config_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_keys_to_ignore_on_load_missingrE   r   Moduler   r!   r   staticmethodr   r   rt   r   ru   r'   r'   rT   r(   r     s   r   a*  

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`MptConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
            Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
            `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
            their past given to this model should not be passed as `input_ids` as they have already been computed.

            Each element of `past_key_values` is a tuple (past_key, past_value):
            - past_key: [batch_size * num_heads, head_dim, kv_length]
            - past_value: [batch_size * num_heads, kv_length, head_dim]
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.

            If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
            `past_key_values`).
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
z]The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zed fddZdd Zddd	Zeje	e
e
f e
ejd
ddZejdddZeeeeeeddeej ee	e	ejejf df  eej eej ee ee ee ee ee	ejdf ef d	ddZ  ZS )r   r@   c                    sz   t     j| _ j| _t j| j| _t	 fddt
 jD | _t| j jd| _d | j_d| _|   d S )Nc                    s   g | ]}t  qS r'   )r   )r   _r@   r'   r(   
<listcomp>  s     z%MptModel.__init__.<locals>.<listcomp>r   F)rD   rE   rF   rG   r9   r   r   
vocab_sizewteZ
ModuleListrangeZn_layersblocksr   r   norm_frC   r   	post_initrR   rT   r@   r(   rE     s     zMptModel.__init__c                 C   s   | j S r   r   rS   r'   r'   r(   get_input_embeddings  s    zMptModel.get_input_embeddingsr/   Nc                 C   s   t ||||S r   )r>   )rS   r9   r:   r;   r   r'   r'   r(   r>     s    zMptModel.build_mpt_alibi_tensor)rY   input_shaper   r   c           	      C   s   |d | |j d kr4td|j  d| d| dd }|j}|\}}|dkr\t|||d}t||d}|d krt|n||B }|S )Nr   zXAttention mask shape should be (batch_size, seq_length + past_key_values_length) but is z with input_ids shape z and past length .)r   r   )r*   )r+   rf   r   r)   r.   )	rS   rY   r   r   Zcombined_attention_maskr   r   r-   Zexpanded_attn_maskr'   r'   r(   _prepare_attn_mask  s"      zMptModel._prepare_attn_masknew_embeddingsc                 C   s
   || _ d S r   r   rS   r   r'   r'   r(   set_input_embeddings  s    zMptModel.set_input_embeddings
checkpointoutput_typer   .)		input_idspast_key_valuesrY   inputs_embedsr   r   output_hidden_statesreturn_dictr   c	              	      s   d k	r n| j j |d k	r |n| j j}d k	r4n| j j|d k	rH|n| j j}|d k	rj|d k	rjtdn2|d k	r~|j\}	}
n|d k	r|j\}	}
}ntd|d krtd gt| j	 }|d kr| 
|}|}rdnd } rdnd }|rdnd }| jr| jrrtd d|
}d}|d d k	rF|d d jd }|| }|d krftj|	|f|jd}n||j}| j| j| j j|jd}| j||	|
f|d	}tt| j	|D ]\}\}}|r||f }| jr| jr fd
d}tjj||||||}n|||| |d}|d }dkr8||d f } r||rLdnd f }q| |}|rt||f }|stdd ||||fD S t||||dS )NzDYou cannot specify both input_ids and inputs_embeds at the same timez5You have to specify either input_ids or inputs_embedsr'   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   r0   r   )r   r   c                    s    fdd}|S )Nc                     s    | dS )N)r   r   r'   )r   )r   r   r   r'   r(   custom_forward  s    zGMptModel.forward.<locals>.create_custom_forward.<locals>.custom_forwardr'   )r   r   r   r   r   r(   create_custom_forward  s    z/MptModel.forward.<locals>.create_custom_forward)r   rY   r   r   rW   Tr   c                 s   s   | ]}|d k	r|V  qd S r   r'   )r   vr'   r'   r(   r   5  s      z#MptModel.forward.<locals>.<genexpr>)Zlast_hidden_stater   rV   
attentions)rA   r   r   r   use_return_dictrf   r+   r   rc   r   r   r   r`   loggerZwarning_oncer   Zonesr   r,   r>   r9   rH   r   	enumerateziputilsr   r   r   )rS   r   r   rY   r   r   r   r   r   r$   rm   r   rV   ZpresentsZall_self_attentionsZall_hidden_statesZseq_length_with_pastr   r<   Zcausal_maskiblockr   r   r   r'   r   r(   ro     s    



	


zMptModel.forward)r/   N)NNNNNNNN)rp   rq   rr   r   rE   r   r>   r   rt   r   int
BoolTensorr   r   r   MPT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOCr   
LongTensorr!   r   ro   ru   r'   r'   rT   r(   r     sD   
 
         r   z
    The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    c                       s4  e Zd ZdgZed fddZdd Zejddd	Z	dej
eej eej eej ee edddZeeeeeeddeej
 eeeejejf df  eej eej eej ee ee ee ee eeej ef d
ddZeeejejf df ej
eeejejf df dddZ  ZS )MptForCausalLMzlm_head.weightr@   c                    s8   t  | t|| _tj|j|jdd| _| 	  d S NFrB   )
rD   rE   r   r   r   rO   rF   r   lm_headr   rR   rT   r'   r(   rE   I  s    
zMptForCausalLM.__init__c                 C   s   | j S r   r   r   r'   r'   r(   get_output_embeddingsQ  s    z$MptForCausalLM.get_output_embeddingsr   c                 C   s
   || _ d S r   r   r   r'   r'   r(   set_output_embeddingsT  s    z$MptForCausalLM.set_output_embeddingsN)r   r   rY   r   r   r   c                 K   sR   |r|d d df  d}|d k	r4|d kr4d|i}nd|i}||||d |S )Nr\   r   r   )r   r   rY   )Z	unsqueezeupdate)rS   r   r   rY   r   r   r   Zmodel_inputsr'   r'   r(   prepare_inputs_for_generationW  s    

z,MptForCausalLM.prepare_inputs_for_generationr   .
r   r   rY   r   labelsr   r   r   r   r   c
              
   C   s   |	dk	r|	n| j j}	| j||||||||	d}
|
d }| |}d}|dk	r||j}|dddddf  }|dddf  }|j\}}}t }||	|| ||	|| }|	s|f|
dd  }|dk	r|f| S |S t
|||
j|
j|
jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        Nr   rY   r   r   r   r   r   r   .r\   r   losslogitsr   rV   r   )rA   r   r   r   r,   r   rl   r+   r   r2   r   r   rV   r   )rS   r   r   rY   r   r   r   r   r   r   transformer_outputsrV   Z	lm_logitsr   Zshift_logitsZshift_labelsr$   rm   r   loss_fctr   r'   r'   r(   ro   s  sD    

 zMptForCausalLM.forward)pastbeam_idxr   c                    s,    fdd|D t fdd|D }|S )aL  
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.

        Output shares the same memory storage as `past`.
        c                    s&   i | ]}|D ]}|j  |j qqS r'   )r   r,   )r   r   Z
past_state)r   r'   r(   
<dictcomp>  s
       z1MptForCausalLM._reorder_cache.<locals>.<dictcomp>c                 3   sB   | ]:}|d   d  |d  j |d  d  |d  j fV  qdS r   )Zindex_selectr   r   )device_to_beam_idxr'   r(   r     s   z0MptForCausalLM._reorder_cache.<locals>.<genexpr>)r   )rS   r   r   Zreordered_pastr'   )r   r   r(   _reorder_cache  s    
zMptForCausalLM._reorder_cache)NNNN)	NNNNNNNNN)rp   rq   rr   Z_tied_weights_keysr   rE   r   r   rt   r   r   r   r!   dictr   r   r   r   r   r   r   r   r   ro   r   ru   r'   r'   rT   r(   r   ?  s^                = r   a  
    The MPT Model transformer with a sequence classification head on top (linear layer).

    [`MptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-1) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    c                       s   e Zd Zed fddZeeeee	e
dd
eej eeeejejf df  eej eej eej ee ee ee ee eeej e	f d
dd	Z  ZS )MptForSequenceClassificationr@   c                    s@   t  | |j| _t|| _tj|j|jdd| _| 	  d S r   )
rD   rE   
num_labelsr   r   r   rO   rF   scorer   rR   rT   r'   r(   rE     s
    
z%MptForSequenceClassification.__init__r   N.r   c
              
   C   s  |	dk	r|	n| j j}	| j||||||||	d}
|
d }| |}|dk	rT|jd }n
|jd }| j jdkrz|dkrztd| j jdkrd}nD|dk	rt|| j j	dd 
|j}nd}t| jj d |tj||jd|f }d}|dk	r| j jdkrT| jdkrd	| j _n:| jdkrL|jtjksB|jtjkrLd
| j _nd| j _| j jd	krt }| jdkr|| | }n
|||}n>| j jd
krt }|||}n| j jdkrt }|||}|	s|f|
dd  }|dk	r|f| S |S t|||
j|
j|
jdS )  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   z=Cannot handle batch sizes > 1 if no padding token is defined.r\   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   )rA   r   r   r   r+   Zpad_token_idrf   r   nesumr,   r   r   warningrU   rp   r"   Zproblem_typer   r   longr   r	   r8   r   r   r   r   rV   r   )rS   r   r   rY   r   r   r   r   r   r   r   rV   r   r$   Zsequence_lengthsZpooled_logitsr   r   r   r'   r'   r(   ro     sp    

$

(

z$MptForSequenceClassification.forward)	NNNNNNNNN)rp   rq   rr   r   rE   r   r   r   r   r   r   r   r   r   r   rt   r!   r   ro   ru   r'   r'   rT   r(   r     s8   	         r   z
    MPT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    c                       s   e Zd Zed fddZeeeee	e
dd
eej eeeejejf df  eej eej eej ee ee ee ee eeej e	f d
dd	Z  ZS )MptForTokenClassificationr@   c                    s   t  | |j| _t|| _t|dr:|jd k	r:|j}n t|drV|jd k	rV|j}nd}t	|| _
t|j|j| _|   d S )Nclassifier_dropoutr|   g?)rD   rE   r   r   r   hasattrr   r|   r   r   rk   rO   rF   
classifierr   )rS   rA   r   rT   r'   r(   rE   N  s    
z"MptForTokenClassification.__init__r   N.r   c
              
   K   s   |	dk	r|	n| j j}	| j||||||||	d}|d }| |}| |}d}|dk	r||j}|j\}}t }||	|| | j
|	|| }|	s|f|dd  }|dk	r|f| S |S t|||j|jdS )r   Nr   r   r0   )r   r   rV   r   )rA   r   r   rk   r   r,   r   r+   r   r2   r   r   rV   r   )rS   r   r   rY   r   r   r   r   r   r   Zdeprecated_argumentsr   rV   r   r   r$   rm   r   r   r'   r'   r(   ro   _  s@    


 z!MptForTokenClassification.forward)	NNNNNNNNN)rp   rq   rr   r   rE   r   r   r   r   r   r   r   r   r   r   rt   r!   r   ro   ru   r'   r'   rT   r(   r   F  s8            r   z
    The MPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD
    (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    c                       s|   e Zd Z fddZeeddeej	 eej
 eej
 eej	 eej	 ee ee ee eeef d	ddZ  ZS )	MptForQuestionAnsweringc                    s2   t  | t|| _t|jd| _|   d S )Nr0   )	rD   rE   r   r   r   rO   rF   
qa_outputsr   rR   rT   r'   r(   rE     s    
z MptForQuestionAnswering.__init__zbatch_size, sequence_lengthN)	r   rY   r   start_positionsend_positionsr   r   r   r   c	                 C   sF  |dk	r|n| j j}| j||||||d}	|	d }
| |
}|jddd\}}|d }|d }d}|dk	r|dk	rt| dkr|d}t| dkr|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s0||f|	dd  }|dk	r,|f| S |S t||||	j|	jd	S )
a  
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        N)rY   r   r   r   r   r   r   r\   rZ   )Zignore_indexr0   )r   start_logits
end_logitsrV   r   )rA   r   r   r   splitr8   rl   rc   rh   clampr   r   rV   r   )rS   r   rY   r   r   r   r   r   r   r   Zsequence_outputr   r   r   Z
total_lossZignored_indexr   Z
start_lossZend_lossr   r'   r'   r(   ro     sJ    	






zMptForQuestionAnswering.forward)NNNNNNNN)rp   rq   rr   rE   r   r   formatr   r   r   ZFloatTensorr!   r   r   r   ro   ru   r'   r'   rT   r(   r     s*           
r   )r/   N):rs   r3   typingr   r   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   r	   r
   r   Z
file_utilsr   r   r   Zmodeling_outputsr   r   r   r   r   Zmodeling_utilsr   r   r   Zconfiguration_mptr   Z
get_loggerrp   r   r   r   Z!MPT_PRETRAINED_MODEL_ARCHIVE_LISTSizer   r   r   r)   rt   r.   r>   r   r?   rv   r   r   ZMPT_START_DOCSTRINGr   r   r   r   r   r   r'   r'   r'   r(   <module>   s   
  
H@32 : jR