U
    ,-e                     @   s  d Z ddlZddlZddlZddlmZ ddlmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZmZmZ ddlmZ dd	lmZmZmZmZmZmZ dd
lmZ ddlmZmZmZ ddl m!Z!m"Z"m#Z#m$Z$m%Z% ddl&m'Z' e$(e)Z*dZ+dZ,ddgZ-ddddddddddddddd d!gZ.eG d"d# d#eZ/d$d% Z0G d&d' d'ej1Z2G d(d) d)ej1Z3G d*d+ d+ej1Z4G d,d- d-ej1Z5G d.d/ d/ej1Z6G d0d1 d1ej1Z7G d2d3 d3ej1Z8G d4d5 d5ej1Z9G d6d7 d7ej1Z:G d8d9 d9ej1Z;G d:d; d;ej1Z<G d<d= d=ej1Z=G d>d? d?ej1Z>G d@dA dAej1Z?G dBdC dCeZ@dDZAdEZBe"dFeAG dGdH dHe@ZCe"dIeAG dJdK dKe@ZDe"dLeAG dMdN dNe@ZEe"dOeAG dPdQ dQe@ZFe"dReAG dSdT dTe@ZGdS )Uz PyTorch CANINE model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputModelOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )CanineConfigzgoogle/canine-sr   zgoogle/canine-r   +   ;   =   I   a   g   q                           c                   @   sZ   e Zd ZU dZdZejed< dZejed< dZ	e
eej  ed< dZe
eej  ed< dS )CanineModelOutputWithPoolinga  
    Output type of [`CanineModel`]. Based on [`~modeling_outputs.BaseModelOutputWithPooling`], but with slightly
    different `hidden_states` and `attentions`, as these also include the hidden states and attentions of the shallow
    Transformer encoders.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model (i.e. the output of the final
            shallow Transformer encoder).
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Hidden-state of the first token of the sequence (classification token) at the last layer of the deep
            Transformer encoder, further processed by a Linear layer and a Tanh activation function. The Linear layer
            weights are trained from the next sentence prediction (classification) objective during pretraining.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the input to each encoder + one for the output of each layer of each
            encoder) of shape `(batch_size, sequence_length, hidden_size)` and `(batch_size, sequence_length //
            config.downsampling_rate, hidden_size)`. Hidden-states of the model at the output of each layer plus the
            initial input to each Transformer encoder. The hidden states of the shallow encoders have length
            `sequence_length`, but the hidden states of the deep encoder have length `sequence_length` //
            `config.downsampling_rate`.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of the 3 Transformer encoders of shape `(batch_size,
            num_heads, sequence_length, sequence_length)` and `(batch_size, num_heads, sequence_length //
            config.downsampling_rate, sequence_length // config.downsampling_rate)`. Attentions weights after the
            attention softmax, used to compute the weighted average in the self-attention heads.
    Nlast_hidden_statepooler_outputhidden_states
attentions)__name__
__module____qualname____doc__r.   torchFloatTensor__annotations__r/   r0   r   r   r1    r9   r9   k/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/canine/modeling_canine.pyr-   A   s
   
r-   c                 C   sR  zddl }ddl}ddl}W n  tk
r<   td  Y nX tj|}t	d|  |j
|}g }g }	|D ]@\}
}t	d|
 d|  |j
||
}||
 |	| qrt||	D ]\}
}|
d}
tdd	 |
D rt	d
d|
  q|
d dkrd|
d< n|
d dkr2|
|
d  nh|
d dkrJd|
d< nP|
d dkrldg|
dd  }
n.|
d dkr|
d dkrdg|
dd  }
| }|
D ]}|d|rd|kr|d|}n|g}|d dks|d dkrt|d}n|d dks|d dkr$t|d}n`|d d kr>t|d}nFzt||d }W n2 tk
r   t	d
d|
  Y qY nX t|d!krt|d }|| }q|d"d d#krt|d}n@|d$d d%d& td'D krt|d}n|dkr||}|j|jkr0td(|j d)|j d*t	d+|
  t||_q| S ),z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape /c                 s   s   | ]}|d kV  qdS ))Zadam_vZadam_mZAdamWeightDecayOptimizerZAdamWeightDecayOptimizer_1Zglobal_stepclsZautoregressive_decoderZchar_output_weightsNr9   ).0nr9   r9   r:   	<genexpr>   s   z,load_tf_weights_in_canine.<locals>.<genexpr>z	Skipping Zbertencoderr   
embeddingsZsegment_embeddingstoken_type_embeddingsinitial_char_encoderchars_to_moleculesfinal_char_encoder)	LayerNormconv
projectionz[A-Za-z]+_\d+ZEmbedderz_(\d+)ZkernelgammaweightZoutput_biasbetabiasZoutput_weights   iZ_embeddingsic                 S   s   g | ]}d | qS )Z	Embedder_r9   )r=   ir9   r9   r:   
<listcomp>   s     z-load_tf_weights_in_canine.<locals>.<listcomp>   zPointer shape z and array shape z mismatchedzInitialize PyTorch weight )renumpyZ
tensorflowImportErrorloggererrorospathabspathinfotrainZlist_variablesZload_variableappendzipsplitanyjoinremove	fullmatchgetattrAttributeErrorlenintrange	transposeshape
ValueErrorr6   Z
from_numpydata)modelconfigZtf_checkpoint_pathrR   nptfZtf_pathZ	init_varsnamesZarraysnameri   arrayZpointerZm_nameZscope_namesnumr9   r9   r:   load_tf_weights_in_canined   s    




 

rt   c                       st   e Zd ZdZ fddZeedddZeeeddd	Zdee	j
 ee	j
 ee	j
 ee	j e	jdddZ  ZS )CanineEmbeddingsz<Construct the character, position and token_type embeddings.c                    s   t    || _|j|j }t|jD ]$}d| }t| |t|j	| q&t|j	|j| _
t|j|j| _tj|j|jd| _t|j| _| jdt|jddd t|dd| _d S )	NHashBucketCodepointEmbedder_Zepsposition_ids)r   F)
persistentposition_embedding_typeabsolute)super__init__rm   hidden_sizenum_hash_functionsrg   setattrr   	Embeddingnum_hash_bucketschar_position_embeddingsZtype_vocab_sizerB   rG   layer_norm_epsDropouthidden_dropout_probdropoutZregister_bufferr6   arangemax_position_embeddingsexpandrc   r{   )selfrm   Zshard_embedding_sizerO   rq   	__class__r9   r:   r~      s     

  zCanineEmbeddings.__init__
num_hashesnum_bucketsc                 C   sV   |t tkrtdt t td| }g }|D ]}|d | | }|| q2|S )a  
        Converts ids to hash bucket ids via multiple hashing.

        Args:
            input_ids: The codepoints or other IDs to be hashed.
            num_hashes: The number of hash functions to use.
            num_buckets: The number of hash buckets (i.e. embeddings in each table).

        Returns:
            A list of tensors, each of which is the hash bucket IDs from one hash function.
        z`num_hashes` must be <= Nr   )re   _PRIMESrj   r\   )r   	input_idsr   r   ZprimesZresult_tensorsprimehashedr9   r9   r:   _hash_bucket_tensors   s    z%CanineEmbeddings._hash_bucket_tensors)embedding_sizer   r   c                 C   sx   || dkr"t d| d| d| j|||d}g }t|D ]*\}}d| }	t| |	|}
||
 q>tj|ddS )	zDConverts IDs (e.g. codepoints) into embeddings via multiple hashing.r   zExpected `embedding_size` (z) % `num_hashes` (z) == 0r   rv   ry   dim)rj   r   	enumeraterc   r\   r6   cat)r   r   r   r   r   Zhash_bucket_tensorsZembedding_shardsrO   Zhash_bucket_idsrq   Zshard_embeddingsr9   r9   r:   _embed_hash_buckets   s    
z$CanineEmbeddings._embed_hash_bucketsN)r   token_type_idsrx   inputs_embedsreturnc           
      C   s   |d k	r|  }n|  d d }|d }|d krH| jd d d |f }|d krftj|tj| jjd}|d kr| || jj| jj	| jj
}| |}|| }| jdkr| |}	||	7 }| |}| |}|S )Nry   r   dtypedevicer|   )sizerx   r6   zeroslongr   r   rm   r   r   r   rB   r{   r   rG   r   )
r   r   r   rx   r   input_shape
seq_lengthrB   rA   Zposition_embeddingsr9   r9   r:   forward  s.    
   




zCanineEmbeddings.forward)NNNN)r2   r3   r4   r5   r~   rf   r   r   r   r6   
LongTensorr7   r   __classcell__r9   r9   r   r:   ru      s       ru   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )CharactersToMoleculeszeConvert character sequence to initial molecule sequence (i.e. downsample) using strided convolutions.c                    sJ   t    tj|j|j|j|jd| _t|j | _	tj
|j|jd| _
d S )NZin_channelsZout_channelskernel_sizestriderw   )r}   r~   r   Conv1dr   downsampling_raterH   r   
hidden_act
activationrG   r   r   rm   r   r9   r:   r~   -  s    
zCharactersToMolecules.__init__)char_encodingr   c                 C   s   |d d ddd d f }t |dd}| |}t |dd}| |}|d d ddd d f }t j||gdd}| |}|S )Nr   r   rN   ry   r   )r6   rh   rH   r   r   rG   )r   r   Zcls_encodingZdownsampledZdownsampled_truncatedresultr9   r9   r:   r   <  s    


zCharactersToMolecules.forward)	r2   r3   r4   r5   r~   r6   Tensorr   r   r9   r9   r   r:   r   *  s   r   c                       s>   e Zd ZdZ fddZdejeej ejdddZ  Z	S )	ConvProjectionz
    Project representations from hidden_size*2 back to hidden_size across a window of w = config.upsampling_kernel_size
    characters.
    c                    s`   t    || _tj|jd |j|jdd| _t|j	 | _
tj|j|jd| _t|j| _d S )NrN   r   r   rw   )r}   r~   rm   r   r   r   upsampling_kernel_sizerH   r   r   r   rG   r   r   r   r   r   r   r9   r:   r~   ^  s    
zConvProjection.__init__N)inputsfinal_seq_char_positionsr   c           
      C   s   t |dd}| jjd }|d }|| }t||fd}| ||}t |dd}| |}| |}| 	|}|}|d k	rt
dn|}	|	S )Nr   rN   r   z,CanineForMaskedLM is currently not supported)r6   rh   rm   r   r   ZConstantPad1drH   r   rG   r   NotImplementedError)
r   r   r   Z	pad_totalZpad_begZpad_endpadr   Zfinal_char_seqZ	query_seqr9   r9   r:   r   m  s    



zConvProjection.forward)N)
r2   r3   r4   r5   r~   r6   r   r   r   r   r9   r9   r   r:   r   X  s    r   c                
       sb   e Zd Z fddZdd Zd
ejejeej eej ee	 e
ejeej f ddd	Z  ZS )CanineSelfAttentionc                    s   t    |j|j dkr>t|ds>td|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t|dd| _| jdks| jd	kr|j| _t	d
|j d | j| _d S )Nr   r   zThe hidden size (z6) is not a multiple of the number of attention heads ()r{   r|   relative_keyrelative_key_queryrN   r   )r}   r~   r   num_attention_headshasattrrj   rf   attention_head_sizeall_head_sizer   Linearquerykeyvaluer   Zattention_probs_dropout_probr   rc   r{   r   r   distance_embeddingr   r   r9   r:   r~     s     
zCanineSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )Nry   r   rN   r   r
   )r   r   r   viewpermute)r   xZnew_x_shaper9   r9   r:   transpose_for_scores  s    
z(CanineSelfAttention.transpose_for_scoresNF)from_tensor	to_tensorattention_mask	head_maskoutput_attentionsr   c                 C   s  |  |}| | |}| | |}| |}	t|	|dd}
| jdks^| jdkr"| d }tj	|tj
|jddd}tj	|tj
|jddd}|| }| || j d }|j|	jd}| jdkrtd|	|}|
| }
n4| jdkr"td|	|}td	||}|
| | }
|
t| j }
|d k	rx|jd
krptj|dd}d|  t|
jj }|
| }
tjj|
dd}| |}|d k	r|| }t||}|dddd
 }| d d | j f }|j| }|r||fn|f}|S )Nry   rE   r   r   r   r   )r   zbhld,lrd->bhlrzbhrd,lrd->bhlrr
   r         ?r   rN   )!r   r   r   r   r6   matmulrh   r{   r   r   r   r   r   r   r   tor   Zeinsummathsqrtr   ndimZ	unsqueezefloatZfinfominr   Z
functionalZsoftmaxr   r   
contiguousr   )r   r   r   r   r   r   Zmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresr   Zposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keyZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr9   r9   r:   r     sF    







zCanineSelfAttention.forward)NNF)r2   r3   r4   r~   r   r6   r   r   r7   boolr   r   r   r9   r9   r   r:   r     s   	   r   c                       sB   e Zd Z fddZeej ejeejejf dddZ  ZS )CanineSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nrw   )r}   r~   r   r   r   denserG   r   r   r   r   r   r   r9   r:   r~     s    
zCanineSelfOutput.__init__r0   input_tensorr   c                 C   s&   |  |}| |}| || }|S Nr   r   rG   r   r0   r   r9   r9   r:   r     s    

zCanineSelfOutput.forward	r2   r3   r4   r~   r   r6   r7   r   r   r9   r9   r   r:   r     s
    r   c                	       sx   e Zd ZdZdeeeeeed fddZdd Zdee	j
 ee	j
 ee	j
 ee ee	j
ee	j
 f d
ddZ  ZS )CanineAttentionav  
    Additional arguments related to local attention:

        - **local** (`bool`, *optional*, defaults to `False`) -- Whether to apply local attention.
        - **always_attend_to_first_position** (`bool`, *optional*, defaults to `False`) -- Should all blocks be able to
          attend
        to the `to_tensor`'s first position (e.g. a [CLS] position)? - **first_position_attends_to_all** (`bool`,
        *optional*, defaults to `False`) -- Should the *from_tensor*'s first position be able to attend to all
        positions within the *from_tensor*? - **attend_from_chunk_width** (`int`, *optional*, defaults to 128) -- The
        width of each block-wise chunk in `from_tensor`. - **attend_from_chunk_stride** (`int`, *optional*, defaults to
        128) -- The number of elements to skip when moving to the next block in `from_tensor`. -
        **attend_to_chunk_width** (`int`, *optional*, defaults to 128) -- The width of each block-wise chunk in
        *to_tensor*. - **attend_to_chunk_stride** (`int`, *optional*, defaults to 128) -- The number of elements to
        skip when moving to the next block in `to_tensor`.
    F   )always_attend_to_first_positionfirst_position_attends_to_allattend_from_chunk_widthattend_from_chunk_strideattend_to_chunk_widthattend_to_chunk_stridec	           	         st   t    t|| _t|| _t | _|| _||k r<t	d||k rLt	d|| _
|| _|| _|| _|| _|| _d S )Nze`attend_from_chunk_width` < `attend_from_chunk_stride` would cause sequence positions to get skipped.z``attend_to_chunk_width` < `attend_to_chunk_stride`would cause sequence positions to get skipped.)r}   r~   r   r   r   outputsetpruned_headslocalrj   r   r   r   r   r   r   	r   rm   r   r   r   r   r   r   r   r   r9   r:   r~     s&    


zCanineAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )re   r   r   r   r   r   r   r   r   r   r   r   r   union)r   headsindexr9   r9   r:   prune_heads8  s       zCanineAttention.prune_headsNr0   r   r   r   r   c                 C   sL  | j s$| |||||}|d }n|jd  }}| }	}
g }| jrT|d d}nd}t||| jD ]"}t||| j }|||f qfg }| jr|d|f td|| j	D ]"}t||| j
 }|||f qt|t|krtd| d| dg }g }t||D ]\\}}\}}|	d d ||d d f }|
d d ||d d f }|d d ||||f }| jr|d d ||ddf }tj||gdd}|
d d ddd d f }tj||gdd}| |||||}||d  |r||d  qtj|dd}| ||}|f}| j s<||dd   }n|t| }|S )	Nr   r   )r   r   z/Expected to have same number of `from_chunks` (z) and `to_chunks` (z). Check strides.rN   r   )r   r   ri   r   r\   rg   r   r   r   r   r   re   rj   r]   r   r6   r   r   tuple)r   r0   r   r   r   Zself_outputsattention_outputfrom_seq_lengthto_seq_lengthr   r   Zfrom_chunksZ
from_startZchunk_startZ	chunk_endZ	to_chunksZattention_output_chunksZattention_probs_chunksZfrom_endZto_startZto_endZfrom_tensor_chunkZto_tensor_chunkZattention_mask_chunkZcls_attention_maskZcls_positionZattention_outputs_chunkr   r9   r9   r:   r   J  sf    
    zCanineAttention.forward)FFFr   r   r   r   )NNF)r2   r3   r4   r5   r   rf   r~   r   r   r6   r7   r   r   r   r9   r9   r   r:   r     s6          !   r   c                       s0   e Zd Z fddZejejdddZ  ZS )CanineIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r}   r~   r   r   r   intermediate_sizer   
isinstancer   strr   intermediate_act_fnr   r   r9   r:   r~     s
    
zCanineIntermediate.__init__r0   r   c                 C   s   |  |}| |}|S r   )r   r  r   r0   r9   r9   r:   r     s    

zCanineIntermediate.forward)r2   r3   r4   r~   r6   r7   r   r   r9   r9   r   r:   r     s   r   c                       s8   e Zd Z fddZeej ejejdddZ  ZS )CanineOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r}   r~   r   r   r   r   r   rG   r   r   r   r   r   r   r9   r:   r~     s    
zCanineOutput.__init__r   c                 C   s&   |  |}| |}| || }|S r   r   r   r9   r9   r:   r     s    

zCanineOutput.forwardr   r9   r9   r   r:   r    s   r  c                	       sb   e Zd Z fddZd
eej eej eej ee eejeej f dddZ	dd	 Z
  ZS )CanineLayerc	           	   	      sH   t    |j| _d| _t||||||||| _t|| _t|| _	d S Nr   )
r}   r~   chunk_size_feed_forwardseq_len_dimr   	attentionr   intermediater  r   r   r   r9   r:   r~     s    


zCanineLayer.__init__NFr   c           	      C   sH   | j ||||d}|d }|dd  }t| j| j| j|}|f| }|S )Nr   r   r   )r	  r   feed_forward_chunkr  r  )	r   r0   r   r   r   Zself_attention_outputsr   r   layer_outputr9   r9   r:   r     s        
zCanineLayer.forwardc                 C   s   |  |}| ||}|S r   )r
  r   )r   r   Zintermediate_outputr  r9   r9   r:   r    s    
zCanineLayer.feed_forward_chunk)NNF)r2   r3   r4   r~   r   r6   r7   r   r   r   r  r   r9   r9   r   r:   r    s      r  c                
       s`   e Zd Zd
 fdd	Zdeej eej eej ee ee ee e	ee
f ddd	Z  ZS )CanineEncoderFr   c	           	   
      sH   t    | _t fddtjD | _d| _d S )Nc                    s"   g | ]}t  qS r9   )r  )r=   _r   r   r   r   r   rm   r   r   r9   r:   rP     s   z*CanineEncoder.__init__.<locals>.<listcomp>F)	r}   r~   rm   r   Z
ModuleListrg   num_hidden_layerslayergradient_checkpointingr   r   r  r:   r~     s    
zCanineEncoder.__init__NT)r0   r   r   r   output_hidden_statesreturn_dictr   c                    s   |rdnd } rdnd }t | jD ]\}	}
|r8||f }|d k	rH||	 nd }| jr~| jr~ fdd}tjj||
|||}n|
||| }|d } r"||d f }q"|r||f }|stdd |||fD S t|||dS )	Nr9   c                    s    fdd}|S )Nc                     s    | f S r   r9   )r   )moduler   r9   r:   custom_forward   s    zLCanineEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr9   )r  r  r  )r  r:   create_custom_forward  s    z4CanineEncoder.forward.<locals>.create_custom_forwardr   r   c                 s   s   | ]}|d k	r|V  qd S r   r9   r=   vr9   r9   r:   r?   6  s      z(CanineEncoder.forward.<locals>.<genexpr>)r.   r0   r1   )	r   r  r  Ztrainingr6   utils
checkpointr   r   )r   r0   r   r   r   r  r  all_hidden_statesall_self_attentionsrO   Zlayer_moduleZlayer_head_maskr  Zlayer_outputsr9   r  r:   r     s6    	

zCanineEncoder.forward)FFFr   r   r   r   )NNFFT)r2   r3   r4   r~   r   r6   r7   r   r   r   r   r   r   r9   r9   r   r:   r    s,          !     
r  c                       s4   e Zd Z fddZeej ejdddZ  ZS )CaninePoolerc                    s*   t    t|j|j| _t | _d S r   )r}   r~   r   r   r   r   ZTanhr   r   r   r9   r:   r~   ?  s    
zCaninePooler.__init__r  c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r   r0   Zfirst_token_tensorpooled_outputr9   r9   r:   r   D  s    

zCaninePooler.forwardr   r9   r9   r   r:   r  >  s   r  c                       s4   e Zd Z fddZeej ejdddZ  ZS )CaninePredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S r   )r}   r~   r   r   r   r   r   r   r   r   transform_act_fnrG   r   r   r   r9   r:   r~   N  s    
z&CaninePredictionHeadTransform.__init__r  c                 C   s"   |  |}| |}| |}|S r   )r   r"  rG   r  r9   r9   r:   r   W  s    


z%CaninePredictionHeadTransform.forwardr   r9   r9   r   r:   r!  M  s   	r!  c                       s4   e Zd Z fddZeej ejdddZ  ZS )CanineLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)rM   )r}   r~   r!  	transformr   r   r   Z
vocab_sizedecoder	Parameterr6   r   rM   r   r   r9   r:   r~   _  s
    

zCanineLMPredictionHead.__init__r  c                 C   s   |  |}| |}|S r   )r$  r%  r  r9   r9   r:   r   l  s    

zCanineLMPredictionHead.forwardr   r9   r9   r   r:   r#  ^  s   r#  c                       s8   e Zd Z fddZeej eej dddZ  ZS )CanineOnlyMLMHeadc                    s   t    t|| _d S r   )r}   r~   r#  predictionsr   r   r9   r:   r~   s  s    
zCanineOnlyMLMHead.__init__)sequence_outputr   c                 C   s   |  |}|S r   )r(  )r   r)  Zprediction_scoresr9   r9   r:   r   w  s    
zCanineOnlyMLMHead.forward)	r2   r3   r4   r~   r   r6   r   r   r   r9   r9   r   r:   r'  r  s   r'  c                   @   s2   e Zd ZdZeZeZdZdZ	dd Z
d
ddZd	S )CaninePreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    canineTc                 C   s   t |tjtjfr@|jjjd| jjd |j	dk	r|j	j
  nft |tjr|jjjd| jjd |jdk	r|jj|j 
  n&t |tjr|j	j
  |jjd dS )zInitialize the weightsg        )ZmeanZstdNr   )r   r   r   r   rK   rk   Znormal_rm   Zinitializer_rangerM   Zzero_r   Zpadding_idxrG   Zfill_)r   r  r9   r9   r:   _init_weights  s    

z#CaninePreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r   )r   r  r  )r   r  r   r9   r9   r:   _set_gradient_checkpointing  s    
z1CaninePreTrainedModel._set_gradient_checkpointingN)F)r2   r3   r4   r5   r   config_classrt   Zload_tf_weightsZbase_model_prefixZsupports_gradient_checkpointingr,  r-  r9   r9   r9   r:   r*    s   r*  aI  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`CanineConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a5
  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z`The bare CANINE Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zd fdd	Zdd Zdd Zejedd	d
Z	ejejejdddZ
eedeeeeddeej eej eej eej eej eej ee ee ee eeef d
ddZ  ZS )CanineModelTc              
      s   t  | || _t|}d|_t|| _t|ddd|j	|j	|j	|j	d| _
t|| _t|| _t|| _t|| _|rt|nd | _|   d S )Nr   TF)r   r   r   r   r   r   r   )r}   r~   rm   copydeepcopyr  ru   char_embeddingsr  Zlocal_transformer_striderC   r   rD   r@   r   rI   rF   r  pooler	post_init)r   rm   Zadd_pooling_layerZshallow_configr   r9   r:   r~     s*    






zCanineModel.__init__c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr@   r  r	  r   )r   Zheads_to_pruner  r   r9   r9   r:   _prune_heads  s    zCanineModel._prune_headsc                 C   s\   |j d |j d  }}|j d }t||d|f }tj||dftj|jd}|| }|S )aP  
        Create 3D attention mask from a 2D tensor mask.

        Args:
            from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
            to_mask: int32 Tensor of shape [batch_size, to_seq_length].

        Returns:
            float Tensor of shape [batch_size, from_seq_length, to_seq_length].
        r   r   )r   r   r   )ri   r6   reshaper   onesZfloat32r   )r   r   Zto_mask
batch_sizer   r   Zbroadcast_onesmaskr9   r9   r:   )_create_3d_attention_mask_from_input_mask  s    
z5CanineModel._create_3d_attention_mask_from_input_mask)char_attention_maskr   c                 C   sF   |j \}}t||d|f}tjj||d| }tj|dd}|S )z[Downsample 2D character attention mask to 2D molecule attention mask using MaxPool1d layer.r   )r   r   ry   r   )ri   r6   r7  r   Z	MaxPool1dr   squeeze)r   r<  r   r9  Zchar_seq_lenZpoolable_char_maskZpooled_molecule_maskmolecule_attention_maskr9   r9   r:   _downsample_attention_mask"  s    
z&CanineModel._downsample_attention_mask)	moleculeschar_seq_lengthr   c           	      C   s   | j j}|ddddddf }tj||dd}|ddddddf }tt|t| }tj||| dd}tj||gddS )zDRepeats molecules to make them the same length as the char sequence.Nr   rE   )Zrepeatsr   ry   r   )rm   r   r6   Zrepeat_interleavefmodZtensoritemr   )	r   r@  rA  ZrateZmolecules_without_extra_clsZrepeatedZlast_moleculeZremainder_lengthZremainder_repeatedr9   r9   r:   _repeat_molecules3  s    zCanineModel._repeat_moleculesbatch_size, sequence_lengthr  output_typer.  N)
r   r   r   rx   r   r   r   r  r  r   c
           "      C   s  |d k	r|n| j j}|d k	r |n| j j}|r0dnd }
|r<dnd }|	d k	rL|	n| j j}	|d k	rn|d k	rntdn@|d k	r| || | }n"|d k	r| d d }ntd|\}}|d k	r|jn|j}|d krtj	||f|d}|d krtj
|tj|d}| ||}| j|| j jd}| |||jd f}| || j j}| j||||d}| |d k	rf|n||}| j||||d	}|j}| |}| j||||||	d
}|d }| jd k	r| |nd }| j||d d}tj||gdd}| |}| j||||d	}|j}|r<|	r |jn|d }|
|j | |j }
|rj|	rN|jn|d } ||j |  |j }|	s||f}!|!tdd |
|fD 7 }!|!S t |||
|dS )Nr9   zDYou cannot specify both input_ids and inputs_embeds at the same timery   z5You have to specify either input_ids or inputs_embeds)r   r   )r   )r   rx   r   r   )r   r   r  )r   r   r   r  r  r   )rA  r   r   c                 s   s   | ]}|d k	r|V  qd S r   r9   r  r9   r9   r:   r?     s      z&CanineModel.forward.<locals>.<genexpr>)r.   r/   r0   r1   )!rm   r   r  use_return_dictrj   Z%warn_if_padding_and_no_attention_maskr   r   r6   r8  r   r   Zget_extended_attention_maskr?  r   ri   Zget_head_maskr  r2  r;  rC   r.   rD   r@   r3  rD  r   rI   rF   r0   r1   r   r-   )"r   r   r   r   rx   r   r   r   r  r  r  r  r   r9  r   r   Zextended_attention_maskr>  Z extended_molecule_attention_maskZinput_char_embeddingsr<  Zinit_chars_encoder_outputsZinput_char_encodingZinit_molecule_encodingZencoder_outputsZmolecule_sequence_outputr   Zrepeated_moleculesconcatr)  Zfinal_chars_encoder_outputsZdeep_encoder_hidden_statesZdeep_encoder_self_attentionsr   r9   r9   r:   r   L  s    

  	
 

zCanineModel.forward)T)	NNNNNNNNN)r2   r3   r4   r~   r6  r;  r6   r   rf   r?  rD  r   CANINE_INPUTS_DOCSTRINGformatr   _CHECKPOINT_FOR_DOCr-   _CONFIG_FOR_DOCr   r   r7   r   r   r   r   r   r9   r9   r   r:   r/    s@            
r/  z
    CANINE Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej eej eej ee ee ee eee	f dddZ  ZS )
CanineForSequenceClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r   r}   r~   
num_labelsr/  r+  r   r   r   r   r   r   
classifierr4  r   r   r9   r:   r~     s    
z(CanineForSequenceClassification.__init__rE  rF  Nr   r   r   rx   r   r   labelsr   r  r  r   c                 C   s|  |
dk	r|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dk	r8| j jdkr| jdkrzd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }| jdkr|| | }n
|||}nN| j jdkrt }||d| j|d}n| j jdkr8t }|||}|
sh|f|dd  }|dk	rd|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   rx   r   r   r   r  r  r   Z
regressionZsingle_label_classificationZmulti_label_classificationry   rN   losslogitsr0   r1   )rm   rH  r+  r   rQ  Zproblem_typerP  r   r6   r   rf   r	   r=  r   r   r   r   r0   r1   )r   r   r   r   rx   r   r   rS  r   r  r  r   r   rW  rV  loss_fctr   r9   r9   r:   r     sV    




"


z'CanineForSequenceClassification.forward)
NNNNNNNNNN)r2   r3   r4   r~   r   rJ  rK  r   rL  r   rM  r   r6   r   r7   r   r   r   r   r   r9   r9   r   r:   rN    s<             
rN  z
    CANINE Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    c                       s   e Zd Z fddZeedeee	e
dd	eej eej eej eej eej eej eej ee ee ee eee	f dddZ  ZS )
CanineForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S r  )r}   r~   r/  r+  r   r   r   r   r   r   rQ  r4  r   r   r9   r:   r~   Z  s
    
z CanineForMultipleChoice.__init__z(batch_size, num_choices, sequence_lengthrF  NrR  c                 C   st  |
dk	r|
n| j j}
|dk	r&|jd n|jd }|dk	rJ|d|dnd}|dk	rh|d|dnd}|dk	r|d|dnd}|dk	r|d|dnd}|dk	r|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|dk	r0t }|||}|
s`|f|dd  }|dk	r\|f| S |S t	|||j
|jdS )aJ  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   ry   rE   rT  rN   rU  )rm   rH  ri   r   r   r+  r   rQ  r   r   r0   r1   )r   r   r   r   rx   r   r   rS  r   r  r  Znum_choicesr   r   rW  Zreshaped_logitsrV  rX  r   r9   r9   r:   r   d  sL    



zCanineForMultipleChoice.forward)
NNNNNNNNNN)r2   r3   r4   r~   r   rJ  rK  r   rL  r   rM  r   r6   r   r7   r   r   r   r   r   r9   r9   r   r:   rY  R  s<   
          
rY  z
    CANINE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeef dddZ  ZS )
CanineForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r   rO  r   r   r9   r:   r~     s    
z%CanineForTokenClassification.__init__rE  )rG  r.  NrR  c                 C   s   |
dk	r|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dk	rxt }||d| j|d}|
s|f|dd  }|dk	r|f| S |S t|||j	|j
dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, CanineForTokenClassification
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/canine-s")
        >>> model = CanineForTokenClassification.from_pretrained("google/canine-s")

        >>> inputs = tokenizer(
        ...     "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
        ... )

        >>> with torch.no_grad():
        ...     logits = model(**inputs).logits

        >>> predicted_token_class_ids = logits.argmax(-1)

        >>> # Note that tokens are classified rather then input words which means that
        >>> # there might be more predicted token classes than words.
        >>> # Multiple token classes might account for the same word
        >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
        >>> predicted_tokens_classes  # doctest: +SKIP
        ```

        ```python
        >>> labels = predicted_token_class_ids
        >>> loss = model(**inputs, labels=labels).loss
        >>> round(loss.item(), 2)  # doctest: +SKIP
        ```NrT  r   ry   rN   rU  )rm   rH  r+  r   rQ  r   r   rP  r   r0   r1   )r   r   r   r   rx   r   r   rS  r   r  r  r   r)  rW  rV  rX  r   r9   r9   r:   r     s8    3

z$CanineForTokenClassification.forward)
NNNNNNNNNN)r2   r3   r4   r~   r   rJ  rK  r   r   rM  r   r6   r   r7   r   r   r   r   r   r9   r9   r   r:   rZ    s4   
          
rZ  z
    CANINE Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    c                       s   e Zd Z fddZeededee	dddde
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeef d	d
dZ  ZS )CanineForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r   )
r}   r~   rP  r/  r+  r   r   r   
qa_outputsr4  r   r   r9   r:   r~     s
    
z#CanineForQuestionAnswering.__init__rE  zSplend1dchan/canine-c-squadz'nice puppet'gQ!@)r  rG  r.  Zexpected_outputZexpected_lossN)r   r   r   rx   r   r   start_positionsend_positionsr   r  r  r   c                 C   sD  |dk	r|n| j j}| j|||||||	|
|d	}|d }| |}|jddd\}}|d}|d}d}|dk	r|dk	rt| dkr|d}t| dkr|d}|d}|d| |d| t	|d}|||}|||}|| d }|s.||f|dd  }|dk	r*|f| S |S t
||||j|jd	S )
a  
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        NrT  r   r   ry   r   )Zignore_indexrN   )rV  start_logits
end_logitsr0   r1   )rm   rH  r+  r\  r^   r=  re   r   Zclamp_r   r   r0   r1   )r   r   r   r   rx   r   r   r]  r^  r   r  r  r   r)  rW  r_  r`  Z
total_lossZignored_indexrX  Z
start_lossZend_lossr   r9   r9   r:   r   )  sP     








z"CanineForQuestionAnswering.forward)NNNNNNNNNNN)r2   r3   r4   r~   r   rJ  rK  r   r   rM  r   r6   r   r7   r   r   r   r   r   r9   r9   r   r:   r[    sD   
	           
r[  )Hr5   r0  r   rW   dataclassesr   typingr   r   r   r6   Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zactivationsr   Zmodeling_outputsr   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r  r   r   r   r   r   Zconfiguration_caniner   Z
get_loggerr2   rU   rL  rM  Z$CANINE_PRETRAINED_MODEL_ARCHIVE_LISTr   r-   rt   Moduleru   r   r   r   r   r   r   r  r  r  r  r!  r#  r'  r*  ZCANINE_START_DOCSTRINGrJ  r/  rN  rY  rZ  r[  r9   r9   r9   r:   <module>   s    
$"ae.:d :R 2  ZTc