U
    9%e                     @   s  d Z ddlZddlmZmZmZ ddlZddlZddlmZ ddl	m
Z
mZmZ ddlmZ ddlmZmZmZmZ dd	lmZ dd
lmZmZmZmZ ddlmZ eeZdZ dZ!ddgZ"d,ej#ej$ej%e&dddZ'd-ej(ej$ee& dddZ)G dd dej*Z+G dd dej,Z-G dd dej,Z.G dd deZ/dZ0dZ1ed e0G d!d" d"e/Z2ed#e0G d$d% d%e/Z3ed&e0G d'd( d(e/Z4ed)e0G d*d+ d+e/Z5dS ).z PyTorch BioGPT model.    N)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions SequenceClassifierOutputWithPastTokenClassifierOutput)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )BioGptConfigzmicrosoft/biogptr   zmicrosoft/BioGPT-Large)input_ids_shapedtypedevicepast_key_values_lengthc                 C   s   | \}}t j||ft |j|d}t j|d|d}|||d |ddk d ||}|dkrt j	t j
||||d|gdd}|ddddddf |d||| S )zB
    Make causal mask used for bi-directional self-attention.
    r   r   r   r   r   dimN)torchfullfinfominarangesizeZmasked_fill_viewtocatZzerosexpand)r   r   r   r   bsztgt_lenmaskZ	mask_cond r,   i/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/biogpt/modeling_biogpt.py_make_causal_mask9   s    "
 r.   )r+   r   r*   c                 C   sj   |   \}}|dk	r|n|}| ddddddf |d|||}d| }||tjt|jS )z_
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    Nr         ?)r$   r(   r&   Zmasked_fillr   boolr!   r"   )r+   r   r*   r)   src_lenZexpanded_maskZinverted_maskr,   r,   r-   _expand_maskK   s
    *r2   c                       s@   e Zd ZdZeed fddZd	ejed fddZ  Z	S )
 BioGptLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    )num_embeddingsembedding_dimc                    s   d| _ t || j  | d S )N   )offsetsuper__init__)selfr4   r5   	__class__r,   r-   r9   _   s    z)BioGptLearnedPositionalEmbedding.__init__r   )attention_maskr   c                    sN   |  }tj|dd||   d }|dd|df }t || j S )z3`input_ids_shape` is expected to be [bsz x seqlen].r   r   N)longr   Zcumsumtype_asr8   forwardr7   )r:   r=   r   	positionsr;   r,   r-   r@   e   s     z(BioGptLearnedPositionalEmbedding.forward)r   )
__name__
__module____qualname____doc__intr9   r   
LongTensorr@   __classcell__r,   r,   r;   r-   r3   Z   s   r3   c                       s   e Zd ZdZdeeeeed fddZej	eedd	d
Z
dej	eej	 eeej	  eej	 eej	 eeej	eej	 eeej	  f dddZ  ZS )BioGptAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FT)	embed_dim	num_headsdropout
is_decoderbiasc                    s   t    || _|| _|| _|| | _| j| | jkrNtd| j d| d| jd | _|| _t	j
|||d| _t	j
|||d| _t	j
|||d| _t	j
|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩rO   )r8   r9   rK   rL   rM   head_dim
ValueErrorscalingrN   r   Lineark_projv_projq_projout_proj)r:   rK   rL   rM   rN   rO   r;   r,   r-   r9   v   s    

zBioGptAttention.__init__)tensorseq_lenr)   c                 C   s    | ||| j| jdd S )Nr   r6   )r%   rL   rQ   	transpose
contiguous)r:   rY   rZ   r)   r,   r,   r-   _shape   s    zBioGptAttention._shapeN)hidden_stateskey_value_statespast_key_valuer=   layer_head_maskoutput_attentionsreturnc                 C   sx  |dk	}|  \}}	}
| || j }|r\|dk	r\|d jd |jd kr\|d }|d }n|r| | |d|}| | |d|}n|dk	r| | |d|}| | |d|}tj|d |gdd}tj|d |gdd}n(| | |d|}| | |d|}| j	r ||f}|| j
 d| jf}| ||	|j| }|j| }|j| }| d}t||dd}|  || j
 |	|fkrtd|| j
 |	|f d|   |dk	r |  |d|	|fkrtd	|d|	|f d|   ||| j
|	|| }||| j
 |	|}tjj|dd}|dk	r|  | j
fkrhtd
| j
f d|   |dddd||| j
|	| }||| j
 |	|}|r||| j
|	|}||| j
 |	|}nd}tjj|| j| jd}t||}|  || j
 |	| jfkr4td|| j
 |	| jf d|   ||| j
|	| j}|dd}|||	| j}| |}|||fS )z#Input shape: Batch x Time x ChannelNr   r6   r   r   r   z$Attention weights should be of size z	, but is z!Attention mask should be of size z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size )r$   rW   rS   shaper]   rU   rV   r   r'   rN   rL   rQ   r%   ZreshapeZbmmr[   rR   r   
functionalZsoftmaxrM   rf   rK   rX   )r:   r^   r_   r`   r=   ra   rb   Zis_cross_attentionr)   r*   _Zquery_statesZ
key_statesZvalue_statesZ
proj_shaper1   Zattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr,   r,   r-   r@      s~    





" 
zBioGptAttention.forward)rJ   FT)NNNNF)rB   rC   rD   rE   rF   floatr0   r9   r   Tensorr]   r   r   r@   rH   r,   r,   r;   r-   rI   s   s4           rI   c                       sx   e Zd Zed fddZd
ejeej eej eeej  ee	 ee	 eej
eeej
ej
f  f ddd	Z  ZS )BioGptDecoderLayerconfigc                    s   t    |j| _t| j|j|jdd| _|j| _	t
|j | _|j| _t| j| _t| j|j| _t|j| j| _t| j| _d S )NT)rK   rL   rM   rN   )r8   r9   hidden_sizerK   rI   Znum_attention_headsZattention_probs_dropout_prob	self_attnhidden_dropout_probrM   r
   Z
hidden_actactivation_fnactivation_dropoutr   	LayerNormself_attn_layer_normrT   Zintermediate_sizefc1fc2final_layer_normr:   rn   r;   r,   r-   r9     s    
zBioGptDecoderLayer.__init__NFT)r^   r=   ra   r`   rb   	use_cacherc   c                 C   s   |}|  |}|dk	r"|dd nd}| j|||||d\}}	}
tjj|| j| jd}|| }|}| |}| |}| |}tjj|| j	| jd}| 
|}tjj|| j| jd}|| }|f}|r||	f7 }|r||
f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
        Nr6   )r^   r`   r=   ra   rb   rd   )ru   rp   r   rh   rM   rf   rx   rv   rr   rs   rw   )r:   r^   r=   ra   r`   rb   rz   ZresidualZself_attn_past_key_valueZself_attn_weightsZpresent_key_valueoutputsr,   r,   r-   r@   "  s4    






zBioGptDecoderLayer.forward)NNNFT)rB   rC   rD   r   r9   r   rk   r   r   r0   FloatTensorr@   rH   r,   r,   r;   r-   rl     s        rl   c                   @   s.   e Zd ZdZeZdZdZdd Zd
ddZ	d	S )BioGptPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    biogptTc                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  nft |tj
rz|jjjd| jjd |jdk	r|jj|j 	  n&t |tjr|jj	  |jjd dS )zInitialize the weightsrJ   )ZmeanZstdNr/   )
isinstancer   rT   weightdataZnormal_rn   Zinitializer_rangerO   Zzero_	Embeddingpadding_idxrt   Zfill_)r:   moduler,   r,   r-   _init_weightsk  s    

z#BioGptPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S N)r   BioGptModelgradient_checkpointing)r:   r   valuer,   r,   r-   _set_gradient_checkpointing{  s    
z1BioGptPreTrainedModel._set_gradient_checkpointingN)F)
rB   rC   rD   rE   r   config_classZbase_model_prefixZsupports_gradient_checkpointingr   r   r,   r,   r,   r-   r}   a  s   r}   aJ  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`~BioGptConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
            `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
            can choose to directly pass an embedded representation. This is useful if you want more control over how to
            convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z`The bare BioGPT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zed fddZdd Zdd Zdd	 Zee	
d
eeeeddeej eej eej eej eeeej   ee ee ee ee eeef d
ddZ  ZS )r   rm   c                    s   t     | _ j| _ j| _ j| _ j| _	 j
rDt jnd| _t j| j| j	| _t j| j| _t fddt jD | _t| j| _d| _|   d S )Nr/   c                    s   g | ]}t  qS r,   )rl   ).0ri   rm   r,   r-   
<listcomp>  s     z(BioGptModel.__init__.<locals>.<listcomp>F)r8   r9   rn   	layerdroprq   rM   ro   rK   pad_token_idr   Zscale_embeddingmathsqrtembed_scaler   r   
vocab_sizeembed_tokensr3   Zmax_position_embeddingsembed_positionsZ
ModuleListrangeZnum_hidden_layerslayersrt   
layer_normr   	post_initry   r;   rm   r-   r9     s     zBioGptModel.__init__c                 C   s   | j S r   r   r:   r,   r,   r-   get_input_embeddings  s    z BioGptModel.get_input_embeddingsc                 C   s
   || _ d S r   r   r:   r   r,   r,   r-   set_input_embeddings  s    z BioGptModel.set_input_embeddingsc                 C   s`   d }|d dkr$t ||j|j|d}|d k	r\t||j|d d|j}|d krT|n|| }|S )Nr   r   )r   r   )r*   )r.   r   r   r2   r&   )r:   r=   input_shapeinputs_embedsr   Zcombined_attention_maskZexpanded_attn_maskr,   r,   r-   _prepare_decoder_attention_mask  s    z+BioGptModel._prepare_decoder_attention_maskbatch_size, sequence_length
checkpointoutput_typer   N)
	input_idsr=   	head_maskr   past_key_valuesrz   rb   output_hidden_statesreturn_dictrc   c
              	      sF   d k	r n| j j |d k	r |n| j j}d k	r4n| j j|	d k	rH|	n| j j}	|d k	rj|d k	rjtdnN|d k	r|}
|
 }n8|d k	r| d d }|d d d d df }
ntd|d k	r|d d jd nd}|d kr| |
| j	 }|d krt
j|jd d t
j|jd}n<|jd ||d  krTtd|jd  d	||d   d
| ||}| ||||}|| }tjj|| j| jd}| jr| jrrtd d|rdnd } rdnd }d }rdnd }t| jD ]\}}|r||f7 }| jr$t
g }|| jk r$q|d k	r6|| nd }| jr| jr fdd}t
jj|||||d k	r||| nd d }n&||||d k	r|| nd | d}|d }r|| rdnd f7 } r||d f7 }q|r||f7 }| |}r|nd }|	s4tdd |||||fD S t|||||dS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embedsr   r6   r   r   z'The provided attention mask has length z, but its length should be z0 (sum of the lengths of current and past inputs)rd   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr,   c                    s    fdd}|S )Nc                     s    | f S r   r,   )inputs)r   rb   rz   r,   r-   custom_forwardO  s    zJBioGptModel.forward.<locals>.create_custom_forward.<locals>.custom_forwardr,   )r   r   rb   rz   )r   r-   create_custom_forwardN  s    z2BioGptModel.forward.<locals>.create_custom_forward)r=   ra   r`   rb   rz   c                 s   s   | ]}|d k	r|V  qd S r   r,   )r   vr,   r,   r-   	<genexpr>w  s   z&BioGptModel.forward.<locals>.<genexpr>)Zlast_hidden_stater   r^   
attentionscross_attentions) rn   rb   r   rz   use_return_dictrR   r$   rg   r   r   r   Zonesr0   r   r   r   r   rh   rM   rf   r   loggerZwarning_once	enumerater   Zrandr   utilsr   r   tupler   )r:   r   r=   r   r   r   rz   rb   r   r   inputr   r   rA   r^   Zall_hidden_statesZall_self_attnsZall_cross_attentionsZnext_decoder_cacheidxZdecoder_layerZdropout_probabilityr`   r   Zlayer_outputsZ
next_cacher,   r   r-   r@     s    


    

	

zBioGptModel.forward)	NNNNNNNNN)rB   rC   rD   r   r9   r   r   r   r   BIOGPT_INPUTS_DOCSTRINGformatr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOCr   r   rG   r|   r   rk   r0   r   r@   rH   r,   r,   r;   r-   r     s>            
r   zHBioGPT Model with a `language modeling` head on top for CLM fine-tuning.c                       s   e Zd ZdgZ fddZdd Zdd Zee	de
eeed	deej eej eej eej eeeej   eej ee ee ee ee eeef dddZdddZedd Z  ZS )BioGptForCausalLMzoutput_projection.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S NFrP   )
r8   r9   r   r~   r   rT   ro   r   output_projectionr   ry   r;   r,   r-   r9     s    
zBioGptForCausalLM.__init__c                 C   s   | j S r   r   r   r,   r,   r-   get_output_embeddings  s    z'BioGptForCausalLM.get_output_embeddingsc                 C   s
   || _ d S r   r   )r:   Znew_embeddingsr,   r,   r-   set_output_embeddings  s    z'BioGptForCausalLM.set_output_embeddingsr   r   N)r   r=   r   r   r   labelsrz   rb   r   r   rc   c                 C   s   |
dk	r|
n| j j}
| j||||||||	|
d	}|d }| |}d}|dk	r|ddddddf  }|ddddf  }t }||d| j j|d}|
s|f|dd  }|dk	r|f| S |S t|||j	|j
|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        N)r=   r   r   r   rz   rb   r   r   r   r   r   )losslogitsr   r^   r   r   )rn   r   r~   r   r\   r   r%   r   r   r   r^   r   r   )r:   r   r=   r   r   r   r   rz   rb   r   r   r{   Zsequence_outputZprediction_scoresZlm_lossZshifted_prediction_scoresloss_fctoutputr,   r,   r-   r@     s>    
zBioGptForCausalLM.forwardc                 K   sX   |r|d d df  d}|d k	r4|d kr4d|i}nd|i}||||dd |S )Nr   r   r   rz   )r=   r   rz   )Z	unsqueezeupdateget)r:   r   r=   r   r   kwargsZmodel_inputsr,   r,   r-   prepare_inputs_for_generation  s    
z/BioGptForCausalLM.prepare_inputs_for_generationc                    s.   d}| D ] }|t  fdd|D f7 }q|S )Nr,   c                 3   s"   | ]}| d  |jV  qdS )r   N)Zindex_selectr&   r   )r   Z
past_statebeam_idxr,   r-   r     s     z3BioGptForCausalLM._reorder_cache.<locals>.<genexpr>)r   )r   r   Zreordered_pastZ
layer_pastr,   r   r-   _reorder_cache  s    z BioGptForCausalLM._reorder_cache)
NNNNNNNNNN)NN)rB   rC   rD   Z_tied_weights_keysr9   r   r   r   r   r   r   r   r   r   r   r   rG   r|   r   rk   r0   r   r@   r   staticmethodr   rH   r,   r,   r;   r-   r     sL   	          
:   
r   z
    BioGPT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    c                       s   e Zd Z fddZeeeeee	dde
ej e
ej e
ej e
ej e
eeej   e
ej e
ej e
e e
e e
e e
e eeef dddZ  ZS )	BioGptForTokenClassificationc                    sj   t  | |j| _t|| _t|dr:|jd k	r:|j}n|j}t	|| _
t|j|j| _|   d S )Nclassifier_dropout)r8   r9   
num_labelsr   r~   hasattrr   rq   r   ZDropoutrM   rT   ro   
classifierr   )r:   rn   r   r;   r,   r-   r9     s    
z%BioGptForTokenClassification.__init__r   N)r   token_type_idsr=   r   r   r   r   rz   rb   r   r   rc   c                 C   s
  |dk	r|n| j j}| j|||||||	|
|d	}|d }| |}| |}d}|dk	rt }|dk	r|ddk}|d| j}t	||dt
|j|}|||}n||d| j|d}|s|f|dd  }|dk	r|f| S |S t|||j|jdS )  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r=   r   r   rz   rb   r   r   r   r   r   r6   )r   r   r^   r   )rn   r   r~   rM   r   r   r%   r   r   whererY   Zignore_indexr?   r   r^   r   )r:   r   r   r=   r   r   r   r   rz   rb   r   r   transformer_outputsr^   r   r   r   Zactive_lossZactive_logitsZactive_labelsr   r,   r,   r-   r@     sJ    

  z$BioGptForTokenClassification.forward)NNNNNNNNNNN)rB   rC   rD   r9   r   r   r   r   r   r   r   r   rG   r|   r   rk   r0   r   r@   rH   r,   r,   r;   r-   r     s@              
r   a  
    The BioGpt Model transformer with a sequence classification head on top (linear layer).

    [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it is required to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    c                       s   e Zd Zed fddZeeeee	e
ddeej eej eej eeeej   eej eej ee ee ee ee eee	f dddZd	d
 Zdd Z  ZS )BioGptForSequenceClassificationrm   c                    s@   t  | |j| _t|| _tj|j| jdd| _| 	  d S r   )
r8   r9   r   r   r~   r   rT   ro   scorer   ry   r;   r,   r-   r9   e  s
    
z(BioGptForSequenceClassification.__init__r   N)r   r=   r   r   r   r   rz   rb   r   r   rc   c                 C   s   |
dk	r|
n| j j}
| j||||||||	|
d	}|d }| |}|dk	r^|jdd \}}n|jdd \}}| j jdkrd}nD|dk	rt|| j jdd 	|j
}nd}t| jj d |tj||j
d|f }d}|dk	r| j jdkrJ| jdkrd	| j _n:| jdkrB|jtjks8|jtjkrBd
| j _nd| j _| j jd	krt }| jdkr~|| | }n
|||}nN| j jd
krt }||d| j|d}n| j jdkrt }|||}|
s|f|dd  }|dk	r|f| S |S t|||j|j|jdS )r   Nr   r   r6   r   r   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`r   Z
regressionZsingle_label_classificationZmulti_label_classification)r   r   r   r^   r   )rn   r   r~   r   rg   r   r   nesumr&   r   r   warningr<   rB   r#   Zproblem_typer   r   r>   rF   r   Zsqueezer   r%   r   r   r   r^   r   )r:   r   r=   r   r   r   r   rz   rb   r   r   r   r^   r   Z
batch_sizeZsequence_lengthZpooled_logitsr   r   r   r,   r,   r-   r@   n  sn    
$

(

z'BioGptForSequenceClassification.forwardc                 C   s   | j jS r   r~   r   r   r,   r,   r-   r     s    z4BioGptForSequenceClassification.get_input_embeddingsc                 C   s   || j _d S r   r   r   r,   r,   r-   r     s    z4BioGptForSequenceClassification.set_input_embeddings)
NNNNNNNNNN)rB   rC   rD   r   r9   r   r   r   r   r   r   r   r   rG   r|   r   rk   r0   r   r@   r   r   rH   r,   r,   r;   r-   r   U  s@   	          
Xr   )r   )N)6rE   r   typingr   r   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   r   r   r   r   Zconfiguration_biogptr   Z
get_loggerrB   r   r   r   Z$BIOGPT_PRETRAINED_MODEL_ARCHIVE_LISTSizer   r   rF   r.   rk   r2   r   r3   ModulerI   rl   r}   ZBIOGPT_START_DOCSTRINGr   r   r   r   r   r,   r,   r,   r-   <module>   sj   
	     T6 A qU