U
    9%e                     @   sJ  d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZmZmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZmZ ddlmZm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z& e#'e(Z)dZ*dZ+dddgZ,dZ-dZ.dgZ/G dd dej0Z1G dd dej0Z2G dd dej0Z3G dd dej0Z4G dd dej0Z5G dd dej0Z6G d d! d!ej0Z7G d"d# d#ej0Z8G d$d% d%ej0Z9G d&d' d'eZ:d(Z;d)Z<e!d*e;G d+d, d,e:Z=G d-d. d.ej0Z>e!d/e;G d0d1 d1e:Z?e!d2e;G d3d4 d4e:Z@eG d5d6 d6eZAe!d7e;G d8d9 d9e:ZBdS ):z PyTorch DeiT model.    N)	dataclass)OptionalSetTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutputMaskedImageModelingOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
DeiTConfigr   z(facebook/deit-base-distilled-patch16-224   i   ztabby, tabby catc                       sJ   e Zd ZdZd
eedd fddZdeje	ej
 ejddd	Z  ZS )DeiTEmbeddingszv
    Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
    FN)configuse_mask_tokenreturnc                    s   t    ttdd|j| _ttdd|j| _|rTttdd|jnd | _	t
|| _| jj}ttd|d |j| _t|j| _d S )Nr      )super__init__r   	ParametertorchZzeroshidden_size	cls_tokendistillation_token
mask_tokenDeiTPatchEmbeddingspatch_embeddingsnum_patchesposition_embeddingsDropouthidden_dropout_probdropout)selfr   r   r,   	__class__ e/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/deit/modeling_deit.pyr#   I   s    
 
zDeiTEmbeddings.__init__)pixel_valuesbool_masked_posr    c                 C   s   |  |}| \}}}|d k	rT| j||d}|d|}|d|  ||  }| j|dd}	| j|dd}
tj	|	|
|fdd}|| j
 }| |}|S )N      ?r   dim)r+   sizer)   expand	unsqueezeZtype_asr'   r(   r%   catr-   r0   )r1   r6   r7   
embeddings
batch_sizeZ
seq_length_Zmask_tokensmaskZ
cls_tokensZdistillation_tokensr4   r4   r5   forwardT   s    


zDeiTEmbeddings.forward)F)N)__name__
__module____qualname____doc__r   boolr#   r%   Tensorr   
BoolTensorrD   __classcell__r4   r4   r2   r5   r   D   s   r   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )r*   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )kernel_sizeZstride)r"   r#   
image_size
patch_sizenum_channelsr&   
isinstancecollectionsabcIterabler,   r   Conv2d
projection)r1   r   rN   rO   rP   r&   r,   r2   r4   r5   r#   m   s    
 zDeiTPatchEmbeddings.__init__)r6   r    c              
   C   s   |j \}}}}|| jkr td|| jd ks<|| jd krjtd| d| d| jd  d| jd  d	| |ddd}|S )	NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model (z).r!   )shaperP   
ValueErrorrN   rV   flatten	transpose)r1   r6   rA   rP   heightwidthxr4   r4   r5   rD   |   s    
(zDeiTPatchEmbeddings.forward)	rE   rF   rG   rH   r#   r%   rJ   rD   rL   r4   r4   r2   r5   r*   f   s   r*   c                       sl   e Zd Zedd fddZejejdddZdeej e	e
eejejf eej f d	d
dZ  ZS )DeiTSelfAttentionNr   r    c                    s   t    |j|j dkr@t|ds@td|jf d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)bias)r"   r#   r&   num_attention_headshasattrrY   intattention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyvaluer.   Zattention_probs_dropout_probr0   r1   r   r2   r4   r5   r#      s    
zDeiTSelfAttention.__init__)r^   r    c                 C   s6   |  d d | j| jf }||}|ddddS )Nr8   r   r!   r   r   )r<   rc   rf   viewpermute)r1   r^   Znew_x_shaper4   r4   r5   transpose_for_scores   s    
z&DeiTSelfAttention.transpose_for_scoresF)	head_maskoutput_attentionsr    c                 C   s   |  |}| | |}| | |}| |}t||dd}|t| j	 }t
jj|dd}	| |	}	|d k	r|	| }	t|	|}
|
dddd }
|
 d d | jf }|
|}
|r|
|	fn|
f}|S )Nr8   r:   r   r!   r   r   )ri   ro   rj   rk   r%   matmulr[   mathsqrtrf   r   
functionalZsoftmaxr0   rn   
contiguousr<   rg   rm   )r1   hidden_statesrp   rq   Zmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr4   r4   r5   rD      s     



zDeiTSelfAttention.forward)NF)rE   rF   rG   r   r#   r%   rJ   ro   r   rI   r   r   rD   rL   r4   r4   r2   r5   r_      s       r_   c                       s@   e Zd ZdZedd fddZejejejdddZ  Z	S )	DeiTSelfOutputz
    The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    Nr`   c                    s.   t    t|j|j| _t|j| _d S N)	r"   r#   r   rh   r&   denser.   r/   r0   rl   r2   r4   r5   r#      s    
zDeiTSelfOutput.__init__rx   input_tensorr    c                 C   s   |  |}| |}|S r{   r|   r0   r1   rx   r~   r4   r4   r5   rD      s    

zDeiTSelfOutput.forward)
rE   rF   rG   rH   r   r#   r%   rJ   rD   rL   r4   r4   r2   r5   rz      s   rz   c                       sp   e Zd Zedd fddZee ddddZdej	e
ej	 eeeej	ej	f eej	 f d	d
dZ  ZS )DeiTAttentionNr`   c                    s*   t    t|| _t|| _t | _d S r{   )r"   r#   r_   	attentionrz   outputsetpruned_headsrl   r2   r4   r5   r#      s    


zDeiTAttention.__init__)headsr    c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r:   )lenr   r   rc   rf   r   r   ri   rj   rk   r   r|   rg   union)r1   r   indexr4   r4   r5   prune_heads   s       zDeiTAttention.prune_headsFrx   rp   rq   r    c                 C   s4   |  |||}| |d |}|f|dd   }|S )Nr   r   )r   r   )r1   rx   rp   rq   Zself_outputsattention_outputry   r4   r4   r5   rD      s    zDeiTAttention.forward)NF)rE   rF   rG   r   r#   r   re   r   r%   rJ   r   rI   r   r   rD   rL   r4   r4   r2   r5   r      s     r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )DeiTIntermediateNr`   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r{   )r"   r#   r   rh   r&   intermediate_sizer|   rQ   Z
hidden_actstrr   intermediate_act_fnrl   r2   r4   r5   r#     s
    
zDeiTIntermediate.__init__)rx   r    c                 C   s   |  |}| |}|S r{   )r|   r   )r1   rx   r4   r4   r5   rD     s    

zDeiTIntermediate.forward	rE   rF   rG   r   r#   r%   rJ   rD   rL   r4   r4   r2   r5   r     s   r   c                       s<   e Zd Zedd fddZejejejdddZ  ZS )
DeiTOutputNr`   c                    s.   t    t|j|j| _t|j| _	d S r{   )
r"   r#   r   rh   r   r&   r|   r.   r/   r0   rl   r2   r4   r5   r#     s    
zDeiTOutput.__init__r}   c                 C   s    |  |}| |}|| }|S r{   r   r   r4   r4   r5   rD     s    

zDeiTOutput.forwardr   r4   r4   r2   r5   r     s   r   c                       s`   e Zd ZdZedd fddZd
ejeej e	e
eejejf eej f ddd	Z  ZS )	DeiTLayerz?This corresponds to the Block class in the timm implementation.Nr`   c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   Zeps)r"   r#   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   	LayerNormr&   layer_norm_epslayernorm_beforelayernorm_afterrl   r2   r4   r5   r#   '  s    



zDeiTLayer.__init__Fr   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )Nrq   r   r   )r   r   r   r   r   )r1   rx   rp   rq   Zself_attention_outputsr   ry   Zlayer_outputr4   r4   r5   rD   1  s    


zDeiTLayer.forward)NF)rE   rF   rG   rH   r   r#   r%   rJ   r   rI   r   r   rD   rL   r4   r4   r2   r5   r   $  s     r   c                	       sN   e Zd Zedd fddZd
ejeej eeee	e
ef ddd	Z  ZS )DeiTEncoderNr`   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r4   )r   ).0rB   r   r4   r5   
<listcomp>S  s     z(DeiTEncoder.__init__.<locals>.<listcomp>F)	r"   r#   r   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingrl   r2   r   r5   r#   P  s    
 zDeiTEncoder.__init__FT)rx   rp   rq   output_hidden_statesreturn_dictr    c                    s   |rdnd } rdnd }t | jD ]\}}	|r8||f }|d k	rH|| nd }
| jr|| jr| fdd}tjj||	||
}n|	||
 }|d } r"||d f }q"|r||f }|stdd |||fD S t|||dS )	Nr4   c                    s    fdd}|S )Nc                     s    | f S r{   r4   )inputs)modulerq   r4   r5   custom_forwardj  s    zJDeiTEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr4   )r   r   r   )r   r5   create_custom_forwardi  s    z2DeiTEncoder.forward.<locals>.create_custom_forwardr   r   c                 s   s   | ]}|d k	r|V  qd S r{   r4   )r   vr4   r4   r5   	<genexpr>  s      z&DeiTEncoder.forward.<locals>.<genexpr>)last_hidden_staterx   
attentions)		enumerater   r   Ztrainingr%   utils
checkpointtupler   )r1   rx   rp   rq   r   r   Zall_hidden_statesZall_self_attentionsiZlayer_moduleZlayer_head_maskr   Zlayer_outputsr4   r   r5   rD   V  s4    

zDeiTEncoder.forward)NFFT)rE   rF   rG   r   r#   r%   rJ   r   rI   r   r   r   rD   rL   r4   r4   r2   r5   r   O  s   	    
r   c                   @   sZ   e Zd ZdZeZdZdZdZdgZ	e
ejejejf dddd	ZdeeddddZdS )DeiTPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    deitr6   Tr   N)r   r    c                 C   s   t |tjtjfrZtjj|jjt	j
d| jjd|jj|j_|jdk	r|jj  n&t |tjr|jj  |jjd dS )zInitialize the weightsg        )ZmeanZstdNr9   )rQ   r   rh   rU   initZtrunc_normal_weightdatator%   Zfloat32r   Zinitializer_rangedtyperb   Zzero_r   Zfill_)r1   r   r4   r4   r5   _init_weights  s      
z!DeiTPreTrainedModel._init_weightsF)r   rk   r    c                 C   s   t |tr||_d S r{   )rQ   r   r   )r1   r   rk   r4   r4   r5   _set_gradient_checkpointing  s    
z/DeiTPreTrainedModel._set_gradient_checkpointing)F)rE   rF   rG   rH   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr   r   rh   rU   r   r   r   rI   r   r4   r4   r4   r5   r     s    r   aF  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`DeiTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aL  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`DeiTImageProcessor.__call__`] for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare DeiT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zdeeedd fddZeddd	Zd
d Ze	e
eeeededdeej eej eej ee ee ee eeef dddZ  ZS )	DeiTModelTFN)r   add_pooling_layerr   r    c                    s\   t  | || _t||d| _t|| _tj|j	|j
d| _|rJt|nd | _|   d S )N)r   r   )r"   r#   r   r   r@   r   encoderr   r   r&   r   	layernorm
DeiTPoolerpooler	post_init)r1   r   r   r   r2   r4   r5   r#     s    
zDeiTModel.__init__)r    c                 C   s   | j jS r{   )r@   r+   )r1   r4   r4   r5   get_input_embeddings  s    zDeiTModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r1   Zheads_to_pruner   r   r4   r4   r5   _prune_heads  s    zDeiTModel._prune_headsZvision)r   output_typer   Zmodalityexpected_outputr6   r7   rp   rq   r   r   r    c                 C   s  |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}|dkrLtd| || j j}| jjj	j
j}|j|kr~||}| j||d}| j|||||d}	|	d }
| |
}
| jdk	r| |
nd}|s|dk	r|
|fn|
f}||	dd  S t|
||	j|	jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)r7   rp   rq   r   r   r   r   )r   Zpooler_outputrx   r   )r   rq   r   use_return_dictrY   Zget_head_maskr   r@   r+   rV   r   r   r   r   r   r   r   rx   r   )r1   r6   r7   rp   rq   r   r   Zexpected_dtypeZembedding_outputZencoder_outputssequence_outputpooled_outputZhead_outputsr4   r4   r5   rD     s<    


zDeiTModel.forward)TF)NNNNNN)rE   rF   rG   r   rI   r#   r*   r   r   r   DEIT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r%   rJ   rK   r   r   rD   rL   r4   r4   r2   r5   r     s4   	      
r   c                       s*   e Zd Zed fddZdd Z  ZS )r   r   c                    s*   t    t|j|j| _t | _d S r{   )r"   r#   r   rh   r&   r|   ZTanh
activationrl   r2   r4   r5   r#   +  s    
zDeiTPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r|   r   )r1   rx   Zfirst_token_tensorr   r4   r4   r5   rD   0  s    

zDeiTPooler.forward)rE   rF   rG   r   r#   rD   rL   r4   r4   r2   r5   r   *  s   r   aW  DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                       sz   e Zd Zedd fddZeeeee	dd	e
ej e
ej e
ej e
e e
e e
e eeef dddZ  ZS )
DeiTForMaskedImageModelingNr`   c                    sX   t  | t|ddd| _ttj|j|jd |j	 ddt
|j| _|   d S )NFT)r   r   r!   r   )Zin_channelsZout_channelsrM   )r"   r#   r   r   r   Z
SequentialrU   r&   Zencoder_striderP   ZPixelShuffledecoderr   rl   r2   r4   r5   r#   F  s    

z#DeiTForMaskedImageModeling.__init__r   r   r   c                 C   sN  |dk	r|n| j j}| j||||||d}|d }|ddddf }|j\}	}
}t|
d  }}|ddd|	|||}| |}d}|dk	r
| j j| j j	 }|d||}|
| j j	d
| j j	dd }tjj||dd	}||  | d
  | j j }|s:|f|dd  }|dk	r6|f| S |S t|||j|jdS )aM  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
        >>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 224, 224]
        ```N)r7   rp   rq   r   r   r   r   r8   g      ?r!   none)Z	reductiongh㈵>)lossZreconstructionrx   r   )r   r   r   rX   re   rn   Zreshaper   rN   rO   Zrepeat_interleaver>   rw   r   rv   Zl1_losssumrP   r   rx   r   )r1   r6   r7   rp   rq   r   r   ry   r   rA   Zsequence_lengthrP   r\   r]   Zreconstructed_pixel_valuesZmasked_im_lossr<   rC   Zreconstruction_lossr   r4   r4   r5   rD   W  sJ    (	

  z"DeiTForMaskedImageModeling.forward)NNNNNN)rE   rF   rG   r   r#   r   r   r   r   r   r   r%   rJ   rK   rI   r   r   rD   rL   r4   r4   r2   r5   r   9  s$   
      
r   z
    DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                       sz   e Zd Zedd fddZeeeee	dd	e
ej e
ej e
ej e
e e
e e
e eeef dddZ  ZS )
DeiTForImageClassificationNr`   c                    sR   t  | |j| _t|dd| _|jdkr<t|j|jnt | _	| 
  d S NF)r   r   )r"   r#   
num_labelsr   r   r   rh   r&   Identity
classifierr   rl   r2   r4   r5   r#     s
    $z#DeiTForImageClassification.__init__r   )r6   rp   labelsrq   r   r   r    c                 C   s  |dk	r|n| j j}| j|||||d}|d }| |dddddf }	d}
|dk	rD||	j}| j jdkr| jdkrd| j _n4| jdkr|jt	j
ks|jt	jkrd| j _nd| j _| j jdkrt }| jdkr||	 | }
n
||	|}
nN| j jdkr&t }||	d| j|d}
n| j jdkrDt }||	|}
|st|	f|dd  }|
dk	rp|
f| S |S t|
|	|j|jd	S )
aM  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, DeiTForImageClassification
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
        >>> # so the head will be randomly initialized, hence the predictions will be random
        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
        >>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> # model predicts one of the 1000 ImageNet classes
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: magpie
        ```Nr   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr8   )r   logitsrx   r   )r   r   r   r   r   ZdeviceZproblem_typer   r   r%   longre   r
   Zsqueezer	   rm   r   r   rx   r   )r1   r6   rp   r   rq   r   r   ry   r   r   r   Zloss_fctr   r4   r4   r5   rD     sN    ,


"


z"DeiTForImageClassification.forward)NNNNNN)rE   rF   rG   r   r#   r   r   r   r   r   r   r%   rJ   rI   r   r   rD   rL   r4   r4   r2   r5   r     s$   
      
r   c                   @   sh   e Zd ZU dZdZejed< dZejed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS )+DeiTForImageClassificationWithTeacherOutputa5  
    Output type of [`DeiTForImageClassificationWithTeacher`].

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores as the average of the cls_logits and distillation logits.
        cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
            class token).
        distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
            distillation token).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr   
cls_logitsdistillation_logitsrx   r   )rE   rF   rG   rH   r   r%   ZFloatTensor__annotations__r   r   rx   r   r   r   r4   r4   r4   r5   r      s   
r   a  
    DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
    the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.

    .. warning::

           This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
           supported.
    c                       sv   e Zd Zedd fddZeeeee	e
edd	eej eej ee ee ee eee	f dddZ  ZS )
%DeiTForImageClassificationWithTeacherNr`   c                    sv   t  | |j| _t|dd| _|jdkr<t|j|jnt | _	|jdkr`t|j|jnt | _
|   d S r   )r"   r#   r   r   r   r   rh   r&   r   cls_classifierdistillation_classifierr   rl   r2   r4   r5   r#   L  s      z.DeiTForImageClassificationWithTeacher.__init__)r   r   r   r   )r6   rp   rq   r   r   r    c                 C   s   |d k	r|n| j j}| j|||||d}|d }| |d d dd d f }| |d d dd d f }	||	 d }
|s|
||	f|dd   }|S t|
||	|j|jdS )Nr   r   r   r!   )r   r   r   rx   r   )r   r   r   r   r   r   rx   r   )r1   r6   rp   rq   r   r   ry   r   r   r   r   r   r4   r4   r5   rD   ]  s,    z-DeiTForImageClassificationWithTeacher.forward)NNNNN)rE   rF   rG   r   r#   r   r   r   _IMAGE_CLASS_CHECKPOINTr   r   _IMAGE_CLASS_EXPECTED_OUTPUTr   r%   rJ   rI   r   r   rD   rL   r4   r4   r2   r5   r   ?  s*        
r   )CrH   collections.abcrR   rt   dataclassesr   typingr   r   r   r   r%   Ztorch.utils.checkpointr   Ztorch.nnr   r	   r
   Zactivationsr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   r   r   r   r   Zconfiguration_deitr   Z
get_loggerrE   loggerr   r   r   r   r   Z"DEIT_PRETRAINED_MODEL_ARCHIVE_LISTModuler   r*   r_   rz   r   r   r   r   r   r   ZDEIT_START_DOCSTRINGr   r   r   r   r   r   r   r4   r4   r4   r5   <module>   sr    

"%=(+9]	ik	