U
    9%eJ                  
   @   s  d Z ddlZddlmZ ddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZ dd	lmZ dd
lmZmZmZmZmZmZmZmZ ddlmZ eeZdZ dZ!ddddddddddg
Z"da#dd Z$G dd dej%j&Z'd5ddZ(d6dd Z)G d!d" d"ej*Z+G d#d$ d$ej*Z,G d%d& d&ej*Z-G d'd( d(eZ.eG d)d* d*eZ/eG d+d, d,eZ0d-Z1d.Z2ed/e1G d0d1 d1e.Z3ed2e1G d3d4 d4e.Z4dS )7zPyTorch RWKV model.    N)	dataclass)Path)ListOptionalTupleUnion)nn)CrossEntropyLoss   )PreTrainedModel)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardis_bitsandbytes_availableis_ninja_availableis_torch_cuda_availablelogging   )
RwkvConfigzRWKV/rwkv-4-169m-piler   zRWKV/rwkv-4-430m-pilezRWKV/rwkv-4-1b5-pilezRWKV/rwkv-4-3b-pilezRWKV/rwkv-4-7b-pilezRWKV/rwkv-4-14b-pilezRWKV/rwkv-raven-1b5zRWKV/rwkv-raven-3bzRWKV/rwkv-raven-7bzRWKV/rwkv-raven-14bc                    s   ddl m} tt jjjd d   fdddD }td k	rNtj| krNd S t	d|  d	 d
dddddd|  g}|d|  |t
 t
jk|da| t_d S )Nr   )loadZkernelsrwkvc                    s   g | ]} | qS  r   ).0fZkernel_folderr   e/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/rwkv/modeling_rwkv.py
<listcomp>G   s     z(load_wkv_cuda_kernel.<locals>.<listcomp>)z
wkv_op.cppzwkv_cuda.cuzwkv_cuda_bf16.cuz2Loading CUDA kernel for RWKV at context length of .z
-res-usagez--maxrregcount 60z--use_fast_mathz-O3z-Xptxas -O3z--extra-device-vectorizationz-DTmax=Zwkv_)namesourcesverboseZextra_cuda_cflags)Ztorch.utils.cpp_extensionr   r   __file__resolveparentrwkv_cuda_kernelmax_seq_lengthloggerinfor   Zget_verbosityDEBUG)context_lengthZload_kernelZcuda_kernel_filesflagsr   r   r   load_wkv_cuda_kernelA   s*    	r,   c                   @   s(   e Zd ZedddZedddZdS )	RwkvLinearAttentionNFc              	   C   s  |  \}}}	|tjkr0td| dtj d||	 t|	d dkrhtd| d|	 dt|	d d	|j| _|jjd
ks|jjd
ks|jjd
ks|jjd
krtdt	
|   }|jt	jkr| }| }| }| }| }| }t	j|t	jd}
|s|d k	r|d kr^t	j||	dt	j|jt	jd}|d d d d df  d8  < nt	jdd |D dd }|jt	jkrtj}ntj}||||||
| n*|jt	jkrtjntj}||||||
 | |||||
 |d k	r
dd t	j|dddD }|
| j|fS )NzCannot process a batch with z+ tokens at the same time, use a maximum of z with this model.    r   zThe product of batch size (z) and hidden size (z") needs to be a round multiple of r   cudazUCalling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.memory_formatr
   )dtypedevicer1      籡*Gc                 S   s   g | ]}| d qS r4   Z	unsqueezer   sr   r   r   r      s     z/RwkvLinearAttention.forward.<locals>.<listcomp>)dimc                 S   s   g | ]}| d qS r6   )Zsqueezer8   r   r   r   r      s     )sizer%   r&   
ValueErrorminr2   input_dtyper3   typetorchexpfloat
contiguousfloat16
empty_likecontiguous_formatzerosfloat32catbfloat16Zforward_with_state_bf16Zforward_with_stateZforward_bf16forwardZsave_for_backwardchunkto)ctx
time_decay
time_firstkeyvaluestatereturn_stateZ
batch_sizeZseq_lenhidden_sizeoutputZforward_funcr   r   r   rK   b   sd    





 
zRwkvLinearAttention.forwardc                 C   s   | j }| j\}}}}}tj|tj|tjkr0tjntjd}	tj|tjd}
tj|tjd}tj|tjd}|tjkr|| }|tjkrt	j
nt	j}||||||| |	|
||
 |	||
|||||d d fS )N)r1   r2   r0   )r>   Zsaved_tensorsr@   rE   rF   rJ   rH   rD   rB   r%   Zbackward_bf16backwardrC   rM   )rN   Zg_outputZg_stater>   rO   rP   rQ   rR   rV   Zg_time_decayZg_time_firstZg_keyZg_valueZbackward_funcr   r   r   rW      s@    
zRwkvLinearAttention.backward)NF)N)__name__
__module____qualname__staticmethodrK   rW   r   r   r   r   r-   a   s   >r-   Fc                 C   s  |  \}}}t|}|d krztj|d d df tjd}	tj|d d df tjd}
tj|d d df tjdd }n
|\}	}
}t|  } t|D ]}|d d |f  }|d d |f }t||| }t|| }t|| | }||	 ||  }||
 | }|| |j	|d d |f< t||  |}t||  | }t|| }||	 ||  }	||
 | }
|}q|s|d k	r|	|
|g}||fS )Nr   )r2   r5   )
r;   r@   Z
zeros_likerH   rA   rangerB   maximumrM   r2   )rO   rP   rQ   rR   rS   rT   _Z
seq_lengthrV   Z	num_stateZ	den_stateZ	max_stateZcurrent_indexcurrent_keycurrent_valueZmax_for_outpute1e2	numeratordenominatorZmax_for_stater   r   r   rwkv_linear_attention_cpu   s4    
"

re   c                 C   sd   t dd | |||fD }|ddk}td ks8|s8|rLt| |||||dS t| |||||S d S )Nc                 s   s   | ]}|j jd kV  qdS )r/   N)r3   r?   )r   tr   r   r   	<genexpr>   s     z(rwkv_linear_attention.<locals>.<genexpr>r   rS   rT   )anyr;   r%   re   r-   apply)rO   rP   rQ   rR   rS   rT   Zno_cudaZ	one_tokenr   r   r   rwkv_linear_attention   s
    rk   c                       s2   e Zd Zd
 fdd	ZdddZddd	Z  ZS )RwkvSelfAttentionr   c                    sF  t    || _td k	o"tj|jk}t rbt rb|sbzt|j W n t	k
r`   t
d Y nX || _|j}|jd k	r~|jn|}|| _tt|| _tt|| _ttdd|| _ttdd|| _ttdd|| _td| _tj||dd| _tj||dd| _tj||dd| _tj||dd| _d S )Nz9Could not load the custom CUDA kernel for RWKV attention.r   r   r   r   FZbias)super__init__configr%   r&   r*   r   r   r,   	Exceptionr'   r(   layer_idrU   attention_hidden_sizer   	Parameterr@   emptyrO   rP   time_mix_keytime_mix_valuetime_mix_receptance	ZeroPad2d
time_shiftLinearrQ   rR   
receptancerV   )selfrr   rt   Zkernel_loadedrU   ru   	__class__r   r   rq     s.    
zRwkvSelfAttention.__init__Nc                 C   s  | ddkr4|d k	r4|d d d d d | jf }n:| |}|d k	rn|d d d d d | jf |d d df< || j |d| j   }|| j |d| j   }|| j |d| j   }| |}| |}t	| 
|}|d k	r|d d df |d d d d d | jf< ||||fS Nr   r   rn   )r;   rt   r|   rx   ry   rz   rQ   rR   r@   sigmoidr~   )r   hiddenrS   shiftedrQ   rR   r~   r   r   r   extract_key_value!  s    
(


(z#RwkvSelfAttention.extract_key_valueFc           	         s    j ||d\}}}}|d k	r<t fdd|dd  D nd }t j j||||d\}}|d k	r|d |d d d d d  jf< |d |d d d d d  jf< |d |d	 d d d d  jf<  || |fS )
NrS   c                 3   s&   | ]}|d d d d  j f V  qd S Nrt   r8   r   r   r   rg   6  s     z,RwkvSelfAttention.forward.<locals>.<genexpr>r4   rh   r   r   r
      )r   tuplerk   rO   rP   rt   rV   )	r   r   rS   	use_cacher~   rQ   rR   Zlayer_stater   r   r   r   rK   4  s    *
	   zRwkvSelfAttention.forward)r   )N)NF)rX   rY   rZ   rq   r   rK   __classcell__r   r   r   r   rl     s   
rl   c                       s(   e Zd Zd fdd	ZdddZ  ZS )	RwkvFeedForwardr   c                    s   t    || _|| _|j}|jd k	r,|jnd|j }td| _t	t
dd|| _t	t
dd|| _tj||dd| _tj||dd| _tj||dd| _d S )Nr   rm   r   Fro   )rp   rq   rr   rt   rU   intermediate_sizer   r{   r|   rv   r@   rw   rx   rz   r}   rQ   r~   rR   )r   rr   rt   rU   r   r   r   r   rq   I  s    
zRwkvFeedForward.__init__Nc                 C   s
  | ddkr4|d k	r4|d d d d d | jf }n:| |}|d k	rn|d d d d d | jf |d d df< || j |d| j   }|| j |d| j   }tt| |}| 	|}t
| |}|d k	r|d d df |d d d d d | jf< || |fS r   )r;   rt   r|   rx   rz   r@   ZsquareZrelurQ   rR   r   r~   )r   r   rS   r   rQ   r~   rR   r   r   r   rK   Z  s    
(
(zRwkvFeedForward.forward)r   )NrX   rY   rZ   rq   rK   r   r   r   r   r   r   H  s   r   c                       s&   e Zd Z fddZdddZ  ZS )	RwkvBlockc                    sv   t    || _|| _|dkr2tj|j|jd| _tj|j|jd| _	tj|j|jd| _
t||| _t||| _d S )Nr   )Zeps)rp   rq   rr   rt   r   	LayerNormrU   Zlayer_norm_epsilonpre_lnln1ln2rl   	attentionr   feed_forward)r   rr   rt   r   r   r   rq   o  s    
zRwkvBlock.__init__NFc                 C   sz   | j dkr| |}| j| |||d\}}|| }| j| ||d\}}|| }||f}|rn||f7 }n|d7 }|S )Nr   )rS   r   r   r   )rt   r   r   r   r   r   )r   r   rS   r   output_attentionsr   r   outputsr   r   r   rK   }  s    

zRwkvBlock.forward)NFFr   r   r   r   r   r   n  s   r   c                   @   s<   e Zd ZdZeZdZdgZddgZdZ	dd Z
dd
dZdS )RwkvPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    r   r   rO   rP   Tc              	      s  t |trN|j}|jj}|jj|j ||d  d||  }tjfddt	D |j
j|j
jd}|ddddf } fddt	 D }tj||jj|jjd}tjdd t	 D |jj|jjdd	 }t h ||j_t|jtd
 | |j_t|||j
_t||d
  |j_t|d	| |j_W 5 Q R X nt |tr|j}|jj}|jjd||  }tjfddt	D |j
j|j
jd}|ddddf }t & t|||j
_t|||j_W 5 Q R X dS )zInitialize the weights.r   g      ?c                    s   g | ]}|  qS r   r   r   irU   r   r   r     s     z5RwkvPreTrainedModel._init_weights.<locals>.<listcomp>r2   r3   Nc                    s,   g | ]$}d d| d  dd     qS )   r   gffffff?g?r   )r   h)ru   ratio_0_to_1r   r   r     s   c                 S   s   g | ]}|d  d d  qS )r   r
   r   r   r   r   r   r     s     g      ?g333333?c                    s   g | ]}|  qS r   r   r   r   r   r   r     s     )
isinstancerl   rt   rr   num_hidden_layersrU   ru   r@   Ztensorr\   rx   r2   r3   rO   rP   no_graddataZ	ones_likemathlogpowry   rz   r   )r   modulert   r   Zratio_1_to_almost0Ztime_weightZdecay_speedZzigzagr   )ru   rU   r   r   _init_weights  sZ    	
 
z!RwkvPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r   )r   	RwkvModelgradient_checkpointing)r   r   rR   r   r   r   _set_gradient_checkpointing  s    
z/RwkvPreTrainedModel._set_gradient_checkpointingN)F)rX   rY   rZ   __doc__r   config_classZbase_model_prefixZ_no_split_modulesZ_keep_in_fp32_modulesZsupports_gradient_checkpointingr   r   r   r   r   r   r     s   9r   c                   @   sb   e Zd ZU dZdZejed< dZe	e
ej  ed< dZe	eej  ed< dZe	eej  ed< dS )
RwkvOutputa  
    Class for the RWKV model outputs.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlast_hidden_staterS   hidden_states
attentions)rX   rY   rZ   r   r   r@   FloatTensor__annotations__rS   r   r   r   r   r   r   r   r   r   r     s
   
r   c                   @   st   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dS )RwkvCausalLMOutputa|  
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    NlosslogitsrS   r   r   )rX   rY   rZ   r   r   r   r@   r   r   r   rS   r   r   r   r   r   r   r   r   r     s   
r   a>  

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`RwkvConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a
  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            This is currently not used by `RwkvModel`, but will be supported in the future.

            [What are attention masks?](../glossary#attention-mask)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the last state is returned and can be used to quickly generate the next logits.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Z fddZdd Zdd Zeeee	e
eddeej eej eej eeej  ee ee ee ee eee
f d		d
dZdd Zdd Z  ZS )r   c                    sd   t    t j j| _t fddt j	D | _
t j| _d| _d| _|   d S )Nc                    s   g | ]}t  |d qS )r   )r   )r   idxrr   r   r   r   ]  s     z&RwkvModel.__init__.<locals>.<listcomp>F)rp   rq   r   Z	Embedding
vocab_sizerU   
embeddingsZ
ModuleListr\   r   blocksr   ln_outlayers_are_rescaledr   	post_initr   rr   r   r   r   rq   Y  s     zRwkvModel.__init__c                 C   s   | j S r   r   r   r   r   r   get_input_embeddingsg  s    zRwkvModel.get_input_embeddingsc                 C   s
   || _ d S r   r   r   Znew_embeddingsr   r   r   set_input_embeddingsj  s    zRwkvModel.set_input_embeddings
checkpointoutput_typer   N)		input_idsattention_maskinputs_embedsrS   r   r   output_hidden_statesreturn_dictreturnc	                    sH  d k	rn| j j|d k	r |n| j j}d k	r4n| jsB| j jnd|d k	rR|n| j j}| j| jkrn|   |d k	r d k	rtdn|d kr d krtd d kr| 	| r|d kr 
d| j j| j jf fddtdD }|d  d	8  < | jr"| jr"r"td
 d }	r0dnd }
|r>dnd }t| jD ]\}}| jr| jrfdd}tjj|||	|\}	}}n||	|d\}	}}| jr| j jdkr|d | j j dkr|	d }	|r||	f }rL|
|f }
qL| |	}	|r||	f }|s8tdd |	|||
fD S t|	|||
dS )NFzDYou cannot specify both input_ids and inputs_embeds at the same timez5You have to specify either input_ids or inputs_embedsr   c                    s0   g | ](}t j|d kr jnt j jdqS )r   r   )r@   rG   r2   rH   r3   r   )r   shaper   r   r     s     z%RwkvModel.forward.<locals>.<listcomp>   r   gꌠ9Y>)FzZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...r   c                    s    fdd}|S )Nc                     s    | dS )N)r   r   r   )inputs)r   r   r   r   r   custom_forward  s    zHRwkvModel.forward.<locals>.create_custom_forward.<locals>.custom_forwardr   )r   r   )r   r   )r   r   create_custom_forward  s    z0RwkvModel.forward.<locals>.create_custom_forward)rS   r   r   r   r4   c                 s   s   | ]}|d k	r|V  qd S r   r   )r   xr   r   r   rg     s      z$RwkvModel.forward.<locals>.<genexpr>)r   rS   r   r   )rr   r   r   trainingr   use_return_dictr   _rescale_layersr<   r   r;   rU   r   r\   r   r'   Zwarning_once	enumerater   r@   utilsr   rescale_everyr   r   r   )r   r   r   r   rS   r   r   r   r   r   Zall_self_attentionsZall_hidden_statesr   blockr   r   r   )r   r   r   r   r   rK   m  s    

     



zRwkvModel.forwardc              	   C   sl  | j | j krd S | jjdkr^t . t| jD ]\}}| jr|jj	j
dt|| jj   |jjj
dt|| jj   q6t|jj	j
dr|jj	j
jdt|| jj   |jjj
jdt|| jj   q6t|jj	j
dr| |jj	| | |jj| q6|jj	j
dt|| jj   |jjj
dt|| jj   q6W 5 Q R X | j | _ d S )Nr   r4   SCBquant_state)r   r   rr   r   r@   r   r   r   r   rV   weightZmul_intr   rR   hasattrr   div_ _bnb_4bit_dequantize_and_rescale)r   block_idr   r   r   r   r     s"     ""$ ,zRwkvModel._rescale_layersc                 C   st   t  stdddl}|j|jj|jj}|dt	|| j
j   |jj|ddd|j}t|d| dS )	z
        Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
        be quantized again.
        z/Please install bitsandbytes to use this method.r   Nr4   cpuF)Zrequires_gradr   )r   ImportErrorZbitsandbytesZ
functionalZdequantize_4bitr   r   r   r   r   rr   r   r   Z
Params4bitrM   r3   setattr)r   Ztarget_layerr   ZbnbZdequant_weightsZquant_weightr   r   r   r     s    z*RwkvModel._bnb_4bit_dequantize_and_rescale)NNNNNNNN)rX   rY   rZ   rq   r   r   r   RWKV_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOCr   r@   
LongTensorr   r   boolr   r   rK   r   r   r   r   r   r   r   r   T  s<           
`r   z
    The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    c                       s   e Zd ZdgZ fddZdd Zdd Zdd	d
Zee	e
eeeddeej eej eej eeej  eej ee ee ee ee eeef d
ddZ  ZS )RwkvForCausalLMzhead.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFro   )
rp   rq   r   r   r   r}   rU   r   headr   r   r   r   r   rq     s    
zRwkvForCausalLM.__init__c                 C   s   | j S r   r   r   r   r   r   get_output_embeddings  s    z%RwkvForCausalLM.get_output_embeddingsc                 C   s
   || _ d S r   r   r   r   r   r   set_output_embeddings  s    z%RwkvForCausalLM.set_output_embeddingsNc                 K   sL   |d k	r|d d df  d}|d k	r8|d kr8d|i}nd|i}||d< |S )Nrn   r   r   rS   r7   )r   r   rS   r   kwargsZmodel_inputsr   r   r   prepare_inputs_for_generation  s    
z-RwkvForCausalLM.prepare_inputs_for_generationr   )
r   r   r   rS   labelsr   r   r   r   r   c
              	   C   s   |	dk	r|	n| j j}	| j|||||||	d}
|
d }| |}d}|dk	r||j}|dddddf  }|dddf  }t }||d|	d|d}|	s|f|
dd  }|dk	r|f| S |S t
|||
j|
j|
jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        N)r   rS   r   r   r   r   r   .rn   r   )r   r   rS   r   r   )rr   r   r   r   rM   r3   rC   r	   viewr;   r   rS   r   r   )r   r   r   r   rS   r   r   r   r   r   Zrwkv_outputsr   r   r   Zshift_logitsZshift_labelsZloss_fctrV   r   r   r   rK   '  s:    	
zRwkvForCausalLM.forward)NN)	NNNNNNNNN)rX   rY   rZ   Z_tied_weights_keysrq   r   r   r   r   r   r   r   r   r   r   r@   r   r   r   r   r   r   rK   r   r   r   r   r   r     s@   
         
r   )NF)NF)5r   r   dataclassesr   pathlibr   typingr   r   r   r   r@   Ztorch.utils.checkpointr   Ztorch.nnr	   Zmodeling_utilsr   r   r   r   r   r   r   r   r   r   Zconfiguration_rwkvr   Z
get_loggerrX   r'   r   r   Z"RWKV_PRETRAINED_MODEL_ARCHIVE_LISTr%   r,   ZautogradFunctionr-   re   rk   Modulerl   r   r   r   r   r   ZRWKV_START_DOCSTRINGr   r   r   r   r   r   r   <module>   sf   (

 j
,
F&"J + *