U
    ,-e^                     @   s  d Z ddlZddlmZmZmZ ddlZddlZddlmZ ddl	m
Z
mZmZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZmZmZ ddlmZ eeZdZdZ ddddgZ!dZ"dZ#dgZ$dZ%dZ&ee'dddZ(d3ee'ef e)dddZ*G dd dej+Z,G dd dej-Z.G dd  d ej+Z/G d!d" d"ej+Z0G d#d$ d$ej+Z1G d%d& d&ej+Z2G d'd( d(ej+Z3G d)d* d*ej+Z4G d+d, d,eZ5ed-e%G d.d/ d/e5Z6ed0e%G d1d2 d2e5Z7dS )4z PyTorch EfficientNet model.    N)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttention)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )EfficientNetConfigr   zgoogle/efficientnet-b7i      ztabby, tabby cataN  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`EfficientNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aB  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`AutoImageProcessor.__call__`] for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
)confignum_channelsc                 C   sJ   | j }|| j9 }t|t||d  | | }|d| k rB||7 }t|S )z<
    Round number of filters based on depth multiplier.
       g?)Zdepth_divisorZwidth_coefficientmaxint)r   r   ZdivisorZnew_dim r   w/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/efficientnet/modeling_efficientnet.pyround_filtersV   s    
r   T)kernel_sizeadjustc                 C   sr   t | tr| | f} | d d | d d f}|rR|d d |d |d d |d fS |d |d |d |d fS dS )aJ  
    Utility function to get the tuple padding value for the depthwise convolution.

    Args:
        kernel_size (`int` or `tuple`):
            Kernel size of the convolution layers.
        adjust (`bool`, *optional*, defaults to `True`):
            Adjusts padding value to apply to right and bottom sides of the input.
    r   r   r   N)
isinstancer   )r   r   Zcorrectr   r   r   correct_pade   s    

$r!   c                       s:   e Zd ZdZed fddZejejdddZ  Z	S )EfficientNetEmbeddingszL
    A module that corresponds to the stem module of the original work.
    r   c                    sh   t    t|d| _tjdd| _tj|j| jddddd| _	tj
| j|j|jd	| _t|j | _d S )
N    )r   r   r   r   paddingr	   r   validFr   strider&   bias)epsmomentum)super__init__r   out_dimr   	ZeroPad2dr&   Conv2dr   convolutionBatchNorm2dbatch_norm_epsbatch_norm_momentum	batchnormr
   
hidden_act
activationselfr   	__class__r   r   r.   ~   s    
     zEfficientNetEmbeddings.__init__)pixel_valuesreturnc                 C   s,   |  |}| |}| |}| |}|S N)r&   r2   r6   r8   )r:   r=   featuresr   r   r   forward   s
    



zEfficientNetEmbeddings.forward)
__name__
__module____qualname____doc__r   r.   torchTensorrA   __classcell__r   r   r;   r   r"   y   s   r"   c                       s   e Zd Zd fdd	Z  ZS )	EfficientNetDepthwiseConv2dr   r	   r   Tzerosc	           
         s*   || }	t  j||	|||||||d	 d S )N)	in_channelsout_channelsr   r)   r&   dilationgroupsr*   padding_mode)r-   r.   )
r:   rK   Zdepth_multiplierr   r)   r&   rM   r*   rO   rL   r;   r   r   r.      s    z$EfficientNetDepthwiseConv2d.__init__)r   r	   r   r   r   TrJ   )rB   rC   rD   r.   rH   r   r   r;   r   rI      s          rI   c                       s@   e Zd ZdZeeeed fddZejej	dddZ
  ZS )EfficientNetExpansionLayerz_
    This corresponds to the expansion phase of each block in the original implementation.
    r   in_dimr/   r)   c                    sB   t    tj||dddd| _tj||jd| _t|j	 | _
d S )Nr   sameFrK   rL   r   r&   r*   )num_featuresr+   )r-   r.   r   r1   expand_convr3   r4   	expand_bnr
   r7   
expand_act)r:   r   rR   r/   r)   r;   r   r   r.      s    
z#EfficientNetExpansionLayer.__init__hidden_statesr>   c                 C   s"   |  |}| |}| |}|S r?   )rV   rW   rX   r:   rZ   r   r   r   rA      s    


z"EfficientNetExpansionLayer.forward)rB   rC   rD   rE   r   r   r.   rF   FloatTensorrG   rA   rH   r   r   r;   r   rP      s   rP   c                       sB   e Zd ZdZeeeeed fddZej	ej
dddZ  ZS )EfficientNetDepthwiseLayerzk
    This corresponds to the depthwise convolution phase of each block in the original implementation.
    r   rR   r)   r   adjust_paddingc                    sv   t    || _| jdkrdnd}t||d}tj|d| _t||||dd| _tj	||j
|jd| _t|j | _d S )	Nr   r'   rS   )r   r%   Fr(   rU   r+   r,   )r-   r.   r)   r!   r   r0   depthwise_conv_padrI   depthwise_convr3   r4   r5   depthwise_normr
   r7   depthwise_act)r:   r   rR   r)   r   r_   Zconv_padr&   r;   r   r   r.      s$    
      z#EfficientNetDepthwiseLayer.__init__rY   c                 C   s6   | j dkr| |}| |}| |}| |}|S )Nr   )r)   ra   rb   rc   rd   r[   r   r   r   rA      s    




z"EfficientNetDepthwiseLayer.forwardrB   rC   rD   rE   r   r   boolr.   rF   r\   rG   rA   rH   r   r   r;   r   r]      s   r]   c                       sB   e Zd ZdZd	eeeed fddZej	ej
dddZ  ZS )
EfficientNetSqueezeExciteLayerzl
    This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
    Fr   rR   
expand_dimexpandc                    s   t    |r|n|| _tdt||j | _tjdd| _	tj
| j| jddd| _tj
| j| jddd| _t|j | _t | _d S )Nr   )Zoutput_sizerS   )rK   rL   r   r&   )r-   r.   dimr   r   Zsqueeze_expansion_ratioZdim_ser   ZAdaptiveAvgPool2dsqueezer1   reducerj   r
   r7   
act_reduceZSigmoid
act_expand)r:   r   rR   ri   rj   r;   r   r   r.      s$    
z'EfficientNetSqueezeExciteLayer.__init__rY   c                 C   sF   |}|  |}| |}| |}| |}| |}t||}|S r?   )rl   rm   rn   rj   ro   rF   mul)r:   rZ   inputsr   r   r   rA     s    




z&EfficientNetSqueezeExciteLayer.forward)Fre   r   r   r;   r   rg      s   rg   c                       sH   e Zd ZdZeeeeeed fddZe	j
e	j
e	jdddZ  ZS )EfficientNetFinalBlockLayerz[
    This corresponds to the final phase of each block in the original implementation.
    r   rR   r/   r)   	drop_rateid_skipc                    sX   t    |dko| | _tj||dddd| _tj||j|jd| _	tj
|d| _d S )Nr   rS   FrT   r`   p)r-   r.   apply_dropoutr   r1   project_convr3   r4   r5   
project_bnDropoutdropout)r:   r   rR   r/   r)   rt   ru   r;   r   r   r.     s    
  z$EfficientNetFinalBlockLayer.__init__)
embeddingsrZ   r>   c                 C   s0   |  |}| |}| jr,| |}|| }|S r?   )ry   rz   rx   r|   )r:   r}   rZ   r   r   r   rA   *  s    


z#EfficientNetFinalBlockLayer.forwardrB   rC   rD   rE   r   r   floatrf   r.   rF   r\   rG   rA   rH   r   r   r;   r   rr     s        rr   c                
       sJ   e Zd ZdZeeeeeeeeed	 fddZe	j
e	jdddZ  ZS )EfficientNetBlocka  
    This corresponds to the expansion and depthwise convolution phase of each block in the original implementation.

    Args:
        config ([`EfficientNetConfig`]):
            Model configuration class.
        in_dim (`int`):
            Number of input channels.
        out_dim (`int`):
            Number of output channels.
        stride (`int`):
            Stride size to be used in convolution layers.
        expand_ratio (`int`):
            Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
        kernel_size (`int`):
            Kernel size for the depthwise convolution layer.
        drop_rate (`float`):
            Dropout rate to be used in the final phase of each block.
        id_skip (`bool`):
            Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
            of each block. Set to `True` for the first block of each stage.
        adjust_padding (`bool`):
            Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
            operation, set to `True` for inputs with odd input sizes.
    )	r   rR   r/   r)   expand_ratior   rt   ru   r_   c
                    s   t    || _| jdkrdnd| _|| }
| jrDt|||
|d| _t|| jrR|
n||||	d| _t|||
| jd| _	t
|| jr|
n|||||d| _d S )Nr   TFrQ   r^   rh   rs   )r-   r.   r   rj   rP   	expansionr]   rb   rg   squeeze_exciterr   
projection)r:   r   rR   r/   r)   r   r   rt   ru   r_   Zexpand_in_dimr;   r   r   r.   P  s@    
      zEfficientNetBlock.__init__rY   c                 C   s<   |}| j dkr| |}| |}| |}| ||}|S )Nr   )r   r   rb   r   r   )r:   rZ   r}   r   r   r   rA   y  s    



zEfficientNetBlock.forwardr~   r   r   r;   r   r   5  s   )r   c                       sF   e Zd ZdZed fddZd
ejee	 ee	 e
ddd	Z  ZS )EfficientNetEncoderz
    Forward propogates the embeddings through each EfficientNet block.

    Args:
        config ([`EfficientNetConfig`]):
            Model configuration class.
    r#   c                    s~  t    |_|j_fdd t|j}t fdd|jD }d}g }t|D ]}t	||j| }t	||j
| }|j| }	|j| }
|j| }t |j| D ]}|dkrdnd}|dkrdn|	}	|dkr|n|}||jkrdnd}|j| | }t||||	|
||||d		}|| |d7 }qqVt|_tj|t	|d
dddd_tj|j|j|jd_t|j _d S )Nc                    s   t t j|  S r?   )r   mathceildepth_coefficient)Zrepeats)r:   r   r   round_repeats  s    z3EfficientNetEncoder.__init__.<locals>.round_repeatsc                 3   s   | ]} |V  qd S r?   r   ).0n)r   r   r   	<genexpr>  s     z/EfficientNetEncoder.__init__.<locals>.<genexpr>r   TFr   )	r   rR   r/   r)   r   r   rt   ru   r_   i   rS   rT   r`   )r-   r.   r   r   lenrK   sumZnum_block_repeatsranger   rL   stridesZkernel_sizesZexpand_ratiosZdepthwise_paddingZdrop_connect_rater   appendr   Z
ModuleListblocksr1   top_convr3   
hidden_dimr4   r5   top_bnr
   r7   top_activation)r:   r   Znum_base_blocksZ
num_blocksZcurr_block_numr   irR   r/   r)   r   r   jru   r_   rt   blockr;   )r   r:   r   r.     s^    





  zEfficientNetEncoder.__init__FT)rZ   output_hidden_statesreturn_dictr>   c                 C   st   |r
|fnd }| j D ]}||}|r||f7 }q| |}| |}| |}|shtdd ||fD S t||dS )Nc                 s   s   | ]}|d k	r|V  qd S r?   r   )r   vr   r   r   r     s      z.EfficientNetEncoder.forward.<locals>.<genexpr>)last_hidden_staterZ   )r   r   r   r   tupler   )r:   rZ   r   r   Zall_hidden_statesr   r   r   r   rA     s    



zEfficientNetEncoder.forward)FT)rB   rC   rD   rE   r   r.   rF   r\   r   rf   r   rA   rH   r   r   r;   r   r     s   :  r   c                   @   s2   e Zd ZdZeZdZdZdZdd Z	ddd	Z
d
S )EfficientNetPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    efficientnetr=   Tc                 C   sj   t |tjtjfr@|jjjd| jjd |j	dk	rf|j	j
  n&t |tjrf|j	j
  |jjd dS )zInitialize the weightsg        )meanZstdNg      ?)r    r   Linearr1   weightdataZnormal_r   Zinitializer_ranger*   Zzero_Z	LayerNormZfill_)r:   moduler   r   r   _init_weights  s    
z)EfficientNetPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r?   )r    r   Zgradient_checkpointing)r:   r   valuer   r   r   _set_gradient_checkpointing  s    
z7EfficientNetPreTrainedModel._set_gradient_checkpointingN)F)rB   rC   rD   rE   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r   r   r   r   r   r     s   r   zUThe bare EfficientNet model outputting raw features without any specific head on top.c                	       sd   e Zd Zed fddZeeeee	e
dedd
ejee ee eee	f ddd	Z  ZS )EfficientNetModelr#   c                    s~   t  | || _t|| _t|| _|jdkrDtj	|j
dd| _n.|jdkrbtj|j
dd| _ntd|j |   d S )Nr   T)Z	ceil_moder   z2config.pooling must be one of ['mean', 'max'] got )r-   r.   r   r"   r}   r   encoderZpooling_typer   Z	AvgPool2dr   poolerZ	MaxPool2d
ValueErrorZpooling	post_initr9   r;   r   r   r.     s    



zEfficientNetModel.__init__Zvision)
checkpointoutput_typer   Zmodalityexpected_outputN)r=   r   r   r>   c                 C   s   |d k	r|n| j j}|d k	r |n| j j}|d kr8td| |}| j|||d}|d }| |}||jd d }|s||f|dd   S t	|||j
dS )Nz You have to specify pixel_valuesr   r   r   r   r   )r   pooler_outputrZ   )r   r   use_return_dictr   r}   r   r   Zreshapeshaper   rZ   )r:   r=   r   r   Zembedding_outputZencoder_outputsr   pooled_outputr   r   r   rA     s*    

zEfficientNetModel.forward)NNN)rB   rC   rD   r   r.   r   EFFICIENTNET_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPErF   r\   r   rf   r   r   rA   rH   r   r   r;   r   r     s$   	   
r   z
    EfficientNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g.
    for ImageNet.
    c                
       sd   e Zd Z fddZeeeeee	e
ddejeej ee ee eeef dddZ  ZS )	"EfficientNetForImageClassificationc                    sd   t  | |j| _|| _t|| _tj|jd| _	| jdkrNt
|j| jnt | _|   d S )Nrv   r   )r-   r.   
num_labelsr   r   r   r   r{   Zdropout_rater|   r   r   ZIdentity
classifierr   r9   r;   r   r   r.   G  s    
$z+EfficientNetForImageClassification.__init__)r   r   r   r   N)r=   labelsr   r   r>   c                 C   sv  |dk	r|n| j j}| j|||d}|r.|jn|d }| |}| |}d}|dk	r6| j jdkr| jdkrxd| j _n4| jdkr|jt	j
ks|jt	jkrd| j _nd| j _| j jdkrt }	| jdkr|	| | }n
|	||}nN| j jdkrt }	|	|d| j|d}n| j jdkr6t }	|	||}|sf|f|dd  }
|dk	rb|f|
 S |
S t|||jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   )losslogitsrZ   )r   r   r   r   r|   r   Zproblem_typer   ZdtyperF   longr   r   rl   r   viewr   r   rZ   )r:   r=   r   r   r   outputsr   r   r   Zloss_fctoutputr   r   r   rA   S  s@    




"


z*EfficientNetForImageClassification.forward)NNNN)rB   rC   rD   r.   r   r   r   _IMAGE_CLASS_CHECKPOINTr   r   _IMAGE_CLASS_EXPECTED_OUTPUTrF   r\   r   Z
LongTensorrf   r   r   rA   rH   r   r   r;   r   r   ?  s&       
r   )T)8rE   r   typingr   r   r   rF   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   r   r   Zmodeling_utilsr   utilsr   r   r   r   Zconfiguration_efficientnetr   Z
get_loggerrB   loggerr   r   r   r   r   Z*EFFICIENTNET_PRETRAINED_MODEL_ARCHIVE_LISTZEFFICIENTNET_START_DOCSTRINGr   r   r   rf   r!   Moduler"   r1   rI   rP   r]   rg   rr   r   r   r   r   r   r   r   r   r   <module>   sT   
''!QZ?