U
    9%eR                     @   sJ  d Z ddlmZmZmZ ddlZddlZddlmZ ddlm	Z	m
Z
mZ ddlmZ ddlmZmZmZmZ dd	lmZ dd
lmZmZmZmZmZ ddlmZ eeZdZ dZ!ddddgZ"dZ#dZ$dgZ%d?e&e&ee& e&dddZ'e(de(dfe(e(e(e(dddZ)G dd dej*Z+G dd dej*Z,G dd dej*Z-G d d! d!ej*Z.G d"d# d#ej*Z/G d$d% d%ej*Z0G d&d' d'ej*Z1G d(d) d)ej*Z2G d*d+ d+ej*Z3G d,d- d-eZ4d.Z5d/Z6ed0e5G d1d2 d2e4Z7ed3e5G d4d5 d5e4Z8G d6d7 d7ej*Z9G d8d9 d9ej*Z:G d:d; d;ej*Z;ed<e5G d=d> d>e4Z<dS )@z PyTorch MobileViTV2 model.    )OptionalTupleUnionN)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttentionSemanticSegmenterOutput)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )MobileViTV2Configr   z$apple/mobilevitv2-1.0-imagenet1k-256      ztabby, tabby cat)valuedivisor	min_valuereturnc                 C   sF   |dkr|}t |t| |d  | | }|d|  k r>||7 }t|S )a  
    Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
    original TensorFlow repo. It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    N   g?)maxint)r   r   r   	new_value r!   s/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/mobilevitv2/modeling_mobilevitv2.pymake_divisibleC   s    r#   z-infinf)r   min_valmax_valr   c                 C   s   t |t|| S N)r   minr   r%   r&   r!   r!   r"   clipR   s    r*   c                       sT   e Zd Zdeeeeeeeeeeeef dd fddZe	j
e	j
dd	d
Z  ZS )MobileViTV2ConvLayerr   FTN)configin_channelsout_channelskernel_sizestridegroupsbiasdilationuse_normalizationuse_activationr   c                    s   t    t|d d | }|| dkr@td| d| d|| dkrbtd| d| dtj||||||||dd		| _|	rtj|d
dddd| _nd | _|
rt	|
t
rt|
 | _qt	|jt
rt|j | _q|j| _nd | _d S )Nr   r   r   zInput channels (z) are not divisible by z groups.zOutput channels (Zzeros)	r-   r.   r/   r0   paddingr3   r1   r2   Zpadding_modegh㈵>g?T)Znum_featuresepsZmomentumZaffineZtrack_running_stats)super__init__r   
ValueErrorr   Conv2dconvolutionZBatchNorm2dnormalization
isinstancestrr
   
activationZ
hidden_act)selfr,   r-   r.   r/   r0   r1   r2   r3   r4   r5   r6   	__class__r!   r"   r9   X   sB    



zMobileViTV2ConvLayer.__init__featuresr   c                 C   s6   |  |}| jd k	r| |}| jd k	r2| |}|S r'   )r<   r=   r@   )rA   rE   r!   r!   r"   forward   s    




zMobileViTV2ConvLayer.forward)r   r   Fr   TT)__name__
__module____qualname__r   r   boolr   r?   r9   torchTensorrF   __classcell__r!   r!   rB   r"   r+   W   s(         
6r+   c                       sF   e Zd ZdZd
eeeeedd fddZejejddd	Z	  Z
S )MobileViTV2InvertedResidualzQ
    Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381
    r   N)r,   r-   r.   r0   r3   r   c              	      s   t    ttt||j d}|dkr:td| d|dkoH||k| _t|||dd| _	t|||d|||d| _
t|||dd	d
| _d S )Nr   )r   r   zInvalid stride .r   )r-   r.   r/   r	   )r-   r.   r/   r0   r1   r3   Fr-   r.   r/   r5   )r8   r9   r#   r   roundZexpand_ratior:   use_residualr+   
expand_1x1conv_3x3
reduce_1x1)rA   r,   r-   r.   r0   r3   Zexpanded_channelsrB   r!   r"   r9      s6    
   
z$MobileViTV2InvertedResidual.__init__rD   c                 C   s4   |}|  |}| |}| |}| jr0|| S |S r'   )rS   rT   rU   rR   )rA   rE   Zresidualr!   r!   r"   rF      s
    


z#MobileViTV2InvertedResidual.forward)r   rG   rH   rI   __doc__r   r   r9   rK   rL   rF   rM   r!   r!   rB   r"   rN      s        !rN   c                       sB   e Zd Zd	eeeeedd fddZejejdddZ  Z	S )
MobileViTV2MobileNetLayerr   N)r,   r-   r.   r0   
num_stagesr   c                    sR   t    t | _t|D ]0}t||||dkr4|ndd}| j| |}qd S )Nr   r   )r-   r.   r0   )r8   r9   r   
ModuleListlayerrangerN   append)rA   r,   r-   r.   r0   rY   ir[   rB   r!   r"   r9      s    

z"MobileViTV2MobileNetLayer.__init__rD   c                 C   s   | j D ]}||}q|S r'   r[   )rA   rE   layer_moduler!   r!   r"   rF      s    

z!MobileViTV2MobileNetLayer.forward)r   r   
rG   rH   rI   r   r   r9   rK   rL   rF   rM   r!   r!   rB   r"   rX      s          rX   c                       s>   e Zd ZdZeedd fddZejejdddZ	  Z
S )	MobileViTV2LinearSelfAttentionaq  
    This layer applies a self-attention with linear complexity, as described in MobileViTV2 paper:
    https://arxiv.org/abs/2206.02680

    Args:
        config (`MobileVitv2Config`):
             Model configuration object
        embed_dim (`int`):
            `input_channels` from an expected input of size :math:`(batch_size, input_channels, height, width)`
    N)r,   	embed_dimr   c              	      s\   t    t||dd|  ddddd| _tj|jd| _t|||ddddd| _|| _d S )Nr   r   TF)r,   r-   r.   r2   r/   r4   r5   p)	r8   r9   r+   qkv_projr   Dropoutattn_dropoutout_projrc   )rA   r,   rc   rB   r!   r"   r9      s*    


	z'MobileViTV2LinearSelfAttention.__init__hidden_statesr   c           	      C   s   |  |}tj|d| j| jgdd\}}}tjjj|dd}| |}|| }tj|ddd}tjj	||
| }| |}|S )Nr   )Zsplit_size_or_sectionsdimrl   Trl   Zkeepdim)rf   rK   splitrc   r   
functionalZsoftmaxrh   sumreluZ	expand_asri   )	rA   rk   Zqkvquerykeyr   Zcontext_scoresZcontext_vectoroutr!   r!   r"   rF     s    
 

z&MobileViTV2LinearSelfAttention.forwardrV   r!   r!   rB   r"   rb      s   rb   c                       s@   e Zd Zd	eeeedd fddZejejdddZ	  Z
S )
MobileViTV2FFN        N)r,   rc   ffn_latent_dimffn_dropoutr   c              
      sZ   t    t|||dddddd| _t|| _t|||dddddd| _t|| _d S )Nr   TF)r,   r-   r.   r/   r0   r2   r4   r5   )	r8   r9   r+   conv1r   rg   dropout1conv2dropout2)rA   r,   rc   ry   rz   rB   r!   r"   r9      s.    


zMobileViTV2FFN.__init__rj   c                 C   s,   |  |}| |}| |}| |}|S r'   )r{   r|   r}   r~   )rA   rk   r!   r!   r"   rF   @  s
    



zMobileViTV2FFN.forward)rx   rG   rH   rI   r   r   floatr9   rK   rL   rF   rM   r!   r!   rB   r"   rw     s     rw   c                       s@   e Zd Zd	eeeedd fddZejejdddZ	  Z
S )
MobileViTV2TransformerLayerrx   N)r,   rc   ry   dropoutr   c                    sb   t    tjd||jd| _t||| _tj|d| _	tjd||jd| _
t||||j| _d S )Nr   Z
num_groupsnum_channelsr7   rd   )r8   r9   r   	GroupNormlayer_norm_epslayernorm_beforerb   	attentionrg   r|   layernorm_afterrw   rz   ffn)rA   r,   rc   ry   r   rB   r!   r"   r9   I  s    
z$MobileViTV2TransformerLayer.__init__rj   c                 C   s<   |  |}| |}|| }| |}| |}|| }|S r'   )r   r   r   r   )rA   rk   Zlayernorm_1_outZattention_outputZlayer_outputr!   r!   r"   rF   W  s    



z#MobileViTV2TransformerLayer.forward)rx   r   r!   r!   rB   r"   r   H  s    r   c                       s<   e Zd Zeeedd fddZejejdddZ  Z	S )MobileViTV2TransformerN)r,   n_layersd_modelr   c                    sf   t    |j}|| g| }dd |D }t | _t|D ]"}t|||| d}| j| q>d S )Nc                 S   s   g | ]}t |d  d  qS )   )r   ).0dr!   r!   r"   
<listcomp>l  s     z3MobileViTV2Transformer.__init__.<locals>.<listcomp>)rc   ry   )	r8   r9   ffn_multiplierr   rZ   r[   r\   r   r]   )rA   r,   r   r   r   Zffn_dimsZ	block_idxtransformer_layerrB   r!   r"   r9   d  s    

  zMobileViTV2Transformer.__init__rj   c                 C   s   | j D ]}||}q|S r'   r_   )rA   rk   r`   r!   r!   r"   rF   u  s    

zMobileViTV2Transformer.forwardra   r!   r!   rB   r"   r   c  s   r   c                
       s   e Zd ZdZdeeeeeeedd fddZeje	eje	eef f dd	d
Z
eje	eef ejdddZejejdddZ  ZS )MobileViTV2Layerz=
    MobileViTV2 layer: https://arxiv.org/abs/2206.02680
    r   r   N)r,   r-   r.   attn_unit_dimn_attn_blocksr3   r0   r   c           	         s   t    |j| _|j| _|}|dkr\t||||dkr:|nd|dkrL|d ndd| _|}nd | _t||||j|d| _	t|||dddd| _
t|||d| _tjd||jd| _t|||dd	dd| _d S )
Nr   r   )r-   r.   r0   r3   )r-   r.   r/   r1   F)r-   r.   r/   r4   r5   )r   r   r   T)r8   r9   
patch_sizepatch_widthpatch_heightrN   downsampling_layerr+   Zconv_kernel_sizeconv_kxkconv_1x1r   transformerr   r   r   	layernormconv_projection)	rA   r,   r-   r.   r   r   r3   r0   Zcnn_out_dimrB   r!   r"   r9     sN    


zMobileViTV2Layer.__init__)feature_mapr   c                 C   sT   |j \}}}}tjj|| j| jf| j| jfd}|||| j| j d}|||ffS )N)r/   r0   rm   )shaper   rq   Zunfoldr   r   reshape)rA   r   
batch_sizer-   Z
img_heightZ	img_widthpatchesr!   r!   r"   	unfolding  s    

zMobileViTV2Layer.unfolding)r   output_sizer   c                 C   sH   |j \}}}}|||| |}tjj||| j| jf| j| jfd}|S )N)r   r/   r0   )r   r   r   rq   foldr   r   )rA   r   r   r   Zin_dimr   Z	n_patchesr   r!   r!   r"   folding  s    

zMobileViTV2Layer.foldingrD   c                 C   s`   | j r|  |}| |}| |}| |\}}| |}| |}| ||}| |}|S r'   )r   r   r   r   r   r   r   r   )rA   rE   r   r   r!   r!   r"   rF     s    





zMobileViTV2Layer.forward)r   r   r   )rG   rH   rI   rW   r   r   r9   rK   rL   r   r   r   rF   rM   r!   r!   rB   r"   r   {  s"   
   =$r   c                       sD   e Zd Zedd fddZd
ejeeee	e
f ddd	Z  ZS )MobileViTV2EncoderNr,   r   c                    s  t    || _t | _d| _d }}|jdkr<d}d}n|jdkrJd}d}tt	d|j
 dddddd	}td|j
 dd
}td|j
 dd
}td|j
 dd
}td|j
 dd
}	td|j
 dd
}
t|||ddd}| j| t|||ddd}| j| t|||t|jd |j
 dd
|jd d}| j| |rH|d9 }t|||	t|jd |j
 dd
|jd |d}| j| |r|d9 }t||	|
t|jd |j
 dd
|jd |d}| j| d S )NFr   Tr   r       @   r)   r   r   r         i  r   )r-   r.   r0   rY   r   r   )r-   r.   r   r   )r-   r.   r   r   r3   )r8   r9   r,   r   rZ   r[   gradient_checkpointingZoutput_strider#   r*   width_multiplierrX   r]   r   Zbase_attn_unit_dimsr   )rA   r,   Zdilate_layer_4Zdilate_layer_5r3   layer_0_dimZlayer_1_dimZlayer_2_dimZlayer_3_dimZlayer_4_dimZlayer_5_dimZlayer_1Zlayer_2Zlayer_3Zlayer_4Zlayer_5rB   r!   r"   r9     s    



  zMobileViTV2Encoder.__init__FT)rk   output_hidden_statesreturn_dictr   c                 C   s   |rdnd }t | jD ]H\}}| jrH| jrHdd }tjj|||}n||}|r||f }q|sztdd ||fD S t||dS )Nr!   c                    s    fdd}|S )Nc                     s    |  S r'   r!   )inputsmoduler!   r"   custom_forwardK  s    zQMobileViTV2Encoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr!   )r   r   r!   r   r"   create_custom_forwardJ  s    z9MobileViTV2Encoder.forward.<locals>.create_custom_forwardc                 s   s   | ]}|d k	r|V  qd S r'   r!   )r   vr!   r!   r"   	<genexpr>[  s      z-MobileViTV2Encoder.forward.<locals>.<genexpr>)last_hidden_staterk   )		enumerater[   r   ZtrainingrK   utils
checkpointtupler   )rA   rk   r   r   Zall_hidden_statesr^   r`   r   r!   r!   r"   rF   ?  s    zMobileViTV2Encoder.forward)FT)rG   rH   rI   r   r9   rK   rL   rJ   r   r   r   rF   rM   r!   r!   rB   r"   r     s   T  
r   c                   @   sJ   e Zd ZdZeZdZdZdZe	e
je
je
jf ddddZdd
dZdS )MobileViTV2PreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    mobilevitv2pixel_valuesTN)r   r   c                 C   sj   t |tjtjfr@|jjjd| jjd |j	dk	rf|j	j
  n&t |tjrf|j	j
  |jjd dS )zInitialize the weightsrx   )meanZstdNg      ?)r>   r   Linearr;   weightdataZnormal_r,   Zinitializer_ranger2   Zzero_	LayerNormZfill_)rA   r   r!   r!   r"   _init_weightsl  s    
z(MobileViTV2PreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r'   )r>   r   r   )rA   r   r   r!   r!   r"   _set_gradient_checkpointingx  s    
z6MobileViTV2PreTrainedModel._set_gradient_checkpointing)F)rG   rH   rI   rW   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r   r   r;   r   r   r   r!   r!   r!   r"   r   a  s    r   aM  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`MobileViTV2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aF  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`MobileViTImageProcessor.__call__`] for details.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zYThe bare MobileViTV2 model outputting raw hidden-states without any specific head on top.c                	       st   e Zd Zdeed fddZdd Zeee	e
eededdeej ee ee eeef d
ddZ  ZS )MobileViTV2ModelT)r,   expand_outputc              	      sf   t  | || _|| _ttd|j dddddd}t||j|ddd	d	d
| _	t
|| _|   d S )Nr   r   r   r)   r   r   r	   r   Tr-   r.   r/   r0   r4   r5   )r8   r9   r,   r   r#   r*   r   r+   r   	conv_stemr   encoder	post_init)rA   r,   r   r   rB   r!   r"   r9     s&      	
zMobileViTV2Model.__init__c                 C   sF   |  D ]8\}}| jj| }t|tr|jjD ]}|j| q.qdS )zPrunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
        N)itemsr   r[   r>   r   r   r   Zprune_heads)rA   Zheads_to_pruneZlayer_indexZheadsZmobilevitv2_layerr   r!   r!   r"   _prune_heads  s
    
zMobileViTV2Model._prune_headsZvision)r   output_typer   Zmodalityexpected_outputN)r   r   r   r   c           	      C   s   |d k	r|n| j j}|d k	r |n| j j}|d kr8td| |}| j|||d}| jrv|d }tj|ddgdd}n|d }d }|s|d k	r||fn|f}||dd   S t	|||j
d	S )
Nz You have to specify pixel_valuesr   r   r   rm   Fro   r   )r   pooler_outputrk   )r,   r   use_return_dictr:   r   r   r   rK   r   r   rk   )	rA   r   r   r   Zembedding_outputZencoder_outputsr   pooled_outputoutputr!   r!   r"   rF     s0    
zMobileViTV2Model.forward)T)NNN)rG   rH   rI   r   rJ   r9   r   r   MOBILEVITV2_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   rK   rL   r   r   rF   rM   r!   r!   rB   r"   r     s&   
	   
r   z
    MobileViTV2 model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    c                
       sp   e Zd Zedd fddZeeeee	e
edd	eej ee eej ee eee	f dddZ  ZS )
!MobileViTV2ForImageClassificationNr   c                    s`   t  | |j| _t|| _td|j dd}|jdkrJtj||jdnt	 | _
|   d S )Nr   r   r   r   )Zin_featuresZout_features)r8   r9   
num_labelsr   r   r#   r   r   r   ZIdentity
classifierr   )rA   r,   r.   rB   r!   r"   r9     s    
z*MobileViTV2ForImageClassification.__init__)r   r   r   r   )r   r   labelsr   r   c                 C   sl  |dk	r|n| j j}| j|||d}|r.|jn|d }| |}d}|dk	r,| j jdkr| jdkrnd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }	| jdkr|	| | }n
|	||}nN| j jdkrt }	|	|d| j|d}n| j jdkr,t }	|	||}|s\|f|dd  }
|dk	rX|f|
 S |
S t|||jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrm   r   )losslogitsrk   )r,   r   r   r   r   Zproblem_typer   ZdtyperK   longr   r   Zsqueezer   viewr   r   rk   )rA   r   r   r   r   outputsr   r   r   loss_fctr   r!   r!   r"   rF     s>    



"


z)MobileViTV2ForImageClassification.forward)NNNN)rG   rH   rI   r   r9   r   r   r   _IMAGE_CLASS_CHECKPOINTr   r   _IMAGE_CLASS_EXPECTED_OUTPUTr   rK   rL   rJ   r   r   rF   rM   r!   r!   rB   r"   r     s&       
r   c                       s<   e Zd Zeeedd fddZejejdddZ  Z	S )MobileViTV2ASPPPoolingN)r,   r-   r.   r   c              	      s4   t    tjdd| _t|||ddddd| _d S )Nr   )r   Trs   r   )r8   r9   r   ZAdaptiveAvgPool2dglobal_poolr+   r   )rA   r,   r-   r.   rB   r!   r"   r9   F  s    
zMobileViTV2ASPPPooling.__init__rD   c                 C   s:   |j dd  }| |}| |}tjj||ddd}|S )Nr   bilinearFsizemodeZalign_corners)r   r   r   r   rq   interpolate)rA   rE   Zspatial_sizer!   r!   r"   rF   U  s
    

zMobileViTV2ASPPPooling.forwardra   r!   r!   rB   r"   r   E  s   r   c                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	MobileViTV2ASPPzs
    ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587
    Nr   c                    s   t    td j dd}| jt jdkr<tdt	 | _
t ddd}| j
| | j
 fd	d
 jD  t }| j
| t d ddd| _tj jd| _d S )Nr   r   r   r	   z"Expected 3 values for atrous_ratesr   rs   rP   c              
      s    g | ]}t  d |ddqS )r	   rs   )r-   r.   r/   r3   r5   )r+   )r   Zrater,   r-   r.   r!   r"   r   x  s   	z,MobileViTV2ASPP.__init__.<locals>.<listcomp>   rd   )r8   r9   r#   r   aspp_out_channelslenZatrous_ratesr:   r   rZ   convsr+   r]   extendr   projectrg   Zaspp_dropout_probr   )rA   r,   Zencoder_out_channelsZin_projectionZ
pool_layerrB   r   r"   r9   b  s<    

	    zMobileViTV2ASPP.__init__rD   c                 C   sD   g }| j D ]}||| q
tj|dd}| |}| |}|S )Nr   rn   )r   r]   rK   catr   r   )rA   rE   ZpyramidconvZpooled_featuresr!   r!   r"   rF     s    


zMobileViTV2ASPP.forward
rG   rH   rI   rW   r   r9   rK   rL   rF   rM   r!   r!   rB   r"   r   ]  s   ,r   c                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	MobileViTV2DeepLabV3zB
    DeepLabv3 architecture: https://arxiv.org/abs/1706.05587
    Nr   c              	      sB   t    t|| _t|j| _t||j	|j
ddddd| _d S )Nr   FT)r-   r.   r/   r4   r5   r2   )r8   r9   r   asppr   Z	Dropout2dZclassifier_dropout_probr   r+   r   r   r   rA   r,   rB   r!   r"   r9     s    

zMobileViTV2DeepLabV3.__init__rj   c                 C   s&   |  |d }| |}| |}|S )Nrm   )r   r   r   )rA   rk   rE   r!   r!   r"   rF     s    

zMobileViTV2DeepLabV3.forwardr   r!   r!   rB   r"   r     s   r   zZ
    MobileViTV2 model with a semantic segmentation head on top, e.g. for Pascal VOC.
    c                
       sl   e Zd Zedd fddZeeeee	dd	e
ej e
ej e
e e
e eeef dddZ  ZS )
"MobileViTV2ForSemanticSegmentationNr   c                    s8   t  | |j| _t|dd| _t|| _|   d S )NF)r   )r8   r9   r   r   r   r   segmentation_headr   r  rB   r!   r"   r9     s
    
z+MobileViTV2ForSemanticSegmentation.__init__)r   r   )r   r   r   r   r   c                 C   s  |dk	r|n| j j}|dk	r |n| j j}| j|d|d}|rB|jn|d }| |}d}|dk	r| j jdkrvtdn6tj	j
||jdd ddd	}	t| j jd
}
|
|	|}|s|r|f|dd  }n|f|dd  }|dk	r|f| S |S t|||r|jndddS )a  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Returns:

        Examples:

        ```python
        >>> import requests
        >>> import torch
        >>> from PIL import Image
        >>> from transformers import AutoImageProcessor, MobileViTV2ForSemanticSegmentation

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
        >>> model = MobileViTV2ForSemanticSegmentation.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> # logits are of shape (batch_size, num_labels, height, width)
        >>> logits = outputs.logits
        ```NTr   r   z/The number of labels should be greater than oner   r   Fr   )Zignore_indexr   )r   r   rk   Z
attentions)r,   r   r   r   rk   r  r   r:   r   rq   r   r   r   Zsemantic_loss_ignore_indexr   )rA   r   r   r   r   r   Zencoder_hidden_statesr   r   Zupsampled_logitsr   r   r!   r!   r"   rF     sB    '

   
z*MobileViTV2ForSemanticSegmentation.forward)NNNN)rG   rH   rI   r   r9   r   r   r   r   r   r   rK   rL   rJ   r   r   rF   rM   r!   r!   rB   r"   r    s   

    
r  )r   N)=rW   typingr   r   r   rK   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   r   r   r   r   r   Zconfiguration_mobilevitv2r   Z
get_loggerrG   loggerr   r   r   r   r   Z)MOBILEVITV2_PRETRAINED_MODEL_ARCHIVE_LISTr   r#   r   r*   Moduler+   rN   rX   rb   rw   r   r   r   r   r   ZMOBILEVITV2_START_DOCSTRINGr   r   r   r   r   r   r  r!   r!   r!   r"   <module>   sd   
"A1?)rtTQ=