U
    9%eA                     @   s,  d Z ddlZddlmZ ddlmZ ddlmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZmZmZ dd	lmZ e rdd
lmZ ddlmZ ddlmZmZmZmZmZmZmZmZm Z  ddl!m"Z" ddl#m$Z$ ddlm%Z%m&Z&m'Z'm(Z(m)Z) ddl*m+Z+ e(,e-Z.dZ/dZ0ddgZ1dd Z2dd Z3dd Z4G dd dej5Z6G dd dej5Z7G dd  d ej5Z8G d!d" d"ej5Z9G d#d$ d$ej5Z:G d%d& d&ej5Z;G d'd( d(ej5Z<G d)d* d*ej5Z=G d+d, d,ej5Z>G d-d. d.ej5Z?G d/d0 d0ej5Z@G d1d2 d2ej5ZAG d3d4 d4ej5ZBG d5d6 d6ej5ZCG d7d8 d8e"ZDeG d9d: d:eZEd;ZFd<ZGe&d=eFG d>d? d?eDZHe&d@eFG dAdB dBeDZIe&dCeFG dDdE dEeDZJe&dFeFG dGdH dHeDZKe&dIeFG dJdK dKeDZLe&dLeFG dMdN dNeDZMe&dOeFG dPdQ dQeDZNe&dReFG dSdT dTeDZOdS )Uz PyTorch FNet model.    N)	dataclass)partial)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )is_scipy_available)linalg)ACT2FN)	BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputModelOutputMultipleChoiceModelOutputNextSentencePredictorOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
FNetConfigzgoogle/fnet-baser    zgoogle/fnet-largec                 C   s:   | j d }|d|d|f }| tj} td| ||S )z4Applies 2D matrix multiplication to 3D input arrays.r   Nzbij,jk,ni->bnk)shapetypetorch	complex64Zeinsum)xmatrix_dim_onematrix_dim_two
seq_length r)   e/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/fnet/modeling_fnet.py_two_dim_matmulF   s    
r+   c                 C   s   t | ||S N)r+   )r%   r&   r'   r)   r)   r*   two_dim_matmulO   s    r-   c                 C   s4   | }t t| jdd D ]}tjj||d}q|S )z
    Applies n-dimensional Fast Fourier Transform (FFT) to input array.

    Args:
        x: Input n-dimensional array.

    Returns:
        n-dimensional Fourier transform of input n-dimensional array.
    r   N)axis)reversedrangendimr#   fft)r%   outr.   r)   r)   r*   fftnT   s    
r4   c                       s*   e Zd ZdZ fddZdddZ  ZS )FNetEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j|j| _t|j| _| jdt|jddd | jdtj| j tjddd d S )	N)padding_idxZepsposition_ids)r   F)
persistenttoken_type_idsdtype)super__init__r   	Embedding
vocab_sizehidden_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsLinear
projectionDropouthidden_dropout_probdropoutregister_bufferr#   Zarangeexpandzerosr8   sizelongselfconfig	__class__r)   r*   r?   g   s"    
    zFNetEmbeddings.__init__Nc                 C   s   |d k	r|  }n|  d d }|d }|d krH| jd d d |f }|d krt| dr| jd d d |f }||d |}|}ntj|tj| jjd}|d kr| 	|}| 
|}	||	 }
| |}|
|7 }
| |
}
| |
}
| |
}
|
S )Nr9   r   r;   r   r=   device)rQ   r8   hasattrr;   rO   r#   rP   rR   rY   rC   rF   rE   rG   rJ   rM   )rT   	input_idsr;   r8   inputs_embedsinput_shaper(   buffered_token_type_ids buffered_token_type_ids_expandedrF   
embeddingsrE   r)   r)   r*   forward}   s,    







zFNetEmbeddings.forward)NNNN)__name__
__module____qualname____doc__r?   ra   __classcell__r)   r)   rV   r*   r5   d   s   r5   c                       s,   e Zd Z fddZdd Zdd Z  ZS )FNetBasicFourierTransformc                    s   t    | | d S r,   )r>   r?   _init_fourier_transformrS   rV   r)   r*   r?      s    
z"FNetBasicFourierTransform.__init__c                 C   s   |j sttjjdd| _n~|jdkrt r| dtj	t
|jtjd | dtj	t
|jtjd tt| j| jd| _qtd t| _nt| _d S )	N)r      dim   dft_mat_hiddenr<   dft_mat_seq)r&   r'   zpSciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier transform instead.)use_tpu_fourier_optimizationsr   r#   r2   r4   fourier_transformrD   r   rN   Ztensorr   ZdftrB   r$   tpu_short_seq_lengthr-   rn   rm   r   warningrS   r)   r)   r*   rh      s,    
    
z1FNetBasicFourierTransform._init_fourier_transformc                 C   s   |  |j}|fS r,   )rp   real)rT   hidden_statesoutputsr)   r)   r*   ra      s    z!FNetBasicFourierTransform.forward)rb   rc   rd   r?   rh   ra   rf   r)   r)   rV   r*   rg      s   rg   c                       s$   e Zd Z fddZdd Z  ZS )FNetBasicOutputc                    s"   t    tj|j|jd| _d S Nr7   )r>   r?   r   rG   rB   rH   rS   rV   r)   r*   r?      s    
zFNetBasicOutput.__init__c                 C   s   |  || }|S r,   )rG   rT   rt   input_tensorr)   r)   r*   ra      s    zFNetBasicOutput.forwardrb   rc   rd   r?   ra   rf   r)   r)   rV   r*   rv      s   rv   c                       s$   e Zd Z fddZdd Z  ZS )FNetFourierTransformc                    s"   t    t|| _t|| _d S r,   )r>   r?   rg   rT   rv   outputrS   rV   r)   r*   r?      s    

zFNetFourierTransform.__init__c                 C   s$   |  |}| |d |}|f}|S Nr   )rT   r|   )rT   rt   Zself_outputsfourier_outputru   r)   r)   r*   ra      s    
zFNetFourierTransform.forwardrz   r)   r)   rV   r*   r{      s   r{   c                       s0   e Zd Z fddZejejdddZ  ZS )FNetIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r,   )r>   r?   r   rI   rB   intermediate_sizedense
isinstance
hidden_actstrr   intermediate_act_fnrS   rV   r)   r*   r?      s
    
zFNetIntermediate.__init__rt   returnc                 C   s   |  |}| |}|S r,   )r   r   rT   rt   r)   r)   r*   ra      s    

zFNetIntermediate.forwardrb   rc   rd   r?   r#   Tensorra   rf   r)   r)   rV   r*   r      s   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )
FNetOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S rw   )r>   r?   r   rI   r   rB   r   rG   rH   rK   rL   rM   rS   rV   r)   r*   r?      s    
zFNetOutput.__init__)rt   ry   r   c                 C   s&   |  |}| |}| || }|S r,   )r   rM   rG   rx   r)   r)   r*   ra      s    

zFNetOutput.forwardr   r)   r)   rV   r*   r      s   r   c                       s,   e Zd Z fddZdd Zdd Z  ZS )	FNetLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S Nr   )
r>   r?   chunk_size_feed_forwardseq_len_dimr{   fourierr   intermediater   r|   rS   rV   r)   r*   r?      s    


zFNetLayer.__init__c                 C   s0   |  |}|d }t| j| j| j|}|f}|S r}   )r   r   feed_forward_chunkr   r   )rT   rt   Zself_fourier_outputsr~   layer_outputru   r)   r)   r*   ra     s    
   zFNetLayer.forwardc                 C   s   |  |}| ||}|S r,   )r   r|   )rT   r~   Zintermediate_outputr   r)   r)   r*   r     s    
zFNetLayer.feed_forward_chunk)rb   rc   rd   r?   ra   r   rf   r)   r)   rV   r*   r      s   r   c                       s&   e Zd Z fddZdddZ  ZS )FNetEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r)   )r   ).0_rU   r)   r*   
<listcomp>  s     z(FNetEncoder.__init__.<locals>.<listcomp>F)	r>   r?   rU   r   Z
ModuleListr0   Znum_hidden_layerslayergradient_checkpointingrS   rV   r   r*   r?     s    
 zFNetEncoder.__init__FTc           	      C   s   |rdnd }t | jD ]P\}}|r,||f }| jrV| jrVdd }tjj|||}n||}|d }q|rv||f }|stdd ||fD S t||dS )Nr)   c                    s    fdd}|S )Nc                     s    |  S r,   r)   )inputsmoduler)   r*   custom_forward)  s    zJFNetEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr)   )r   r   r)   r   r*   create_custom_forward(  s    z2FNetEncoder.forward.<locals>.create_custom_forwardr   c                 s   s   | ]}|d k	r|V  qd S r,   r)   )r   vr)   r)   r*   	<genexpr>8  s      z&FNetEncoder.forward.<locals>.<genexpr>)last_hidden_statert   )		enumerater   r   Ztrainingr#   utils
checkpointtupler   )	rT   rt   output_hidden_statesreturn_dictZall_hidden_statesiZlayer_moduler   Zlayer_outputsr)   r)   r*   ra     s    


zFNetEncoder.forward)FTrz   r)   r)   rV   r*   r     s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )
FNetPoolerc                    s*   t    t|j|j| _t | _d S r,   )r>   r?   r   rI   rB   r   ZTanh
activationrS   rV   r)   r*   r?   ?  s    
zFNetPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S r}   )r   r   )rT   rt   Zfirst_token_tensorpooled_outputr)   r)   r*   ra   D  s    

zFNetPooler.forwardr   r)   r)   rV   r*   r   >  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )FNetPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S rw   )r>   r?   r   rI   rB   r   r   r   r   r   transform_act_fnrG   rH   rS   rV   r)   r*   r?   O  s    
z$FNetPredictionHeadTransform.__init__r   c                 C   s"   |  |}| |}| |}|S r,   )r   r   rG   r   r)   r)   r*   ra   X  s    


z#FNetPredictionHeadTransform.forwardr   r)   r)   rV   r*   r   N  s   	r   c                       s,   e Zd Z fddZdd Zdd Z  ZS )FNetLMPredictionHeadc                    sH   t    t|| _t|j|j| _t	t
|j| _| j| j_d S r,   )r>   r?   r   	transformr   rI   rB   rA   decoder	Parameterr#   rP   biasrS   rV   r)   r*   r?   `  s
    

zFNetLMPredictionHead.__init__c                 C   s   |  |}| |}|S r,   )r   r   r   r)   r)   r*   ra   k  s    

zFNetLMPredictionHead.forwardc                 C   s   | j j| _d S r,   )r   r   rT   r)   r)   r*   _tie_weightsp  s    z!FNetLMPredictionHead._tie_weights)rb   rc   rd   r?   ra   r   rf   r)   r)   rV   r*   r   _  s   r   c                       s$   e Zd Z fddZdd Z  ZS )FNetOnlyMLMHeadc                    s   t    t|| _d S r,   )r>   r?   r   predictionsrS   rV   r)   r*   r?   v  s    
zFNetOnlyMLMHead.__init__c                 C   s   |  |}|S r,   )r   )rT   sequence_outputprediction_scoresr)   r)   r*   ra   z  s    
zFNetOnlyMLMHead.forwardrz   r)   r)   rV   r*   r   u  s   r   c                       s$   e Zd Z fddZdd Z  ZS )FNetOnlyNSPHeadc                    s   t    t|jd| _d S Nri   )r>   r?   r   rI   rB   seq_relationshiprS   rV   r)   r*   r?     s    
zFNetOnlyNSPHead.__init__c                 C   s   |  |}|S r,   )r   )rT   r   seq_relationship_scorer)   r)   r*   ra     s    
zFNetOnlyNSPHead.forwardrz   r)   r)   rV   r*   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )FNetPreTrainingHeadsc                    s(   t    t|| _t|jd| _d S r   )r>   r?   r   r   r   rI   rB   r   rS   rV   r)   r*   r?     s    

zFNetPreTrainingHeads.__init__c                 C   s   |  |}| |}||fS r,   )r   r   )rT   r   r   r   r   r)   r)   r*   ra     s    

zFNetPreTrainingHeads.forwardrz   r)   r)   rV   r*   r     s   r   c                   @   s.   e Zd ZdZeZdZdZdd Zd
ddZ	d	S )FNetPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    fnetTc                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  nft |tj
rz|jjjd| jjd |jdk	r|jj|j 	  n&t |tjr|jj	  |jjd dS )zInitialize the weightsg        )ZmeanZstdNg      ?)r   r   rI   weightdataZnormal_rU   Zinitializer_ranger   Zzero_r@   r6   rG   Zfill_)rT   r   r)   r)   r*   _init_weights  s    

z!FNetPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r,   )r   r   r   )rT   r   valuer)   r)   r*   _set_gradient_checkpointing  s    
z/FNetPreTrainedModel._set_gradient_checkpointingN)F)
rb   rc   rd   re   r    config_classZbase_model_prefixZsupports_gradient_checkpointingr   r   r)   r)   r)   r*   r     s   r   c                   @   sV   e Zd ZU dZdZeej ed< dZ	ejed< dZ
ejed< dZeeej  ed< dS )FNetForPreTrainingOutputa  
    Output type of [`FNetForPreTraining`].

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
    Nlossprediction_logitsseq_relationship_logitsrt   )rb   rc   rd   re   r   r   r#   FloatTensor__annotations__r   r   rt   r   r)   r)   r)   r*   r     s
   
r   aG  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`FNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare FNet Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd ZdZd fdd	Zdd Zdd Zee	d	e
eeed
deej eej eej eej ee ee eeef dddZ  ZS )	FNetModelz

    The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
    Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.

    Tc                    sD   t  | || _t|| _t|| _|r2t|nd | _| 	  d S r,   )
r>   r?   rU   r5   r`   r   encoderr   pooler	post_init)rT   rU   Zadd_pooling_layerrV   r)   r*   r?     s    

zFNetModel.__init__c                 C   s   | j jS r,   r`   rC   r   r)   r)   r*   get_input_embeddings  s    zFNetModel.get_input_embeddingsc                 C   s   || j _d S r,   r   )rT   r   r)   r)   r*   set_input_embeddings  s    zFNetModel.set_input_embeddingsbatch_size, sequence_lengthr   output_typer   N)r[   r;   r8   r\   r   r   r   c                 C   s~  |d k	r|n| j j}|d k	r |n| j j}|d k	rB|d k	rBtdnD|d k	r\| }|\}}	n*|d k	r~| d d }|\}}	ntd| j jr|	dkr| j j|	krtd|d k	r|jn|j}
|d krt| j	dr| j	j
d d d |	f }|||	}|}ntj|tj|
d}| j	||||d}| j|||d	}|d
 }| jd k	rP| |nd }|sn||f|dd   S t|||jdS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer9   z5You have to specify either input_ids or inputs_embedsrl   zThe `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to the model when using TPU optimizations.r;   rX   )r[   r8   r;   r\   )r   r   r   r   )r   pooler_outputrt   )rU   r   use_return_dict
ValueErrorrQ   ro   rq   rY   rZ   r`   r;   rO   r#   rP   rR   r   r   r   rt   )rT   r[   r;   r8   r\   r   r   r]   Z
batch_sizer(   rY   r^   r_   Zembedding_outputZencoder_outputsr   r   r)   r)   r*   ra     s`    




zFNetModel.forward)T)NNNNNN)rb   rc   rd   re   r?   r   r   r   FNET_INPUTS_DOCSTRINGformatr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOCr   r#   Z
LongTensorr   boolr   r   ra   rf   r)   r)   rV   r*   r     s2         
r   z
    FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head.
    c                       s   e Zd ZddgZ fddZdd Zdd Zee	d	e
eed
deej eej eej eej eej eej ee ee eeef d	ddZ  ZS )FNetForPreTrainingcls.predictions.decoder.biascls.predictions.decoder.weightc                    s,   t  | t|| _t|| _|   d S r,   )r>   r?   r   r   r   clsr   rS   rV   r)   r*   r?   s  s    

zFNetForPreTraining.__init__c                 C   s
   | j jjS r,   r   r   r   r   r)   r)   r*   get_output_embeddings|  s    z(FNetForPreTraining.get_output_embeddingsc                 C   s   || j j_d S r,   r   rT   Znew_embeddingsr)   r)   r*   set_output_embeddings  s    z(FNetForPreTraining.set_output_embeddingsr   r   r   N)	r[   r;   r8   r\   labelsnext_sentence_labelr   r   r   c	                 C   s   |dk	r|n| j j}| j||||||d}	|	dd \}
}| |
|\}}d}|dk	r|dk	rt }||d| j j|d}||dd|d}|| }|s||f|	dd  }|dk	r|f| S |S t||||	jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring) Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.
        kwargs (`Dict[str, any]`, optional, defaults to *{}*):
            Used to hide legacy arguments that have been deprecated.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FNetForPreTraining
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> prediction_logits = outputs.prediction_logits
        >>> seq_relationship_logits = outputs.seq_relationship_logits
        ```Nr;   r8   r\   r   r   ri   r9   )r   r   r   rt   )	rU   r   r   r   r	   viewrA   r   rt   )rT   r[   r;   r8   r\   r   r   r   r   ru   r   r   r   r   
total_lossloss_fctmasked_lm_lossnext_sentence_lossr|   r)   r)   r*   ra     s4    *	zFNetForPreTraining.forward)NNNNNNNN)rb   rc   rd   _tied_weights_keysr?   r   r   r   r   r   r   r   r   r   r#   r   r   r   r   ra   rf   r)   r)   rV   r*   r   i  s2   	
        
r   z2FNet Model with a `language modeling` head on top.c                       s   e Zd ZddgZ fddZdd Zdd Zee	d	e
eeed
deej eej eej eej eej ee ee eeef dddZ  ZS )FNetForMaskedLMr   r   c                    s,   t  | t|| _t|| _|   d S r,   )r>   r?   r   r   r   r   r   rS   rV   r)   r*   r?     s    

zFNetForMaskedLM.__init__c                 C   s
   | j jjS r,   r   r   r)   r)   r*   r     s    z%FNetForMaskedLM.get_output_embeddingsc                 C   s   || j j_d S r,   r   r   r)   r)   r*   r     s    z%FNetForMaskedLM.set_output_embeddingsr   r   Nr[   r;   r8   r\   r   r   r   r   c                 C   s   |dk	r|n| j j}| j||||||d}|d }	| |	}
d}|dk	rjt }||
d| j j|d}|s|
f|dd  }|dk	r|f| S |S t||
|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Nr   r   r9   ri   r   logitsrt   )	rU   r   r   r   r	   r   rA   r   rt   )rT   r[   r;   r8   r\   r   r   r   ru   r   r   r   r   r|   r)   r)   r*   ra     s&    	
zFNetForMaskedLM.forward)NNNNNNN)rb   rc   rd   r   r?   r   r   r   r   r   r   r   r   r   r   r#   r   r   r   r   ra   rf   r)   r)   rV   r*   r     s6   	       
r   zJFNet Model with a `next sentence prediction (classification)` head on top.c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
e e
e eeef dddZ  ZS )
FNetForNextSentencePredictionc                    s,   t  | t|| _t|| _|   d S r,   )r>   r?   r   r   r   r   r   rS   rV   r)   r*   r?     s    

z&FNetForNextSentencePrediction.__init__r   r   Nr   c                 K   s   d|krt dt |d}|dk	r*|n| jj}| j||||||d}	|	d }
| |
}d}|dk	rt }||	dd|	d}|s|f|	dd  }|dk	r|f| S |S t
|||	jdS )	a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring). Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
        >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
        >>> logits = outputs.logits
        >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
        ```r   zoThe `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.Nr   r   r9   ri   r   )warningswarnFutureWarningpoprU   r   r   r   r	   r   r   rt   )rT   r[   r;   r8   r\   r   r   r   kwargsru   r   Zseq_relationship_scoresr   r   r|   r)   r)   r*   ra     s:    '
	
z%FNetForNextSentencePrediction.forward)NNNNNNN)rb   rc   rd   r?   r   r   r   r   r   r   r   r#   r   r   r   r   ra   rf   r)   r)   rV   r*   r     s(   	
       
r   z
    FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej ee ee eee	f dddZ  ZS )
FNetForSequenceClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r,   r>   r?   
num_labelsr   r   r   rK   rL   rM   rI   rB   
classifierr   rS   rV   r)   r*   r?   t  s    
z&FNetForSequenceClassification.__init__r   r   Nr   c                 C   sr  |dk	r|n| j j}| j||||||d}|d }	| |	}	| |	}
d}|dk	r2| j jdkr| jdkrtd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }| jdkr||
 | }n
||
|}nN| j jdkrt }||
d| j|d}n| j jdkr2t }||
|}|sb|
f|dd  }|dk	r^|f| S |S t||
|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr9   ri   r   )rU   r   r   rM   r   Zproblem_typer   r=   r#   rR   intr
   squeezer	   r   r   r   rt   )rT   r[   r;   r8   r\   r   r   r   ru   r   r   r   r   r|   r)   r)   r*   ra     sF    	




"


z%FNetForSequenceClassification.forward)NNNNNNN)rb   rc   rd   r?   r   r   r   r   r   r   r   r   r#   r   r   r   r   ra   rf   r)   r)   rV   r*   r   l  s0          
r   z
    FNet Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej ee ee eee	f dddZ  ZS )
FNetForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S r   )r>   r?   r   r   r   rK   rL   rM   rI   rB   r   r   rS   rV   r)   r*   r?     s
    
zFNetForMultipleChoice.__init__z(batch_size, num_choices, sequence_lengthr   Nr   c                 C   sL  |dk	r|n| j j}|dk	r&|jd n|jd }|dk	rJ|d|dnd}|dk	rh|d|dnd}|dk	r|d|dnd}|dk	r|d|d|dnd}| j||||||d}	|	d }
| |
}
| |
}|d|}d}|dk	rt }|||}|s<|f|	dd  }|dk	r8|f| S |S t	|||	j
dS )aJ  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   r9   r   ri   r   )rU   r   r!   r   rQ   r   rM   r   r	   r   rt   )rT   r[   r;   r8   r\   r   r   r   Znum_choicesru   r   r   Zreshaped_logitsr   r   r|   r)   r)   r*   ra     s:    	



zFNetForMultipleChoice.forward)NNNNNNN)rb   rc   rd   r?   r   r   r   r   r   r   r   r   r#   r   r   r   r   ra   rf   r)   r)   rV   r*   r    s0   
       
r  z
    FNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej ee ee eee	f dddZ  ZS )
FNetForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r,   r   rS   rV   r)   r*   r?     s    
z#FNetForTokenClassification.__init__r   r   Nr   c                 C   s   |dk	r|n| j j}| j||||||d}|d }	| |	}	| |	}
d}|dk	rrt }||
d| j|d}|s|
f|dd  }|dk	r|f| S |S t||
|j	dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr   r   r9   ri   r   )
rU   r   r   rM   r   r	   r   r   r   rt   )rT   r[   r;   r8   r\   r   r   r   ru   r   r   r   r   r|   r)   r)   r*   ra   $  s(    	

z"FNetForTokenClassification.forward)NNNNNNN)rb   rc   rd   r?   r   r   r   r   r   r   r   r   r#   r   r   r   r   ra   rf   r)   r)   rV   r*   r    s0          
r  z
    FNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej eej ee ee eee	f d	ddZ  ZS )
FNetForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r,   )
r>   r?   r   r   r   r   rI   rB   
qa_outputsr   rS   rV   r)   r*   r?   ]  s
    
z!FNetForQuestionAnswering.__init__r   r   N)	r[   r;   r8   r\   start_positionsend_positionsr   r   r   c	                 C   sB  |dk	r|n| j j}| j||||||d}	|	d }
| |
}|jddd\}}|d }|d }d}|dk	r|dk	rt| dkr|d}t| dkr|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s0||f|	dd  }|dk	r,|f| S |S t||||	jd	S )
a  
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        Nr   r   r   r9   rj   )Zignore_indexri   )r   start_logits
end_logitsrt   )rU   r   r   r  splitr   
contiguouslenrQ   clampr	   r   rt   )rT   r[   r;   r8   r\   r  r  r   r   ru   r   r   r  r	  r   Zignored_indexr   Z
start_lossZend_lossr|   r)   r)   r*   ra   h  sH    	






   z FNetForQuestionAnswering.forward)NNNNNNNN)rb   rc   rd   r?   r   r   r   r   r   r   r   r   r#   r   r   r   r   ra   rf   r)   r)   rV   r*   r  U  s4           
r  )Pre   r   dataclassesr   	functoolsr   typingr   r   r   r#   Ztorch.utils.checkpointr   Ztorch.nnr   r	   r
   r   r   Zscipyr   Zactivationsr   Zmodeling_outputsr   r   r   r   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   r   Zconfiguration_fnetr    Z
get_loggerrb   loggerr   r   Z"FNET_PRETRAINED_MODEL_ARCHIVE_LISTr+   r-   r4   Moduler5   rg   rv   r{   r   r   r   r   r   r   r   r   r   r   r   r   ZFNET_START_DOCSTRINGr   r   r   r   r   r   r  r  r  r)   r)   r)   r*   <module>   s   ,
	=&
& #f]BXNH>