U
    9%e                     @   s  d Z ddlZddlmZmZmZmZ ddlZddlm	  m
Z ddlZddlm	Z	 ddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZmZmZ ddlmZ eeZdZ dZ!dgZ"G dd de	j#Z$G dd de	j#Z%G dd de	j#Z&G dd de	j#Z'G dd de	j#Z(G dd de	j#Z)G dd de	j#Z*G dd de	j#Z+G dd  d e	j#Z,G d!d" d"e	j#Z-G d#d$ d$e	j#Z.G d%d& d&eZ/d'Z0d(Z1ed)e0G d*d+ d+e/Z2ed,e0G d-d. d.e/Z3dS )/z PyTorch CPMAnt    N)ListOptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)BaseModelOutputWithPastCausalLMOutputWithPast)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )CpmAntConfigzopenbmb/cpm-ant-10br   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )CpmAntLayerNormzv
    We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details."
    configc                    s2   t    |j| _|j| _tt|j| _	d S N)
super__init__epshidden_sizedim_normr   	Parametertorchemptyweightselfr   	__class__ i/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/cpmant/modeling_cpmant.pyr   2   s    
zCpmAntLayerNorm.__init__hidden_statesc                 C   s^   | d| jkrtd|j}|tjdjddd}|t	|| j
  || j }|S )f
        Args:
            hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
        z'hidden_states.size(-1) != self.dim_norm   T)dimZkeepdim)sizer   AssertionErrordtypetor   Zfloat32powmeanZrsqrtr   r   )r!   r'   Z	old_dtypeZvariancer$   r$   r%   forward9   s     zCpmAntLayerNorm.forward)
__name__
__module____qualname____doc__r   r   r   Tensorr2   __classcell__r$   r$   r"   r%   r   -   s   r   c                
       s^   e Zd Zed fddZd	ejejejejee	 ee
ejejf  ee	 dddZ  ZS )
CpmAntAttentionr   c                    s   t    |j| _|j| _|j| _tj| j| j| j dd| _	tj| j| j| j dd| _
tj| j| j| j dd| _tj| j| j | jdd| _tjjdd| _|jd k	rtjj|jd| _nd | _d S )NFbiasr)   r+   )p)r   r   r   Z	dim_modelnum_attention_heads	num_headsdim_headr   Linear	project_q	project_k	project_vattention_outr   ZSoftmaxsoftmax	dropout_pDropoutdropoutr    r"   r$   r%   r   G   s    

zCpmAntAttention.__init__FN)hidden_q	hidden_kvattention_maskposition_biasoutput_attentionspast_key_values	use_cachec              	   C   s
  | d}| d}	| d}
| |}| |}| |}|||	| j| jdddd}|||
| j| jdddd}|||
| j| jdddd}|dk	rtj	|d |gdd}tj	|d |gdd}| d}
t
||ddt| j }|| }t|||d|	|
td	ktjtd
|j|jd}| |}t|||d|	|
td	ktjd|j|jd}|r|}nd}| jdk	r| |}t
||}||| j|	| jdddd}| ||	| j| j }| |}d}|r ||f}|||fS )a  
        Args:
            hidden_q (`torch.Tensor`):
                Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
            hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
                Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
            attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
                Avoid invalid areas to participate in the calculation of self-attention.
            position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
                Provide positional information to self-attention block.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers.
            past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
                Cached past key and value projection states.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
        r   r   r*   r   Nr<   r)   Fz-inf)devicer.   )r,   rB   rC   rD   viewr?   r@   permuter   catmatmulZ	transposemathsqrtZmasked_filltensorZscalar_tensorfloatrR   r.   rF   rI   
contiguousrE   )r!   rJ   rK   rL   rM   rN   rO   rP   Z
batch_sizelen_qZlen_kquerykeyvalueZscoreattn_weightsr$   r$   r%   r2   Z   sN    





   
 

 
zCpmAntAttention.forward)FNN)r3   r4   r5   r   r   r   r7   Z
BoolTensorr   boolr   r2   r8   r$   r$   r"   r%   r9   F   s      r9   c                	       s^   e Zd Zed fddZd	ejejeej ee ee	ejejf  ee dddZ
  ZS )
CpmAntSelfAttentionBlockr   c                    s@   t    t|| _t|| _|jr6tj	|j| _
nd | _
d S r   )r   r   r   layernorm_before_attentionr9   self_attentionrG   r   r   rH   rI   r    r"   r$   r%   r      s    


z!CpmAntSelfAttentionBlock.__init__NFr'   rL   rM   rN   rO   rP   c           
   	   C   sP   |  |}| |||||||}|\}}}	| jdk	r>| |}|| }|||	fS )a  
        Args:
            hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
                Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
            attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
                Avoid invalid areas to participate in the calculation of self-attention.
            position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
                Provide positional information to self-attention block.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers.
            past_key_values (`Tuple(torch.FloatTensor)`, *optional*):
                Cached past key and value projection states.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
        N)rc   rd   rI   )
r!   r'   rL   rM   rN   rO   rP   outputsr`   current_key_valuer$   r$   r%   r2      s    
      


z CpmAntSelfAttentionBlock.forward)NFNNr3   r4   r5   r   r   r   r7   r   ra   r   r2   r8   r$   r$   r"   r%   rb      s       rb   c                       s2   e Zd Zed fddZejdddZ  ZS )CpmAntDenseGatedACTr   c                    sF   t    tj|j|jdd| _tj|j|jdd| _tj	 | _
d S NFr:   )r   r   r   rA   r   dim_ffw_0w_1r   ZGELUactr    r"   r$   r%   r      s    
zCpmAntDenseGatedACT.__init__r&   c                 C   s&   |  | |}| |}|| }|S )zTransform an input tensor from one feature space to another via a nonlinear operation

        Args:
            hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
        )rn   rl   rm   )r!   r'   Z
gate_scorer$   r$   r%   r2      s    
zCpmAntDenseGatedACT.forward	r3   r4   r5   r   r   r   r7   r2   r8   r$   r$   r"   r%   ri      s   ri   c                       s2   e Zd Zed fddZejdddZ  ZS )CpmAntFeedForwardr   c                    sP   t    t|| _|jd k	r0tj|j| _nd | _tj	|j
|jdd| _d S rj   )r   r   ri   w_inrG   r   r   rH   rI   rA   rk   r   w_outr    r"   r$   r%   r      s    


zCpmAntFeedForward.__init__r&   c                 C   s,   |  |}| jdk	r| |}| |}|S )r(   N)rq   rI   rr   r!   r'   r$   r$   r%   r2      s
    



zCpmAntFeedForward.forwardro   r$   r$   r"   r%   rp      s   
rp   c                       s2   e Zd Zed fddZejdddZ  ZS )CpmAntFFNBlockr   c                    s@   t    t|| _t|| _|jr6tj	|j| _
nd | _
d S r   )r   r   r   layernorm_before_ffnrp   ffnrG   r   r   rH   rI   r    r"   r$   r%   r     s    


zCpmAntFFNBlock.__init__r&   c                 C   s4   |  |}| |}| jdk	r(| |}|| }|S )z
        Args:
            hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
                Hidden states before feed forward layer.
        N)ru   rv   rI   )r!   r'   Z
ln_outputsrf   r$   r$   r%   r2     s    	



zCpmAntFFNBlock.forwardro   r$   r$   r"   r%   rt     s   rt   c                	       s^   e Zd Zed fddZd	ejejeej ee ee	ejejf  ee dddZ
  ZS )
CpmAntTransformerBlockr   c                    s"   t    t|| _t|| _d S r   )r   r   rb   self_attrt   rv   r    r"   r$   r%   r   )  s    

zCpmAntTransformerBlock.__init__NFre   c           	      C   s4   | j ||||||d}|\}}}| |}|||fS )a  
        Args:
            hidden_states (`torch.Tensor`):
                Input to the layer of shape `(batch, seq_len, dim_model)`
            attention_mask (`torch.Tensor`):
                Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
            position_bias (`torch.Tensor`):
                Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers.
            past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
                Cached past key and value projection states
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
        )rL   rM   rN   rO   rP   )rx   rv   )	r!   r'   rL   rM   rN   rO   rP   r`   rg   r$   r$   r%   r2   .  s    	

zCpmAntTransformerBlock.forward)NFNNrh   r$   r$   r"   r%   rw   (  s   	    rw   c                
       s`   e Zd Zed fddZdejejejee ee ee	ejejf  ee dddZ
  ZS )	CpmAntEncoderr   c                    s@   t     j| _t fddt| jD | _t | _	d S )Nc                    s   g | ]}t  qS r$   )rw   ).0Zithr   r$   r%   
<listcomp>[  s     z*CpmAntEncoder.__init__.<locals>.<listcomp>)
r   r   Znum_hidden_layers
num_layersr   Z
ModuleListrangelayersr   output_layernormr    r"   r   r%   r   X  s    
 zCpmAntEncoder.__init__N)r'   rL   rM   rN   output_hidden_statesrO   rP   c              	   C   s   |rdnd}|rdnd}	|r dnd}
t | jD ]`\}}|rD||f7 }||||||rZ|| nd|d}|\}}}|r||	|f7 }	|dk	r.|
|f }
q.| |}|r||f7 }||
||	fS )a%  
        Args:
            hidden_states (`torch.Tensor`):
                Input to the layer of shape `(batch, seq_len, dim_model)`
            attention_mask (`torch.Tensor`):
                Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
            position_bias (`torch.Tensor`):
                Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers.
            past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
                Cached past key and value projection states
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
        r$   N)rN   rO   rP   )	enumerater~   r   )r!   r'   rL   rM   rN   r   rO   rP   all_hidden_statesZall_self_attnsZcurrent_key_valuesilayerZlayer_outputsr`   rg   r$   r$   r%   r2   _  s.    




zCpmAntEncoder.forward)NNNNrh   r$   r$   r"   r%   ry   W  s       ry   c                       s0   e Zd Z fddZejejdddZ  ZS )CpmAntIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r   r   r   rA   r   intermediate_sizedense
isinstanceZ
hidden_actstrr	   intermediate_act_fnr    r"   r$   r%   r     s
    
zCpmAntIntermediate.__init__)r'   returnc                 C   s   |  |}| |}|S r   )r   r   rs   r$   r$   r%   r2     s    

zCpmAntIntermediate.forwardr3   r4   r5   r   r   r7   r2   r8   r$   r$   r"   r%   r     s   r   c                       sP   e Zd Zed fddZejejejejdddZdd ZdddZ	  Z
S )CpmAntSegmentPositionEmbeddingr   c                    sR   t    |j| _|j| _|j| _|j| _	t
t|j|j |j |j| _d S r   )r   r   r>   r?   Zposition_bias_num_bucketsnum_bucketsZposition_bias_max_distancemax_distancesegment_typesnum_segmentsr   r   r   r   relative_attention_biasr    r"   r$   r%   r     s    
z'CpmAntSegmentPositionEmbedding.__init__)key_pos	query_poskey_segmentquery_segmentc              	   C   s  t  ~ |d}|d}|d}|d|dkr`td|d d|d d||dks|||dkrtd| d|d d||dkrtd| d|d d||d|}|||d}||d|}|||d}| ||}|| j }| jt j	|t j
|jd	d d d f t j	|t j
|jd	d d d f  | j| jd
}	t ||k|	d d d d d f |}W 5 Q R X t|| j}
|
dddd }
|
S )Nr   r   z>key_pos.size(0) should be equal to query_pos.size(0), but got z and !z7keylen should be equal to key_segment.size(1), but got z;querylen should be equal to query_segment.size(1), but got r)   r.   rR   )r   r   r   r*   )r   Zno_gradr,   r-   ZszierS   !_segment_relative_position_bucketr   _position_bucketarangeint32rR   r   whereFZ	embeddingr   rT   r[   )r!   r   r   r   r   batchZkeylenZquerylenZrelative_position_bucketZabsolute_position_bucketZembedsr$   r$   r%   r2     sJ    



z&CpmAntSegmentPositionEmbedding.forwardc                 C   s   || j  | S r   )r   )r!   r   r   r$   r$   r%   r     s    z@CpmAntSegmentPositionEmbedding._segment_relative_position_bucket       c                 C   s   d}|d }|dk tj| }t|}|d }||k }|t| | t||  ||   tj }t|t||d }|t	|| tj|7 }|S )Nr   r*   r   )
r/   r   r   abslogrZ   rW   minZ	full_liker   )r!   Zrelative_positionr   r   Zrelative_bucketsZ	max_exactZis_smallZrelative_postion_if_larger$   r$   r%   r     s*    
z/CpmAntSegmentPositionEmbedding._position_bucket)r   r   )r3   r4   r5   r   r   r   r7   r2   r   r   r8   r$   r$   r"   r%   r     s   4r   c                       s4   e Zd Z fddZejejejdddZ  ZS )CpmAntOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S )N)r   )r   r   r   rA   r   r   r   	LayerNormZlayer_norm_epsrH   Zhidden_dropout_probrI   r    r"   r$   r%   r     s    
zCpmAntOutput.__init__)r'   input_tensorr   c                 C   s&   |  |}| |}| || }|S r   )r   rI   r   )r!   r'   r   r$   r$   r%   r2     s    

zCpmAntOutput.forwardr   r$   r$   r"   r%   r     s   r   c                   @   s.   e Zd ZdZeZdZdZdd Zd
ddZ	d	S )CpmAntPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    cpmantTc                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  nt |tj
rz|jjjd| jjd |jdk	r|jj|j 	  nbt |tjr|jj	  |jjd n:t |tr|jjd n t |tr|jjjd| jjd dS )zInitialize the weightsg        )r1   ZstdNg      ?)r   r   rA   r   dataZnormal_r   Zinit_stdr;   Zzero_	EmbeddingZpadding_idxr   Zfill_r   r   r   )r!   moduler$   r$   r%   _init_weights  s    



z#CpmAntPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r   )r   ry   Zgradient_checkpointing)r!   r   r_   r$   r$   r%   _set_gradient_checkpointing/  s    
z1CpmAntPreTrainedModel._set_gradient_checkpointingN)F)
r3   r4   r5   r6   r   config_classZbase_model_prefixZsupports_gradient_checkpointingr   r   r$   r$   r$   r%   r     s   r   aB  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters
        config ([`~CpmAntConfig`]): Model configuration class with all the parameters of the
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zTThe bare CPMAnt Model outputting raw hidden-states without any specific head on top.c                       s   e Zd Zed fddZdd Zdd Zdd	 Zee	e
eeed
deej ee ee eeeej   ee ee eeej ef dddZ  ZS )CpmAntModelr   c                    sl   t  | t|| _t|j|j| _t|j	|j
|j  |j| _t|| _|j| _|j	| _	|   d S r   )r   r   ry   encoderr   r   r   r   segment_embedding
vocab_sizeprompt_typesprompt_lengthinput_embeddingr   rM   	post_initr    r"   r$   r%   r   \  s    
 
zCpmAntModel.__init__c                 C   s   | j S r   r   r!   r$   r$   r%   get_input_embeddingsi  s    z CpmAntModel.get_input_embeddingsc                 K   s
   || _ d S r   r   )r!   
embeddingskwargsr$   r$   r%   set_input_embeddingsl  s    z CpmAntModel.set_input_embeddingsc                 C   s>  | d}| d}|j}tj||dtj||dddk}|d d d d d f |d d d d d f  |d||@ B }	|	|d d d d d f |d d d d d f k@ }	tjtt|| j	 d d d |dd d d f 
|d|d d d f k }
tjtj|| j	|d |
fdd}
|
||d|
|d|@ |	@ }	|	S )Nr   r   )rR   r)   r<   )r,   rR   r   r   rS   Zlogical_notrY   listr}   r   repeatrU   Zonesra   )r!   	input_idsspancontextlengthr   ZseqlenrR   Zdirectional_mask_2drL   Zmask_1dr$   r$   r%   _prepare_attention_masko  s    

$&08$ z#CpmAntModel._prepare_attention_mask
checkpointoutput_typer   N)r   rN   r   rO   rP   return_dictr   c              	   K   sX  |d k	r|n| j j}|d k	r |n| j j}|d k	r4|n| j j}|d k	rH|n| j j}|jtjkrh|tj}|j|j	 }}	t
|dkddj||	d}
|
dkdj||	d}tjtj| jd | j | jd | j ||	d|dd|fdd}| \}}tjtj|| j||	d|
fdd}
tj||fd||	d}tj|||	d|d}tj||fd||	d}|d krd}td g| jj }| }| |}| |
}|| }n@|d d d}| |
}| ||d d dd d d f  }| ||||}| |||
|
}|d d |d d d f }|d d d d |d d d f }|d d |d d d f }| |||||||\}}}}|dkr(|d d | jd d d f }|d k	rd	}|D ]0}||d d d d | jd | jd f f7 }q|}|d k	r(d	}|D ](}||d d | jd d d f f7 }q|}|sHtd
d ||||fD S t||||dS )Nr   r*   r   r)   r   r   r<   rQ   r$   c                 s   s   | ]}|d k	r|V  qd S r   r$   )rz   vr$   r$   r%   	<genexpr>  s     z&CpmAntModel.forward.<locals>.<genexpr>)last_hidden_staterO   r'   
attentions)r   rN   r   use_return_dictrP   r.   r   r   r/   rR   r   sumrU   r   r   r   r   r,   zerosfulltupler   r|   r[   r   r   r   rM   r
   )r!   r   rN   r   rO   rP   r   r   r.   rR   segmentr   r   Z
seq_lengthr   positionr   Zpast_lengthr'   Zsegment_statesrL   rM   Zpresent_key_valuesr   Zall_attentionsZnew_attentionsZ	attentionZnew_hidden_statesZhidden_stater$   r$   r%   r2     s     	"




$ 


.
&
zCpmAntModel.forward)NNNNNN)r3   r4   r5   r   r   r   r   r   r   CPMANT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr
   _CONFIG_FOR_DOCr   r   r7   ra   r   r   r2   r8   r$   r$   r"   r%   r   W  s2         r   zy
    The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
    c                       s   e Zd ZdgZed fddZeeee	e
eddeej eeeejejf   ee ee ee eej ee eej eee
f d	dd	Zd
d Zdd Zdd Zdd Zdd Zdd Z  ZS )CpmAntForCausalLMzlm_head.weightr   c                    sD   t  | t|| _tj|j|j|j|j	  dd| _
|   d S rj   )r   r   r   r   r   rA   r   r   r   r   lm_headr   r    r"   r$   r%   r     s    
  zCpmAntForCausalLM.__init__r   N)	r   rO   rP   rN   r   labelsr   rL   r   c	                 K   s   |dk	r|n| j j}| ||||||}
|r2|
jn|
d }| |}d}|dk	rtt }||d|d|d}|s|f|
dd  }|dk	r|f| S |S t|||
j	|
j
|
jdS )u;
  
        Args:
            input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
                Indices of input sequence tokens in the vocabulary.

                Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers.
            labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                CPMAnt will process attention mask automatically, this parameter is a dummy parameter for
                text-generation pipeline.

        Example:

        Text Generation with CpmAntForCausalLM.
        ```python
        >>> from transformers import CPMAntTokenizer, CpmAntForCausalLM

        >>> texts = "今天天气不错，"
        >>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b")
        >>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b")
        >>> input_ids = tokenizer(texts, return_tensors="pt")
        >>> outputs = model.generate(**input_ids)
        >>> output_texts = tokenizer.batch_decode(outputs)
        >>> print(output_texts)
        ['今天天气不错，阳光明媚，我和妈妈一起去超市买东西。\n在超市里，我看到了一个很好玩的玩具，它的名字叫“机器人”。它有一个圆圆的脑袋，两只圆圆的眼睛，还有一个圆圆的']
        ```
        Nr   r)   r   )losslogitsrO   r'   r   )r   r   r   r   r   r   rS   r,   r   rO   r'   r   )r!   r   rO   rP   rN   r   r   r   rL   r   Zmodel_outputr'   r   r   Z	loss_funcoutputr$   r$   r%   r2     s2    =     
zCpmAntForCausalLM.forwardc                 C   s   | j jS r   r   r   r   r$   r$   r%   r   R  s    z&CpmAntForCausalLM.get_input_embeddingsc                 C   s   || j _d S r   r   )r!   r   r$   r$   r%   r   U  s    z&CpmAntForCausalLM.set_input_embeddingsc                 C   s   | j S r   r   r   r$   r$   r%   get_output_embeddingsX  s    z'CpmAntForCausalLM.get_output_embeddingsc                 C   s
   || _ d S r   r   )r!   Znew_embeddingsr$   r$   r%   set_output_embeddings[  s    z'CpmAntForCausalLM.set_output_embeddingsc                 K   s8   |  }d|kr tdd|d< ||d |dd dS )NrL   r   rP   rO   )r   rP   rO   )intr   r   get)r!   r   r   r$   r$   r%   prepare_inputs_for_generation^  s    
z/CpmAntForCausalLM.prepare_inputs_for_generationc                 C   s<   dd |D }|D ]$}|d | |d< |d | |d< q|S )Nc                 S   s    g | ]}|d k	rt |n|qS r   )r   )rz   Zeachr$   r$   r%   r{   k  s     z4CpmAntForCausalLM._reorder_cache.<locals>.<listcomp>r   r   r$   )r!   rO   Zbeam_idxZkey_value_layerr$   r$   r%   _reorder_cachej  s
    z CpmAntForCausalLM._reorder_cache)NNNNNNNN)r3   r4   r5   Z_tied_weights_keysr   r   r   r   r   r   r   r   r   r   r7   r   r   ra   r   r2   r   r   r   r   r   r   r8   r$   r$   r"   r%   r     sB   
        
Qr   )4r6   rW   typingr   r   r   r   r   Ztorch.nn.functionalr   Z
functionalr   Ztorch.utils.checkpointZtorch.nnr   Zactivationsr	   Zmodeling_outputsr
   r   Zmodeling_utilsr   utilsr   r   r   r   Zconfiguration_cpmantr   Z
get_loggerr3   loggerr   r   Z$CPMANT_PRETRAINED_MODEL_ARCHIVE_LISTModuler   r9   rb   ri   rp   rt   rw   ry   r   r   r   r   ZCPMANT_START_DOCSTRINGr   r   r   r$   r$   r$   r%   <module>   sT   
h1/B]! 