U
    9%e1}                     @   s^  d Z ddlZddlZddlmZmZ ddlZddlZddl	Zddlm
Z
mZ ddlmZmZmZ ddlmZ ddlmZmZmZmZ dd	lmZ dd
lmZmZmZmZmZ ddlm Z  ddl!m"Z" e#e$Z%dZ&dZ'ddddgZ(dZ)dZ*dgZ+dAeee,f dddZ-G dd dej.Z/G dd dej0Z1G dd dej2Z3G dd dej4Z5G dd dej2Z6dBej
e7e,ej
d"d#d$Z8G d%d& d&ej2Z9dCd(d)Z:G d*d+ d+ej2Z;G d,d- d-ej2Z<G d.d/ d/ej2Z=G d0d1 d1ej2Z>G d2d3 d3ej2Z?G d4d5 d5eZ@d6ZAd7ZBed8eAG d9d: d:e@ZCed;eAG d<d= d=e@ZDed>eAG d?d@ d@e@e ZEdS )Dz: PyTorch BiT model. Also supports backbone for ViT hybrid.    N)OptionalTuple)Tensornn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutputBaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttention)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings)BackboneMixin   )	BitConfigr   zgoogle/bit-50i      z	tiger catreturnc                 C   s   d}| dkr,|d ||d   d } | |fS t | tr|  } | dkr|dkr|||d  d dkr||d ||d   d } qd} d}n&| dkrd} n|d ||d   d } | |fS )	al  
    Utility function to get the tuple padding value given the kernel_size and padding.

    Args:
        padding (Union[`str`, `int`], *optional*):
            Padding value, can be either `"same"`, `"valid"`. If a different value is provided the default padding from
            PyTorch is used.
        kernel_size (`int`, *optional*, defaults to 7):
            Kernel size of the convolution layers.
        stride (`int`, *optional*, defaults to 1):
            Stride value of the convolution layers.
        dilation (`int`, *optional*, defaults to 1):
            Dilation value of the convolution layers.
    FNr      Zsamer   TZvalid)
isinstancestrlower)paddingkernel_sizestridedilationZdynamic r#   c/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/bit/modeling_bit.pyget_padding_valueA   s    
r%   c                       s*   e Zd ZdZd
 fdd	Zdd	 Z  ZS )WeightStandardizedConv2dzConv2d with Weight Standardization. Includes TensorFlow compatible SAME padding. Used for ViT Hybrid model.

    Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight
    Standardization](https://arxiv.org/abs/1903.10520v2)
    r   SAMEFư>c
              
      sT   t ||||d\}}
t j||||||||d |
rDt|||| _nd | _|	| _d S )N)r!   r"   )r!   r   r"   groupsbias)r%   super__init__DynamicPad2dpadeps)selfZ
in_channelout_channelsr    r!   r   r"   r)   r*   r/   Z
is_dynamic	__class__r#   r$   r,   q   s    
z!WeightStandardizedConv2d.__init__c              	   C   sj   | j d k	r|  |}tjj| jd| jdd d dd| jd| j}tj	||| j
| j| j| j| j}|S )Nr   T        )trainingZmomentumr/   )r.   r   
functionalZ
batch_normweightZreshaper1   r/   Z
reshape_asZconv2dr*   r!   r   r"   r)   )r0   hidden_stater8   r#   r#   r$   forward   s,    

           z WeightStandardizedConv2d.forward)r   r'   r   r   Fr(   __name__
__module____qualname____doc__r,   r:   __classcell__r#   r#   r2   r$   r&   j   s         r&   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )	BitGroupNormActivationzQ
    A module that combines group normalization with an activation function.
    h㈵>Tc                    s:   t t| j|j|||d |r,t|j | _n
t | _d S )N)r/   affine)	r+   rA   r,   
num_groupsr
   
hidden_act
activationr   Identity)r0   confignum_channelsr/   rC   apply_activationr2   r#   r$   r,      s    zBitGroupNormActivation.__init__c                 C   s*   t j|| j| j| j| j}| |}|S N)r   r7   Z
group_normrD   r8   r*   r/   rF   )r0   r9   r#   r#   r$   r:      s    
zBitGroupNormActivation.forward)rB   TTr;   r#   r#   r2   r$   rA      s   rA   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )r-   z
    A module that wraps dynamic padding of any input, given the parameters of the convolutional layer and the input
    hidden states.
    r   c                    sj   t    t|tr||f}t|tr.||f}t|tr@||f}|| _|| _|| _|| _dd }|| _d S )Nc                 S   s0   t t| | d | |d |  d |  dS )Nr   r   )maxmathceil)xr    r!   r"   r#   r#   r$   compute_padding   s    z.DynamicPad2d.__init__.<locals>.compute_padding)	r+   r,   r   intr    r!   r"   valuerP   )r0   r    r!   r"   rR   rP   r2   r#   r$   r,      s    



zDynamicPad2d.__init__c                 C   s   |  dd  \}}| || jd | jd | jd }| || jd | jd | jd }|dksh|dkrtjj||d ||d  |d ||d  g| jd}|S )Nr   r   r   )rR   )	sizerP   r    r!   r"   r   r7   r.   rR   )r0   inputZinput_heightZinput_widthpadding_heightpadding_widthr#   r#   r$   __call__   s    ""


zDynamicPad2d.__call__)r   )r<   r=   r>   r?   r,   rX   r@   r#   r#   r2   r$   r-      s   r-   c                       s0   e Zd ZdZded fd	d
Zdd Z  ZS )BitMaxPool2dz1Tensorflow like 'SAME' wrapper for 2D max poolingNr   Fr   r   r   T)r    c                    s   t |tjjr|n||f}t |tjjr,|n||f}t |tjjrF|n||f}t ||||| |rxt||||| _n
t	 | _d S rK   )
r   collectionsabcIterabler+   r,   r-   r.   r   rG   )r0   r    r!   r"   	ceil_moder   Zpadding_valueuse_dynamic_paddingr2   r#   r$   r,      s    
zBitMaxPool2d.__init__c                 C   s*   |  |}tj|| j| j| j| j| jS rK   )	r.   r   r7   Z
max_pool2dr    r!   r   r"   r^   r0   hidden_statesr#   r#   r$   r:      s    
     zBitMaxPool2d.forward)Nr   FrZ   r   T)r<   r=   r>   r?   rQ   r,   r:   r@   r#   r#   r2   r$   rY      s         rY   c                       s6   e Zd ZdZed fddZeedddZ  ZS )BitEmbeddingszL
    BiT Embeddings (stem) composed of a single aggressive convolution.
    rH   c                    s   t    t|j|jddd|jd| _tdd|jd| _	|jd k	r\|j
 dkr\t | _ntjdd	d
| _|jdkst||jd| _n
t | _|j| _d S )Nr   r   :0yE>)r    r!   r/   r   r	   )r    r!   r_   r'   )r   r   r   r   r5   )r   rR   preactivationrI   )r+   r,   r&   rI   embedding_sizeglobal_paddingconvolutionrY   Zembedding_dynamic_paddingpoolerupperr   rG   r.   ZConstantPad2d
layer_typerA   normr0   rH   r2   r#   r$   r,     s"    
	

zBitEmbeddings.__init__)pixel_valuesr   c                 C   sH   |j d }|| jkrtd| |}| |}| |}| |}|S )Nr   zeMake sure that the channel dimension of the pixel values match with the one set in the configuration.)shaperI   
ValueErrorri   r.   rm   rj   )r0   ro   rI   Z	embeddingr#   r#   r$   r:     s    





zBitEmbeddings.forward)	r<   r=   r>   r?   r   r,   r   r:   r@   r#   r#   r2   r$   rb      s   rb   r5   F)rU   	drop_probr6   r   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r5   r   r   )r   )dtypedevice)rp   ndimtorchZrandrs   rt   Zfloor_div)rU   rr   r6   Z	keep_probrp   Zrandom_tensoroutputr#   r#   r$   	drop_path/  s    
ry   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )BitDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)rr   r   c                    s   t    || _d S rK   )r+   r,   rr   )r0   rr   r2   r#   r$   r,   G  s    
zBitDropPath.__init__)ra   r   c                 C   s   t || j| jS rK   )ry   rr   r6   r`   r#   r#   r$   r:   K  s    zBitDropPath.forwardr   c                 C   s   d | jS )Nzp={})formatrr   )r0   r#   r#   r$   
extra_reprN  s    zBitDropPath.extra_repr)N)r<   r=   r>   r?   r   floatr,   rv   r   r:   r   r|   r@   r#   r#   r2   r$   rz   D  s   rz      c                 C   s:   |}t |t| |d  | | }|d|  k r6||7 }|S )Nr   g?)rL   rQ   )rR   ZdivisorZ	min_value	new_valuer#   r#   r$   make_divR  s
    r   c                       s*   e Zd ZdZd fdd	Zd	d
 Z  ZS )BitPreActivationBottleneckLayera  Pre-activation (v2) bottleneck block.
    Follows the implementation of "Identity Mappings in Deep Residual Networks":
    https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua

    Except it puts the stride on 3x3 conv when available.
    N      ?r   r5   Fc              	      s   t    |p|}|p|}t|| }|
r@t||||dd| _nd | _t||| _t||dd|jd| _	t||d| _
t||d||d|jd| _t||| _t||dd|jd| _|	d	krt|	nt | _d S )
NTr!   preactr   rd   r/   r   rf   r	   )r!   r)   r/   r   r   )r+   r,   r   BitDownsampleConv
downsamplerA   norm1r&   rh   conv1norm2conv2norm3conv3rz   r   rG   ry   )r0   rH   in_channelsr1   bottle_ratior!   r"   first_dilationr)   drop_path_rateis_first_layerZmid_channelsr2   r#   r$   r,   b  s8    

      z(BitPreActivationBottleneckLayer.__init__c                 C   s^   |  |}|}| jd k	r"| |}| |}| | |}| | |}| |}|| S rK   )r   r   r   r   r   r   r   ry   )r0   ra   Zhidden_states_preactshortcutr#   r#   r$   r:     s    




z'BitPreActivationBottleneckLayer.forward)Nr   r   r   Nr   r5   Fr;   r#   r#   r2   r$   r   Z  s           ,r   c                       s*   e Zd ZdZd fdd	Zd	d
 Z  ZS )BitBottleneckLayerz\Non Pre-activation bottleneck block, equivalent to V1.5/V1b bottleneck. Used for ViT Hybrid.Nr   r   r5   Fc              
      s   t    |p|}|p|}t|| }|
r@t||||dd| _nd | _t||dd|jd| _t||d| _	t||d|||d|jd| _
t||d| _t||dd|jd| _t||dd	| _|	d
krt|	nt | _t|j | _d S )NFr   r   rd   r   rf   r	   )r!   r"   r)   r/   r   rI   rJ   r   )r+   r,   r   r   r   r&   rh   r   rA   r   r   r   r   r   rz   r   rG   ry   r
   rE   rF   )r0   rH   r   r1   r   r!   r"   r   r)   r   r   Zmid_chsr2   r#   r$   r,     s<    


zBitBottleneckLayer.__init__c                 C   sp   |}| j d k	r|  |}| |}| |}| |}| |}| |}| |}| |}| || }|S rK   )	r   r   r   r   r   r   r   ry   rF   )r0   ra   r   r#   r#   r$   r:     s    








zBitBottleneckLayer.forward)Nr   r   r   Nr   r5   Fr;   r#   r#   r2   r$   r     s           1r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )r   r   Tc                    sB   t    t||d|d|jd| _|r.t nt||dd| _d S )Nr   rd   )r!   r/   r   Fr   )	r+   r,   r&   rh   convr   rG   rA   rm   )r0   rH   r   r1   r!   r   r2   r#   r$   r,     s    
     
zBitDownsampleConv.__init__c                 C   s   |  | |S rK   )rm   r   )r0   rO   r#   r#   r$   r:     s    zBitDownsampleConv.forward)r   T)r<   r=   r>   r,   r:   r@   r#   r#   r2   r$   r     s     r   c                       s:   e Zd ZdZd fdd	Zdd Zeedd	d
Z  ZS )BitStagez7
    A ResNet v2 stage composed by stacked layers.
    r   Nc	                    s   t    |dkrdnd}	|jdkr*t}
nt}
|}t | _t|D ]H}| 	|||\}}}| j
t||
|||||||	||d	 |}|}	qDd S )N)r   r   r   r   Z
bottleneck)r!   r"   r   r   r   r   )r+   r,   rl   r   r   r   
Sequentiallayersrange_get_updated_hyperparameters
add_moduler   )r0   rH   r   r1   r!   r"   depthr   layer_dropoutr   Z	layer_clsprev_chs	layer_idxr   r   r2   r#   r$   r,     s:    


  
zBitStage.__init__c                 C   s0   |r|| }nd}|dkrd}|dk}|||fS )zt
        Get the new hyper-parameters with respect to the previous ones and the index of the current layer.
        r5   r   r   r#   )r0   r   r!   r   r   r   r#   r#   r$   r   1  s    
z%BitStage._get_updated_hyperparameters)rU   r   c                 C   s$   |}t | jD ]\}}||}q|S rK   )	enumerater   )r0   rU   r9   _layerr#   r#   r$   r:   A  s    
zBitStage.forward)r   N)	r<   r=   r>   r?   r,   r   r   r:   r@   r#   r#   r2   r$   r     s     .r   c                       s@   e Zd Zed fddZdd Zdeeeedd	d
Z	  Z
S )
BitEncoderrc   c              
      s   t    tg | _|j}d}d}dd tt	d|j
t|j|jD }tt|j|j|D ]Z\}\}}}	| |||||\}
}}t|||
||||	d}|
}||9 }| jt|| qfd S )N   r   c                 S   s   g | ]}|  qS r#   )tolist).0rO   r#   r#   r$   
<listcomp>S  s   z'BitEncoder.__init__.<locals>.<listcomp>r   )r!   r"   r   r   )r+   r,   r   Z
ModuleListstagesrg   rv   r   npZlinspacer   sumZdepthssplitr   ziphidden_sizesr   r   r   r   )r0   rH   r   current_strider"   Zlayer_dropouts	stage_idxZcurrent_depthcurrent_hidden_sizer   r1   r!   stager2   r#   r$   r,   I  s<    
"    

zBitEncoder.__init__c                 C   s>   t ||j }|dkrdnd}||jkr4||9 }d}|||fS )Nr   r   r   )r   Zwidth_factorZoutput_stride)r0   r   r   r   r"   rH   r1   r!   r#   r#   r$   r   o  s    
z'BitEncoder._get_updated_hyperparametersFT)r9   output_hidden_statesreturn_dictr   c                 C   sb   |rdnd }| j D ]}|r$||f }||}q|r<||f }|sVtdd ||fD S t||dS )Nr#   c                 s   s   | ]}|d k	r|V  qd S rK   r#   )r   vr#   r#   r$   	<genexpr>  s      z%BitEncoder.forward.<locals>.<genexpr>)last_hidden_statera   )r   tupler   )r0   r9   r   r   ra   Zstage_moduler#   r#   r$   r:   w  s    



zBitEncoder.forward)FT)r<   r=   r>   r   r,   r   r   boolr   r:   r@   r#   r#   r2   r$   r   H  s   &	     r   c                   @   s2   e Zd ZdZeZdZdZdZdd Z	ddd	Z
d
S )BitPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    bitro   Tc                 C   sX   t |tjr"tjj|jddd n2t |tjtjfrTtj|jd tj|j	d d S )NZfan_outZrelu)modeZnonlinearityr   r   )
r   r   Conv2dinitZkaiming_normal_r8   ZBatchNorm2d	GroupNormZ	constant_r*   )r0   moduler#   r#   r$   _init_weights  s
    z BitPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S rK   )r   BitModelZgradient_checkpointing)r0   r   rR   r#   r#   r$   _set_gradient_checkpointing  s    
z.BitPreTrainedModel._set_gradient_checkpointingN)F)r<   r=   r>   r?   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r   r#   r#   r#   r$   r     s   r   aE  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`BitConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aA  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`BitImageProcessor.__call__`]
            for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zLThe bare BiT model outputting raw features without any specific head on top.c                       sT   e Zd Z fddZeeeeee	de
dd	eee ee edddZ  ZS )
r   c                    sd   t  | || _t|| _t|| _|jdkrBt||j	d dnt
 | _t
d| _|   d S )Nre   r4   rf   )r   r   )r+   r,   rH   rb   embedderr   encoderrl   rA   r   r   rG   rm   ZAdaptiveAvgPool2drj   	post_initrn   r2   r#   r$   r,     s    

zBitModel.__init__Zvision)
checkpointoutput_typer   Zmodalityexpected_outputNro   r   r   r   c                 C   s   |d k	r|n| j j}|d k	r |n| j j}| |}| j|||d}|d }| |}| |}|sv||f|dd   S t|||jdS )Nr   r   r   r   )r   pooler_outputra   )	rH   r   use_return_dictr   r   rm   rj   r   ra   )r0   ro   r   r   Zembedding_outputZencoder_outputsr   pooled_outputr#   r#   r$   r:     s&    
  

zBitModel.forward)NN)r<   r=   r>   r,   r   BIT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r   r   r:   r@   r#   r#   r2   r$   r     s"        r   z
    BiT Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    c                	       s`   e Zd Z fddZeeeeee	e
ddeej eej ee ee edddZ  ZS )	BitForImageClassificationc                    s^   t  | |j| _t|| _tt |jdkrFt|j	d |jnt
 | _|   d S )Nr   r4   )r+   r,   
num_labelsr   r   r   r   ZFlattenZLinearr   rG   
classifierr   rn   r2   r#   r$   r,     s    
$z"BitForImageClassification.__init__)r   r   r   r   N)ro   labelsr   r   r   c                 C   sl  |dk	r|n| j j}| j|||d}|r.|jn|d }| |}d}|dk	r,| j jdkr| jdkrnd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }	| jdkr|	| | }n
|	||}nN| j jdkrt }	|	|d| j|d}n| j jdkr,t }	|	||}|s\|f|dd  }
|dk	rX|f|
 S |
S t|||jd	S )
a0  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr4   r   )losslogitsra   )rH   r   r   r   r   Zproblem_typer   rs   rv   longrQ   r   Zsqueezer   viewr   r   ra   )r0   ro   r   r   r   outputsr   r   r   Zloss_fctrx   r#   r#   r$   r:     s6    



"


z!BitForImageClassification.forward)NNNN)r<   r=   r>   r,   r   r   r   _IMAGE_CLASS_CHECKPOINTr   r   _IMAGE_CLASS_EXPECTED_OUTPUTr   rv   ZFloatTensorZ
LongTensorr   r:   r@   r#   r#   r2   r$   r     s&       r   zL
    BiT backbone, to be used with frameworks like DETR and MaskFormer.
    c                       sN   e Zd Z fddZeeeeedde	e
e e
e edddZ  ZS )	BitBackbonec                    s>   t  | t  | t|| _|jg|j | _|   d S rK   )	r+   r,   Z_init_backboner   r   rg   r   Znum_featuresr   rn   r2   r#   r$   r,   N  s
    
zBitBackbone.__init__)r   r   Nr   c           
      C   s   |dk	r|n| j j}|dk	r |n| j j}| j|ddd}|j}d}t| jD ] \}}|| jkrL||| f7 }qL|s|f}	|r|	|jf7 }	|	S t||r|jndddS )al  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("google/resnetnv2-50")
        >>> model = AutoBackbone.from_pretrained("google/resnetnv2-50")

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        ```NTr   r#   )feature_mapsra   Z
attentions)	rH   r   r   r   ra   r   Zstage_namesZout_featuresr   )
r0   ro   r   r   r   ra   r   idxr   rx   r#   r#   r$   r:   X  s&    
zBitBackbone.forward)NN)r<   r=   r>   r,   r   r   r   r   r   r   r   r   r:   r@   r#   r#   r2   r$   r   G  s   

     r   )Nr   r   r   )r5   F)r~   )Fr?   r[   rM   typingr   r   numpyr   rv   Ztorch.utils.checkpointr   r   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   utilsr   r   r   r   r   Zutils.backbone_utilsr   Zconfiguration_bitr   Z
get_loggerr<   loggerr   r   r   r   r   Z!BIT_PRETRAINED_MODEL_ARCHIVE_LISTr   r%   r   r&   r   rA   Moduler-   Z	MaxPool2drY   rb   r}   ry   rz   r   r   r   r   r   r   r   ZBIT_START_DOCSTRINGr   r   r   r   r#   r#   r#   r$   <module>   sl   
)033
DIJF8F