U
    ,-es                     @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZmZmZ ddlmZmZmZmZ dd	lmZ dd
lmZmZmZmZ ddlmZ eeZdZ dZ!dddgZ"dZ#dZ$dgZ%eG dd deZ&G dd de
j'Z(G dd de
j'Z)G dd de
j'Z*G dd de
j'Z+G dd de
j'Z,G dd de
j'Z-G d d! d!e
j'Z.G d"d# d#e
j'Z/G d$d% d%e
j'Z0G d&d' d'e
j'Z1G d(d) d)e
j'Z2G d*d+ d+eZ3d,Z4d-Z5ed.e4G d/d0 d0e3Z6ed1e4G d2d3 d3e3Z7ed4e4G d5d6 d6e3Z8dS )7z PyTorch LeViT model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )BaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttentionModelOutput)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )LevitConfigr   zfacebook/levit-128S   i  ztabby, tabby catc                   @   sR   e Zd ZU dZdZejed< dZejed< dZ	ejed< dZ
eeej  ed< dS ),LevitForImageClassificationWithTeacherOutputa  
    Output type of [`LevitForImageClassificationWithTeacher`].

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores as the average of the `cls_logits` and `distillation_logits`.
        cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
            class token).
        distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
            distillation token).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
    Nlogits
cls_logitsdistillation_logitshidden_states)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r   r   r    r#   r#   i/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/levit/modeling_levit.pyr   8   s
   
r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )LevitConvEmbeddingsz[
    LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer.
    r   c	           	   
      s6   t    tj|||||||dd| _t|| _d S )NF)dilationgroupsbias)super__init__r   Conv2dconvolutionBatchNorm2d
batch_norm)	selfZin_channelsZout_channelskernel_sizestridepaddingr&   r'   bn_weight_init	__class__r#   r$   r*   W   s    
       zLevitConvEmbeddings.__init__c                 C   s   |  |}| |}|S N)r,   r.   )r/   
embeddingsr#   r#   r$   forward`   s    

zLevitConvEmbeddings.forward)r   r   r   r   r   r   r   r*   r8   __classcell__r#   r#   r4   r$   r%   R   s        	r%   c                       s(   e Zd ZdZ fddZdd Z  ZS )LevitPatchEmbeddingsz
    LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple
    `LevitConvEmbeddings`.
    c                    s   t    t|j|jd d |j|j|j| _t	
 | _t|jd d |jd d |j|j|j| _t	
 | _t|jd d |jd d |j|j|j| _t	
 | _t|jd d |jd |j|j|j| _|j| _d S )Nr            )r)   r*   r%   num_channelshidden_sizesr0   r1   r2   embedding_layer_1r   	Hardswishactivation_layer_1embedding_layer_2activation_layer_2embedding_layer_3activation_layer_3embedding_layer_4r/   configr4   r#   r$   r*   l   sB    
    
    
    
    zLevitPatchEmbeddings.__init__c                 C   st   |j d }|| jkrtd| |}| |}| |}| |}| |}| |}| 	|}|
dddS )Nr   zeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r>   )shaper?   
ValueErrorrA   rC   rD   rE   rF   rG   rH   flatten	transpose)r/   pixel_valuesr?   r7   r#   r#   r$   r8      s    








zLevitPatchEmbeddings.forwardr9   r#   r#   r4   r$   r;   f   s   r;   c                       s&   e Zd Zd fdd	Zdd Z  ZS )MLPLayerWithBNr   c                    s,   t    tj||dd| _t|| _d S )NF)Zin_featuresZout_featuresr(   )r)   r*   r   LinearlinearBatchNorm1dr.   )r/   	input_dim
output_dimr3   r4   r#   r$   r*      s    
zMLPLayerWithBN.__init__c                 C   s&   |  |}| |dd|}|S )Nr   r   )rR   r.   rM   Z
reshape_asr/   hidden_stater#   r#   r$   r8      s    
zMLPLayerWithBN.forward)r   r   r   r   r*   r8   r:   r#   r#   r4   r$   rP      s   rP   c                       s$   e Zd Z fddZdd Z  ZS )LevitSubsamplec                    s   t    || _|| _d S r6   )r)   r*   r1   
resolution)r/   r1   rZ   r4   r#   r$   r*      s    
zLevitSubsample.__init__c                 C   sL   |j \}}}||| j| j|d d d d | jd d | jf |d|}|S )N)rK   viewrZ   r1   reshape)r/   rW   
batch_size_Zchannelsr#   r#   r$   r8      s      zLevitSubsample.forwardrX   r#   r#   r4   r$   rY      s   rY   c                       sB   e Zd Z fddZe d
 fdd	Zdd Zdd	 Z  Z	S )LevitAttentionc                    sB  t    || _|d | _|| _|| _|| | || d  | _|| | | _t|| j| _	t
 | _t| j|dd| _ttt|t|}t|}i g  }}	|D ]X}
|D ]N}t|
d |d  t|
d |d  f}||krt|||< |	||  qqi | _tj
t|t|| _| jdt|	||dd d S )	N      r>   r   )r3   r   attention_bias_idxsF
persistent)r)   r*   num_attention_headsscalekey_dimattention_ratioout_dim_keys_valuesout_dim_projectionrP   queries_keys_valuesr   rB   
activation
projectionlist	itertoolsproductrangelenabsappendattention_bias_cacher    	Parameterzerosattention_biasesregister_buffer
LongTensorr\   )r/   r@   rg   re   rh   rZ   points
len_pointsattention_offsetsindicesp1p2offsetr4   r#   r$   r*      s4    



(  zLevitAttention.__init__Tc                    s    t  | |r| jri | _d S r6   r)   trainru   r/   moder4   r#   r$   r      s    
zLevitAttention.trainc                 C   sT   | j r| jd d | jf S t|}|| jkrF| jd d | jf | j|< | j| S d S r6   trainingrx   rb   strru   r/   deviceZ
device_keyr#   r#   r$   get_attention_biases   s    
z#LevitAttention.get_attention_biasesc           
      C   s   |j \}}}| |}|||| jdj| j| j| j| j gdd\}}}|dddd}|dddd}|dddd}||dd | j	 | 
|j }	|	jdd}	|	| dd||| j}| | |}|S Nr[   r
   dimr   r>   r   )rK   rk   r\   re   splitrg   rh   permuterN   rf   r   r   softmaxr]   rj   rm   rl   )
r/   rW   r^   
seq_lengthr_   rk   querykeyvalue	attentionr#   r#   r$   r8      s    
 "zLevitAttention.forward)T
r   r   r   r*   r    Zno_gradr   r   r8   r:   r#   r#   r4   r$   r`      s
   	r`   c                       sB   e Zd Z fddZe d
 fdd	Zdd Zdd	 Z  Z	S )LevitAttentionSubsamplec	                    s  t    || _|d | _|| _|| _|| | ||  | _|| | | _|| _t	|| j| _
t||| _t	||| | _t | _t	| j|| _i | _ttt|t|}	ttt|t|}
t|	t|
 }}i g  }}|
D ]~}|	D ]t}d}t|d | |d  |d d  t|d | |d  |d d  f}||krVt|||< |||  qqtjt|t|| _| jdt| ||dd d S )Nra   r   r   r>   rb   Frc   )!r)   r*   re   rf   rg   rh   ri   rj   resolution_outrP   keys_valuesrY   queries_subsamplequeriesr   rB   rl   rm   ru   rn   ro   rp   rq   rr   rs   rt   r    rv   rw   rx   ry   rz   r\   )r/   rT   rU   rg   re   rh   r1   resolution_inr   r{   Zpoints_r|   Zlen_points_r}   r~   r   r   sizer   r4   r#   r$   r*      s>    



H
  z LevitAttentionSubsample.__init__Tc                    s    t  | |r| jri | _d S r6   r   r   r4   r#   r$   r     s    
zLevitAttentionSubsample.trainc                 C   sT   | j r| jd d | jf S t|}|| jkrF| jd d | jf | j|< | j| S d S r6   r   r   r#   r#   r$   r     s    
z,LevitAttentionSubsample.get_attention_biasesc           	      C   s   |j \}}}| |||| jdj| j| j| j gdd\}}|dddd}|dddd}| | 	|}||| j
d | j| jdddd}||dd | j | |j }|jdd}|| dd|d| j}| | |}|S r   )rK   r   r\   re   r   rg   rh   r   r   r   r   rN   rf   r   r   r   r]   rj   rm   rl   )	r/   rW   r^   r   r_   r   r   r   r   r#   r#   r$   r8   '  s2    
       "zLevitAttentionSubsample.forward)Tr   r#   r#   r4   r$   r      s
   -	r   c                       s(   e Zd ZdZ fddZdd Z  ZS )LevitMLPLayerzE
    MLP Layer with `2X` expansion in contrast to ViT with `4X`.
    c                    s0   t    t||| _t | _t||| _d S r6   )r)   r*   rP   	linear_upr   rB   rl   linear_down)r/   rT   
hidden_dimr4   r#   r$   r*   B  s    

zLevitMLPLayer.__init__c                 C   s"   |  |}| |}| |}|S r6   )r   rl   r   rV   r#   r#   r$   r8   H  s    


zLevitMLPLayer.forwardr9   r#   r#   r4   r$   r   =  s   r   c                       s(   e Zd ZdZ fddZdd Z  ZS )LevitResidualLayerz"
    Residual Block for LeViT
    c                    s   t    || _|| _d S r6   )r)   r*   module	drop_rate)r/   r   r   r4   r#   r$   r*   T  s    
zLevitResidualLayer.__init__c                 C   sr   | j r\| jdkr\tj|ddd|jd}|| jd| j  }|| 	||  }|S || 	| }|S d S )Nr   r   )r   )
r   r   r    Zrandr   r   Zge_divdetachr   )r/   rW   Zrndr#   r#   r$   r8   Y  s    zLevitResidualLayer.forwardr9   r#   r#   r4   r$   r   O  s   r   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )
LevitStagezP
    LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers.
    c                    sH  t    g | _|| _|
| _t|D ]R}| jtt|||||
| jj	 |dkr$|| }| jtt
||| jj	 q$|	d dkr6| jd |	d  d | _| jt| jj||d  |	d |	d |	d |	d |
| jd | j| _|	d dkr6| jj|d  |	d  }| jtt
| jj|d  || jj	 t| j| _d S )	Nr   Z	Subsampler      r>   r
   )rg   re   rh   r1   r   r   r=   )r)   r*   layersrJ   r   rq   rt   r   r`   Zdrop_path_rater   r   r   r@   r   
ModuleList)r/   rJ   idxr@   rg   depthsre   rh   	mlp_ratiodown_opsr   r_   r   r4   r#   r$   r*   i  sN    
 zLevitStage.__init__c                 C   s   | j S r6   )r   )r/   r#   r#   r$   get_resolution  s    zLevitStage.get_resolutionc                 C   s   | j D ]}||}q|S r6   )r   )r/   rW   layerr#   r#   r$   r8     s    

zLevitStage.forward)r   r   r   r   r*   r   r8   r:   r#   r#   r4   r$   r   d  s   7r   c                       s*   e Zd ZdZ fddZdddZ  ZS )	LevitEncoderzC
    LeViT Encoder consisting of multiple `LevitStage` stages.
    c                    s   t    || _| jj| jj }g | _| jjdg tt	|j
D ]\}t|||j| |j| |j
| |j| |j| |j| |j| |
}| }| j| qDt| j| _d S )N )r)   r*   rJ   Z
image_sizeZ
patch_sizestagesr   rt   rq   rr   r   r   r@   rg   re   rh   r   r   r   r   )r/   rJ   rZ   Z	stage_idxstager4   r#   r$   r*     s*    
zLevitEncoder.__init__FTc                 C   sb   |rdnd }| j D ]}|r$||f }||}q|r<||f }|sVtdd ||fD S t||dS )Nr#   c                 s   s   | ]}|d k	r|V  qd S r6   r#   ).0vr#   r#   r$   	<genexpr>  s      z'LevitEncoder.forward.<locals>.<genexpr>)last_hidden_stater   )r   tupler   )r/   rW   output_hidden_statesreturn_dictZall_hidden_statesr   r#   r#   r$   r8     s    



zLevitEncoder.forward)FTr9   r#   r#   r4   r$   r     s   r   c                       s(   e Zd ZdZ fddZdd Z  ZS )LevitClassificationLayerz$
    LeViT Classification Layer
    c                    s(   t    t|| _t||| _d S r6   )r)   r*   r   rS   r.   rQ   rR   )r/   rT   rU   r4   r#   r$   r*     s    
z!LevitClassificationLayer.__init__c                 C   s   |  |}| |}|S r6   )r.   rR   )r/   rW   r   r#   r#   r$   r8     s    

z LevitClassificationLayer.forwardr9   r#   r#   r4   r$   r     s   r   c                   @   s2   e Zd ZdZeZdZdZdZdd Z	ddd	Z
d
S )LevitPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    levitrO   Tc                 C   sp   t |tjtjfr@|jjjd| jjd |j	dk	rl|j	j
  n,t |tjtjfrl|j	j
  |jjd dS )zInitialize the weightsg        )meanZstdNg      ?)
isinstancer   rQ   r+   weightdataZnormal_rJ   Zinitializer_ranger(   Zzero_rS   r-   Zfill_)r/   r   r#   r#   r$   _init_weights  s    
z"LevitPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r6   )r   
LevitModelZgradient_checkpointing)r/   r   r   r#   r#   r$   _set_gradient_checkpointing  s    
z0LevitPreTrainedModel._set_gradient_checkpointingN)F)r   r   r   r   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r   r#   r#   r#   r$   r     s   r   aG  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`LevitConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aC  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`LevitImageProcessor.__call__`] for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zNThe bare Levit model outputting raw features without any specific head on top.c                	       s^   e Zd Z fddZeeeeee	de
dd	ejee ee eeef dddZ  ZS )
r   c                    s2   t  | || _t|| _t|| _|   d S r6   )r)   r*   rJ   r;   patch_embeddingsr   encoder	post_initrI   r4   r#   r$   r*   !  s
    

zLevitModel.__init__Zvision)
checkpointoutput_typer   Zmodalityexpected_outputNrO   r   r   returnc                 C   s   |d k	r|n| j j}|d k	r |n| j j}|d kr8td| |}| j|||d}|d }|jdd}|s~||f|dd   S t|||jdS )Nz You have to specify pixel_valuesr   r   r   r   r   )r   Zpooler_outputr   )	rJ   r   use_return_dictrL   r   r   r   r   r   )r/   rO   r   r   r7   Zencoder_outputsr   Zpooled_outputr#   r#   r$   r8   )  s(    
zLevitModel.forward)NNN)r   r   r   r*   r   LEVIT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr    r!   r   boolr   r   r8   r:   r#   r#   r4   r$   r     s$   	   
r   z
    Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    c                
       sd   e Zd Z fddZeeeeee	e
ddejeej ee ee eeef dddZ  ZS )	LevitForImageClassificationc                    sX   t  | || _|j| _t|| _|jdkr@t|jd |jntj	
 | _|   d S Nr   r[   )r)   r*   rJ   
num_labelsr   r   r   r@   r    r   Identity
classifierr   rI   r4   r#   r$   r*   ]  s    
z$LevitForImageClassification.__init__r   r   r   r   N)rO   labelsr   r   r   c                 C   sl  |dk	r|n| j j}| j|||d}|d }|d}| |}d}|dk	r,| j jdkr| jdkrnd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }	| jdkr|	| | }n
|	||}nN| j jdkrt }	|	|d| j|d}n| j jdkr,t }	|	||}|s\|f|d	d  }
|dk	rX|f|
 S |
S t|||jd
S )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr[   r>   )lossr   r   )rJ   r   r   r   r   Zproblem_typer   Zdtyper    longintr	   Zsqueezer   r\   r   r   r   )r/   rO   r   r   r   outputssequence_outputr   r   Zloss_fctoutputr#   r#   r$   r8   m  s@    




"


z#LevitForImageClassification.forward)NNNN)r   r   r   r*   r   r   r   _IMAGE_CLASS_CHECKPOINTr   r   _IMAGE_CLASS_EXPECTED_OUTPUTr    r!   r   rz   r   r   r   r8   r:   r#   r#   r4   r$   r   U  s&       
r   ap  
    LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and
    a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning::
           This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
           supported.
    c                	       s\   e Zd Z fddZeeeeee	e
ddejee ee eeef dddZ  ZS )	&LevitForImageClassificationWithTeacherc                    s   t  | || _|j| _t|| _|jdkr@t|jd |jntj	
 | _|jdkrht|jd |jntj	
 | _|   d S r   )r)   r*   rJ   r   r   r   r   r@   r    r   r   r   classifier_distillr   rI   r4   r#   r$   r*     s    
z/LevitForImageClassificationWithTeacher.__init__r   Nr   c           
      C   s   |d k	r|n| j j}| j|||d}|d }|d}| || | }}|| d }|sv|||f|dd   }	|	S t||||jdS )Nr   r   r   r>   )r   r   r   r   )rJ   r   r   r   r   r   r   r   )
r/   rO   r   r   r   r   r   Zdistill_logitsr   r   r#   r#   r$   r8     s    
z.LevitForImageClassificationWithTeacher.forward)NNN)r   r   r   r*   r   r   r   r   r   r   r   r    r!   r   r   r   r   r8   r:   r#   r#   r4   r$   r     s"   
   
r   )9r   ro   dataclassesr   typingr   r   r   r    Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   utilsr   r   r   r   Zconfiguration_levitr   Z
get_loggerr   loggerr   r   r   r   r   Z#LEVIT_PRETRAINED_MODEL_ARCHIVE_LISTr   Moduler%   r;   rP   rY   r`   r   r   r   r   r   r   r   ZLEVIT_START_DOCSTRINGr   r   r   r   r#   r#   r#   r$   <module>   sd   

,>SE.5N	