U
    ,-e                    @   s  d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlZddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZmZmZmZ ddlmZmZm Z  e!e"Z#dZ$dgZ%dRej&ej'ee( dddZ)ej&ej&dddZ*ej&ej&dddZ+ej&e(dddZ,dSej&e-e.e(ej&dddZ/dTdd Z0d!d" Z1G d#d$ d$ej2Z3G d%d& d&ej2Z4G d'd( d(ej2Z5eG d)d* d*eZ6G d+d, d,ej2Z7G d-d. d.ej2Z8G d/d0 d0ej2Z9G d1d2 d2ej2Z:G d3d4 d4ej2Z;G d5d6 d6e;Z<G d7d8 d8ej2Z=G d9d: d:ej2Z>G d;d< d<eZ?d=Z@d>ZAd?ZBd@ZCG dAdB dBej2ZDG dCdD dDej2ZEdUejFej'ejGe(dEdFdGZHG dHdI dIej2ZIG dJdK dKe?ZJG dLdM dMej2ZKG dNdO dOe?ZLee@G dPdQ dQe?ZMdS )Vz PyTorch GroupViT model.    N)	dataclass)AnyOptionalTupleUnion)nn   )ACT2FN)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )GroupViTConfigGroupViTTextConfigGroupViTVisionConfigznvidia/groupvit-gcc-yfcc)maskdtypetgt_lenc                 C   sj   |   \}}|dk	r|n|}| ddddddf |d|||}d| }||tjt|jS )z_
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    Nr         ?)sizeexpandtoZmasked_filltorchboolfinfomin)r   r   r   bszsrc_lenZexpanded_maskZinverted_mask r#   o/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/groupvit/modeling_groupvit.py_expand_mask4   s
    *r%   )logitsreturnc                 C   s   t j| tjt| | jdS )Ndevice)r   
functionalZcross_entropyr   arangelenr)   )r&   r#   r#   r$   contrastive_lossD   s    r-   )
similarityr'   c                 C   s    t | }t |  }|| d S )Ng       @)r-   t)r.   Zcaption_lossZ
image_lossr#   r#   r$   groupvit_lossI   s    r0   )r&   dimc                 C   sJ   |  |}|j|ddd }tj| tjd||d}||  | }|S )NTkeepdimr   Zmemory_formatr   )softmaxmaxr   
zeros_likelegacy_contiguous_formatscatter_detach)r&   r1   y_softindexy_hardretr#   r#   r$   hard_softmaxO   s
    
r?   F)r&   tauhardr1   r'   c           
      C   s   t jjt jd| j| jdt jd| j| jd}|| j}| | | }|	|}|r|j
|ddd }t j| t jd||d}||  | }	n|}	|	S )N        )r)   r   r   Tr2   r   r4   )r   distributionsgumbelZGumbeltensorr)   r   sampleshaper5   r6   r7   r8   r9   r:   )
r&   rA   rB   r1   Zgumbel_distZgumbelsr;   r<   r=   r>   r#   r#   r$   gumbel_softmaxY   s    
rI   c           	      C   s   || | j d  d }||kr@tt|| }| j d | }n tt|| }| j d | }| j d }| j d }| ||||} tjj| ||fd|d} | S )a  
    Args:
        attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
        height (`int`): height of the output attention map
        width (`int`): width of the output attention map
        align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.

    Returns:
        `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]
       g      ?r   r   Zbilinear)r   modealign_corners)rH   intnproundreshaper   r*   interpolate)	
attentionsheightwidthrL   scale
feat_widthfeat_height
batch_sizegroupsr#   r#   r$   resize_attention_mapo   s     

   rZ   c              	   C   s   g }t  b d}| D ]R}|ddd }|dkr:|}n|| }t|ddd f| }|| qW 5 Q R X |d }|S )a1  
    Args:
        attentions (`tuple(torch.FloatTensor)`: tuple of attention maps returned by `GroupViTVisionTransformer`
        hw_shape (`tuple(int)`): height and width of the output attention map
    Returns:
        `torch.Tensor`: the attention map of shape [batch_size, groups, height, width]
    Nr   rJ   r   r@   )r   Zno_gradpermute
contiguousrZ   append)rR   Zhw_shapeZ	attn_mapsZprev_attn_masksZ
attn_masksZcur_attn_mapZfinal_groupingr#   r#   r$   get_grouping_from_attentions   s    	
r^   c                       s*   e Zd Zed fddZdd Z  ZS )GroupViTCrossAttentionLayerconfigc                    sJ   t    t|| _tj|j|jd| _t	|| _
tj|j|jd| _d S NZeps)super__init__GroupViTAttentionattnr   	LayerNormhidden_sizelayer_norm_epsnorm2GroupViTMLPmlp	norm_postselfra   	__class__r#   r$   re      s
    


z$GroupViTCrossAttentionLayer.__init__c                 C   s<   |}|| j ||dd  }|| | | }| |}|S )N)encoder_hidden_statesr   )rg   rm   rk   rn   )rp   querykeyxr#   r#   r$   forward   s
    
z#GroupViTCrossAttentionLayer.forward)__name__
__module____qualname__r   re   rw   __classcell__r#   r#   rq   r$   r_      s   r_   c                       s4   e Zd Zed fddZd	ddZdd Z  ZS )
GroupViTAssignAttentionr`   c                    sj   t    |jd | _t|j|j| _t|j|j| _t|j|j| _t|j|j| _	|j
| _
d S )N      )rd   re   ri   rU   r   Linearq_projk_projv_projproj
assign_epsro   rq   r#   r$   re      s    
z GroupViTAssignAttention.__init__Tc                 C   s@   |r| j rt|d|d}n"|r,t|dd}ntjj|dd}|S )N)r1   rB   r1   )trainingrI   r?   r   r*   r5   )rp   rg   rE   rB   r#   r#   r$   get_attn   s    
z GroupViTAssignAttention.get_attnc                 C   s   |}|  |}| |}| |}||dd | j }| |}| j|ddd}||jddd| j  }|| }| |}||fS )Nr   r@   F)rE   rB   Tr1   r3   )	r   r   r   	transposerU   r   sumr   r   )rp   rt   ru   valueZraw_attnrg   Z	soft_attnoutr#   r#   r$   rw      s    




zGroupViTAssignAttention.forward)TT)rx   ry   rz   r   re   r   rw   r{   r#   r#   rq   r$   r|      s   

r|   c                       s2   e Zd Zed fddZdd Zdd Z  ZS )GroupViTTokenAssignr`   c                    s   t    || _tj j jd| _t j	t
jjr: j	n
 j	 j	f} fdd|D \}}t |||| _tj j jd| _tj j jd| _t | _t | _tj j jd| _t  j| j| _d S )Nrc   c                    s   g | ]}t | j qS r#   )rM   ri   ).0rv   r`   r#   r$   
<listcomp>   s     z0GroupViTTokenAssign.__init__.<locals>.<listcomp>)rd   re   num_output_groupr   rh   ri   rj   norm_tokens
isinstanceassign_mlp_ratiocollectionsabcIterableGroupViTMixerMLP	mlp_internorm_post_tokensnorm_xr_   pre_assign_attnr|   assign
norm_new_xrl   mlp_channels)rp   ra   num_group_tokenr   r   Z
tokens_dimZchannels_dimrq   r`   r$   re      s    



zGroupViTTokenAssign.__init__c                 C   s   |  |}| |}|S )z
        Args:
            group_tokens (torch.Tensor): group tokens, [batch_size, num_group_tokens, channels]

        Returns:
            projected_group_tokens (torch.Tensor): [batch_size, num_output_groups, channels]
        )r   r   )rp   group_tokensprojected_group_tokensr#   r#   r$   project_group_token   s    	

z'GroupViTTokenAssign.project_group_tokenc                 C   s^   |  |}| |}| |}| ||}| ||\}}||7 }|| | | }||fS )z
        Args:
            image_tokens (`torch.Tensor`): image tokens, of shape [batch_size, input_length, channels]
            group_tokens (`torch.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
        )r   r   r   r   r   r   r   )rp   Zimage_tokensr   r   Znew_image_tokens	attentionr#   r#   r$   rw     s    


zGroupViTTokenAssign.forward)rx   ry   rz   r   re   r   rw   r{   r#   r#   rq   r$   r      s   r   c                   @   s   e Zd ZU dZdZeej ed< dZ	ejed< dZ
ejed< dZejed< dZejed< dZejed< dZeed	< dZeed
< ee dddZdS )GroupViTModelOutputa\  
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Contrastive loss for image-text similarity.
        logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
            The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
            similarity scores.
        logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
            The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
            similarity scores.
        segmentation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
            Classification scores for each pixel.

            <Tip warning={true}>

            The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
            to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
            original image size as post-processing. You should always check your logits shape and resize as needed.

            </Tip>

        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The text embeddings obtained by applying the projection layer to the pooled output of
            [`GroupViTTextModel`].
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The image embeddings obtained by applying the projection layer to the pooled output of
            [`GroupViTVisionModel`].
        text_model_output (`BaseModelOutputWithPooling`):
            The output of the [`GroupViTTextModel`].
        vision_model_output (`BaseModelOutputWithPooling`):
            The output of the [`GroupViTVisionModel`].
    Nlosslogits_per_imagelogits_per_textsegmentation_logitstext_embedsimage_embedstext_model_outputvision_model_outputr'   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d kr | nt  | V  qdS ))r   r   N)getattrto_tuple)r   krp   r#   r$   	<genexpr>N  s   z/GroupViTModelOutput.to_tuple.<locals>.<genexpr>)tuplekeysr   r#   r   r$   r   M  s    zGroupViTModelOutput.to_tuple)rx   ry   rz   __doc__r   r   r   FloatTensor__annotations__r   r   r   r   r   r   r   r   r   r   r   r#   r#   r#   r$   r   !  s   
!r   c                       sV   e Zd ZdZdeeeeeef f eed fddZdej	e
ej	d
ddZ  ZS )GroupViTPatchEmbeddingsz#
    Image to Patch Embedding.
          r      
image_size
patch_sizenum_channels	embed_dimc                    s   t    t|tjjr|n||f}t|tjjr6|n||f}|d |d  |d |d   }|| _|| _|| _t	j
||||d| _d S )Nr   r   )Zkernel_sizeZstride)rd   re   r   r   r   r   r   r   num_patchesr   Conv2d
projection)rp   r   r   r   r   r   rq   r#   r$   re   Y  s    
 z GroupViTPatchEmbeddings.__init__Fpixel_valuesinterpolate_pos_encodingr'   c              
   C   sx   |j \}}}}|s\|| jd ks.|| jd kr\td| d| d| jd  d| jd  d	| |ddd}|S )Nr   r   zInput image size (*z) doesn't match model ().rJ   )rH   r   
ValueErrorr   flattenr   )rp   r   r   rX   r   rS   rT   rv   r#   r#   r$   rw   j  s    (zGroupViTPatchEmbeddings.forward)r   r   r   r   )F)rx   ry   rz   r   rM   r   r   re   r   Tensorr   rw   r{   r#   r#   rq   r$   r   T  s       r   c                       sR   e Zd Zed fddZejeeejdddZdeje	ejdd	d
Z
  ZS )GroupViTVisionEmbeddingsr`   c                    sp   t    t|j|j|j|jd| _| jj}t	
td||j| _t	|j| _t	j|j|jd| _|| _d S )Nr   r   rc   )rd   re   r   r   r   r   ri   patch_embeddingsr   r   	Parameterr   zerosposition_embeddingsZDropoutdropoutrh   rj   	layernormra   )rp   ra   r   rq   r#   r$   re   w  s    
z!GroupViTVisionEmbeddings.__init__)
embeddingsrS   rT   r'   c                 C   s   |j d }|| jj d kr(||kr(| jS | j}|j d }|j d }|| jj }|| jj }	|d |	d  }}	t| }
}|dt|
t||dddd}||
 |	| f}t	j
j||ddd	}|dddddd|}|S )
a#  
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        r   r@   g?r   r   rJ   ZbicubicF)scale_factorrK   rL   )rH   r   ra   r   mathsqrtrP   rM   r[   r   r*   rQ   view)rp   r   rS   rT   ZnpatchZpatch_pos_embedZnum_original_pos_embedr1   rW   rV   Zoriginal_heightZoriginal_widthZreshaped_patch_pos_embedr   r#   r#   r$   r     s2    	


   z1GroupViTVisionEmbeddings.interpolate_pos_encodingFr   c           
      C   sd   |j \}}}}| j||d}| |}| \}}}	|rL|| ||| }n
|| j }| |}|S )N)r   )rH   r   r   r   r   r   r   )
rp   r   r   rX   r   rS   rT   r   seq_len_r#   r#   r$   rw     s    


z GroupViTVisionEmbeddings.forward)F)rx   ry   rz   r   re   r   r   rM   r   r   rw   r{   r#   r#   rq   r$   r   v  s   "r   c                       sL   e Zd Zed fddZdeej eej eej ej	dddZ
  ZS )	GroupViTTextEmbeddingsr`   c                    sR   t    |j}t|j|| _t|j|| _| j	dt
|jddd d S )Nposition_ids)r   r@   F)
persistent)rd   re   ri   r   Z	EmbeddingZ
vocab_sizetoken_embeddingZmax_position_embeddingsposition_embeddingZregister_bufferr   r+   r   rp   ra   r   rq   r#   r$   re     s    
  zGroupViTTextEmbeddings.__init__N)	input_idsr   inputs_embedsr'   c                 C   sb   |d k	r|j d n|j d }|d kr:| jd d d |f }|d krL| |}| |}|| }|S )Nr@   r   )rH   r   r   r   )rp   r   r   r   Z
seq_lengthr   r   r#   r#   r$   rw     s    

zGroupViTTextEmbeddings.forward)NNN)rx   ry   rz   r   re   r   r   
LongTensorr   r   rw   r{   r#   r#   rq   r$   r     s      r   c                       s   e Zd ZdZeeeeed fddZedd Zdd Z	de
jee
j e
jd
ddZde
jee
j ee ee
j dddZ  ZS )GroupViTStagezMThis corresponds to the `GroupingLayer` class in the GroupViT implementation.)ra   depthnum_prev_group_tokenr   r   c                    s   t    || _|| _|dkr8ttd| j| _	nd | _	d| _
t fddt|D | _|dkr|t ||d| _nd | _|dkr|dkrttj j jdt | jd || _nd | _d S )	Nr   r   Fc                    s   g | ]}t  qS r#   GroupViTEncoderLayerr   r   r`   r#   r$   r     s     z*GroupViTStage.__init__.<locals>.<listcomp>)ra   r   r   rc   rJ   )rd   re   r   r   r   r   r   r   ri   group_tokengradient_checkpointing
ModuleListrangelayersr   
downsample
Sequentialrh   rj   r   group_projector)rp   ra   r   r   r   r   rq   r`   r$   re     s*    

zGroupViTStage.__init__c                 C   s
   | j d k	S N)r   r   r#   r#   r$   with_group_token  s    zGroupViTStage.with_group_tokenc                 C   sB   | j r6|d d d | j f |d d | j d f fS |d fS d S r   )r   r   rp   rv   r#   r#   r$   split_x  s    0zGroupViTStage.split_xN)rv   r   r'   c                 C   s   |d kr|S t j||gddS )Nr   r   )r   cat)rp   rv   r   r#   r#   r$   concat_x  s    zGroupViTStage.concat_xF)hidden_statesprev_group_tokenoutput_attentionsr'   c                 C   s   | j r6| j|ddd}| jdk	r:|| | }nd}|}| ||}| jD ]}||ddd}|d }qP| |\}}d}	| jdk	r| ||\}}	||f}
|r|
|	f }
|
S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the grouping tensors of Grouping block.
        r   r@   N)attention_maskcausal_attention_mask)	r   r   r   r   r   r   r   r   r   )rp   r   r   r   r   rv   Zcat_xlayerZ	layer_outr   outputsr#   r#   r$   rw     s$    




zGroupViTStage.forward)N)NF)rx   ry   rz   r   r   rM   re   propertyr   r   r   r   r   r   r   r   r   rw   r{   r#   r#   rq   r$   r     s&   #
  r   c                       sJ   e Zd Zdeee ee ee d fddZejejdddZ	  Z
S )	rl   N)ra   ri   intermediate_sizeoutput_sizec                    sp   t    || _t|j | _|d k	r(|n|j}|d k	r:|n|j}|d k	rL|n|}t	||| _
t	||| _d S r   )rd   re   ra   r	   Z
hidden_actactivation_fnri   r   r   r~   fc1fc2)rp   ra   ri   r   r   rq   r#   r$   re   =  s    
zGroupViTMLP.__init__)r   r'   c                 C   s"   |  |}| |}| |}|S r   )r  r  r  )rp   r   r#   r#   r$   rw   M  s    


zGroupViTMLP.forward)NNN)rx   ry   rz   r   r   rM   re   r   r   rw   r{   r#   r#   rq   r$   rl   <  s      rl   c                       s   e Zd Z fddZ  ZS )r   c                    s    t  |dd}|ddS Nr   rJ   )rd   rw   r   r   rq   r#   r$   rw   U  s    zGroupViTMixerMLP.forward)rx   ry   rz   rw   r{   r#   r#   rq   r$   r   T  s   r   c                       s   e Zd ZdZ fddZejeedddZdeje	ej e	ej e	ej
 e	e eeje	ej e	eej  f d	d
dZ  ZS )rf   z=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: r   r}   )rd   re   ra   ri   r   Znum_attention_heads	num_headshead_dimr   rU   Zattention_dropoutr   r   r~   r   r   r   out_projro   rq   r#   r$   re   ]  s    
zGroupViTAttention.__init__)rF   r   r!   c                 C   s    | ||| j| jdd S r  )r   r  r  r   r\   )rp   rF   r   r!   r#   r#   r$   _shapep  s    zGroupViTAttention._shapeNF)r   r   r   rs   r   r'   c                 C   s  |  \}}}|dk	}	| || j }
|	rT| | |d|}| | |d|}n(| | |d|}| | |d|}|| j d| jf}| |
||j| }
|j| }|j| }| d}t	
|
|dd}|  || j ||fkrtd|| j ||f d|   |dk	r||  |d||fkrRtd|d||f d|   ||| j||| }||| j ||}|dk	r|  |d||fkrtd|d||f d|   ||| j||| }||| j ||}tjj|dd}|r$||| j||}||| j ||}nd}tjj|| j| jd	}t	
||}|  || j || jfkrtd
|| j|| jf d|   ||| j|| j}|dd}||||}| |}||fS )z#Input shape: Batch x Time x ChannelNr@   r   rJ   z$Attention weights should be of size z	, but is z!Attention mask should be of size r   )pr   z `attn_output` should be of size )r   r   rU   r  r   r   r  r  r   r   Zbmmr   r   r   r*   r5   r   r   rP   r  )rp   r   r   r   rs   r   r!   r   r   Zis_cross_attentionZquery_statesZ
key_statesZvalue_statesZ
proj_shaper"   attn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr#   r#   r$   rw   s  s`    






zGroupViTAttention.forward)NNNF)rx   ry   rz   r   re   r   r   rM   r  r   r   r   r   rw   r{   r#   r#   rq   r$   rf   Z  s       rf   c                       sJ   e Zd Zed fddZdejejejee e	ej
 dddZ  ZS )	r   r`   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S rb   )rd   re   ri   r   rf   	self_attnr   rh   rj   layer_norm1rl   rm   layer_norm2ro   rq   r#   r$   re     s    


zGroupViTEncoderLayer.__init__F)r   r   r   r   r'   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r`||f7 }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r   r   r   )r  r  r  rm   )rp   r   r   r   r   Zresidualr
  r   r#   r#   r$   rw     s"    




zGroupViTEncoderLayer.forward)F)rx   ry   rz   r   re   r   r   r   r   r   r   rw   r{   r#   r#   rq   r$   r     s    r   c                   @   s.   e Zd ZdZeZdZdZdd Zd
ddZ	d	S )GroupViTPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    ZgroupvitTc                 C   s  | j j}t|tjtjfrD|jjjd|d |j	dk	rj|j	j
  n&t|tjrj|j	j
  |jjd | j j}t|tr|jjjjd|d d |jjjjd|d d nt|trB| j j}|jd d|j j d  | }|jd | }tjj|jj|d tjj|jj|d tjj|jj|d tjj|jj|d npt|tr| j j}|j jd d|j j d  | }d|j j d | }tjj|jj|d tjj|jj|d dS )	zInitialize the weightsrC   )meanstdNr   g{Gz?r}   rJ   )r  )ra   Zinitializer_ranger   r   r~   r   weightdataZnormal_biasZzero_rh   Zfill_Zinitializer_factorr   r   r   rf   r   num_hidden_layersinitr   r   r   r  rl   ri   r  r  )rp   moduleZ
init_rangefactorZin_proj_stdZout_proj_stdZfc_stdr#   r#   r$   _init_weights  s6    

z%GroupViTPreTrainedModel._init_weightsFc                 C   s   t |ttfr||_d S r   )r   GroupViTTextEncoderGroupViTVisionEncoderr   )rp   r  r   r#   r#   r$   _set_gradient_checkpointing(  s    z3GroupViTPreTrainedModel._set_gradient_checkpointingN)F)
rx   ry   rz   r   r   config_classZbase_model_prefixZsupports_gradient_checkpointingr  r  r#   r#   r#   r$   r    s   #r  aJ  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`GroupViTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
aE  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`CLIPImageProcessor.__call__`] for details.
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                       sR   e Zd Zedd fddZdejee ee ee e	e
ef dddZ  ZS )	r  N)ra   r'   c                    s>   t     | _t fddtt jD | _d| _	d S )Nc              
      sF   g | ]>}t   j|  j|  j| |d kr: j|d  nd dqS )r   r   )ra   r   r   r   r   )r   depthsZnum_group_tokensZnum_output_groups)r   ir`   r#   r$   r     s   z2GroupViTVisionEncoder.__init__.<locals>.<listcomp>F)
rd   re   ra   r   r   r   r,   r  stagesr   ro   rq   r`   r$   re     s    

zGroupViTVisionEncoder.__init__)r   output_hidden_statesr   return_dictr'   c                 C   s   |d k	r|n| j j}|d k	r |n| j j}|d k	r4|n| j j}|rDdnd }|rPdnd }d }t| jD ]P\}}	|rx||f }|	|||}
|
d }|
d }|rb|
d d k	rb||
d f }qb|r||f }|stdd |||fD S t|||dS )Nr#   r   r   rJ   c                 s   s   | ]}|d k	r|V  qd S r   r#   r   vr#   r#   r$   r     s      z0GroupViTVisionEncoder.forward.<locals>.<genexpr>last_hidden_stater   rR   )ra   r   r   use_return_dict	enumerater  r   r
   )rp   r   r   r   r!  Zall_hidden_statesZall_groupingsr   r  Zstagelayer_outputsr#   r#   r$   rw     s0    

  zGroupViTVisionEncoder.forward)NNN)rx   ry   rz   r   re   r   r   r   r   r   r   r
   rw   r{   r#   r#   rq   r$   r    s      
r  c                	       s`   e Zd ZdZed fddZd	eej eej ee	 ee	 ee	 e
eef dddZ  ZS )
r  z
    Transformer encoder consisting of `config.num_hidden_layers` self-attention layers. Each layer is a
    [`GroupViTEncoderLayer`].

    Args:
        config: GroupViTTextConfig
    r`   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r#   r   r   r`   r#   r$   r     s     z0GroupViTTextEncoder.__init__.<locals>.<listcomp>F)	rd   re   ra   r   r   r   r  r   r   ro   rq   r`   r$   re     s    
 zGroupViTTextEncoder.__init__N)r   r   r   r   r!  r'   c                    s   dk	r n| j j |dk	r |n| j j}|dk	r4|n| j j}|rDdnd} rPdnd}|}	t| jD ]r\}
}|rx||	f }| jr| jr fdd}tj	j

|||	||}n||	|| d}|d }	 rb||d f }qb|r||	f }|stdd	 |	||fD S t|	||d
S )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr#   c                    s    fdd}|S )Nc                     s    | f S r   r#   )inputs)r  r   r#   r$   custom_forward  s    zRGroupViTTextEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr#   )r  r*  r   )r  r$   create_custom_forward  s    z:GroupViTTextEncoder.forward.<locals>.create_custom_forwardr+  r   r   c                 s   s   | ]}|d k	r|V  qd S r   r#   r"  r#   r#   r$   r   (  s      z.GroupViTTextEncoder.forward.<locals>.<genexpr>r$  )ra   r   r   r&  r'  r   r   r   r   utils
checkpointr   r
   )rp   r   r   r   r   r   r!  Zencoder_statesZall_attentionsr   idxZencoder_layerr,  r(  r#   r+  r$   rw     sH    &

  zGroupViTTextEncoder.forward)NNNNN)rx   ry   rz   r   r   re   r   r   r   r   r   r   r
   rw   r{   r#   r#   rq   r$   r    s   	     
r  )input_ids_shaper   r)   past_key_values_lengthc                 C   s   | \}}t j||ft |j|d}t j|d|d}|||d |ddk d ||}|dkrt j	t j
||||d|gdd}|ddddddf |d||| S )zB
    Make causal mask used for bi-directional self-attention.
    r(   r@   r   r   r   r)   r   N)r   fullr   r    r+   r   Zmasked_fill_r   r   r   r   r   )r0  r   r)   r1  r!   r   r   Z	mask_condr#   r#   r$   _make_causal_mask/  s    "
 r4  c                       sx   e Zd Zed fddZeeeeedd	e	e
j e	e
j e	e
j e	e e	e e	e eeef dddZ  ZS )
GroupViTTextTransformerr`   c                    sH   t    || _|j}t|| _t|| _tj	||j
d| _|j| _d S rb   )rd   re   ra   ri   r   r   r  encoderr   rh   rj   final_layer_normeos_token_idr   rq   r#   r$   re   B  s    


z GroupViTTextTransformer.__init__output_typer  Nr   r   r   r   r   r!  r'   c                 C   sn  |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}|dkrLtd| }|d|d }| j||d}t||j	|j
d}	|dk	rt||j	}| j|||	|||d}
|
d }| |}| jdkr|tj|jd |j
d|jtj|j
d	jdd
f }n>|tj|jd |j
d|jtj|j
d	| jk jdd
f }|sZ||f|
dd  S t|||
j|
jdS )
        Returns:

        NzYou have to specify input_idsr@   )r   r   r(   )r   r   r   r   r   r!  r   rJ   r2  r   r   r%  Zpooler_outputr   rR   )ra   r   r   r&  r   r   r   r   r4  r   r)   r%   r6  r7  r8  r   r+   rH   r   rM   Zargmaxr   r   rR   )rp   r   r   r   r   r   r!  Zinput_shaper   r   encoder_outputsr%  pooled_outputr#   r#   r$   rw   M  sV    	
zGroupViTTextTransformer.forward)NNNNNN)rx   ry   rz   r   re   r   GROUPVIT_TEXT_INPUTS_DOCSTRINGr   r   r   r   r   r   r   r   rw   r{   r#   r#   rq   r$   r5  A  s$   
      
r5  c                       s   e Zd ZeZed fddZejdddZdd Z	e
eeeed	deej eej eej ee ee ee eeef dddZ  ZS )GroupViTTextModelr`   c                    s"   t  | t|| _|   d S r   )rd   re   r5  
text_model	post_initro   rq   r#   r$   re     s    
zGroupViTTextModel.__init__r   c                 C   s
   | j jjS r   rB  r   r   r   r#   r#   r$   get_input_embeddings  s    z&GroupViTTextModel.get_input_embeddingsc                 C   s   || j j_d S r   rD  )rp   r   r#   r#   r$   set_input_embeddings  s    z&GroupViTTextModel.set_input_embeddingsr9  Nr;  c                 C   s   | j ||||||dS )aK  
        Returns:

        Examples:

        ```python
        >>> from transformers import CLIPTokenizer, GroupViTTextModel

        >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> model = GroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   r   r   r   r!  )rB  )rp   r   r   r   r   r   r!  r#   r#   r$   rw     s    zGroupViTTextModel.forward)NNNNNN)rx   ry   rz   r   r  re   r   ModulerE  rF  r   r@  r   r   r   r   r   r   r   r   rw   r{   r#   r#   rq   r$   rA    s*   
      
rA  c                
       sh   e Zd Zed fddZeeeeedd	e	e
j e	e e	e e	e eeef dddZ  ZS )
GroupViTVisionTransformerr`   c                    s@   t    || _|j}t|| _t|| _tj	||j
d| _d S rb   )rd   re   ra   ri   r   r   r  r6  r   rh   rj   r   r   rq   r#   r$   re     s    


z"GroupViTVisionTransformer.__init__r9  N)r   r   r   r!  r'   c           	      C   s   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}|dkrLtd| |}| j||||d}|d }| |}|jdd}|s||f|dd  S t	|||j
|jdS )r<  Nz You have to specify pixel_values)r   r   r   r!  r   r   r   r=  )ra   r   r   r&  r   r   r6  r   r  r   r   rR   )	rp   r   r   r   r!  r   r>  r%  r?  r#   r#   r$   rw     s0    

z!GroupViTVisionTransformer.forward)NNNN)rx   ry   rz   r   re   r    GROUPVIT_VISION_INPUTS_DOCSTRINGr   r   r   r   r   r   r   r   rw   r{   r#   r#   rq   r$   rI    s   	
    
rI  c                
       s~   e Zd ZeZdZed fddZedddZe	e
eeeddeej ee ee ee eeef d
ddZ  ZS )GroupViTVisionModelr   r`   c                    s"   t  | t|| _|   d S r   )rd   re   rI  vision_modelrC  ro   rq   r#   r$   re     s    
zGroupViTVisionModel.__init__r   c                 C   s
   | j jjS r   )rL  r   r   r   r#   r#   r$   rE    s    z(GroupViTVisionModel.get_input_embeddingsr9  Nr   r   r   r!  r'   c                 C   s   | j ||||dS )a  
        Returns:

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, GroupViTVisionModel

        >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```r   r   r   r!  )rL  )rp   r   r   r   r!  r#   r#   r$   rw     s    zGroupViTVisionModel.forward)NNNN)rx   ry   rz   r   r  Zmain_input_namere   r   rE  r   rJ  r   r   r   r   r   r   r   r   rw   r{   r#   r#   rq   r$   rK    s"   
    
rK  c                       s  e Zd ZeZed fddZeedee	j
 ee	j
 ee	j
 ee ee ee e	jdddZeedee	j ee ee ee e	jdd	d
Zeeeeeddee	j ee	j ee	j
 ee	j ee ee ee ee ee eeef d
ddZ  ZS )GroupViTModelr`   c              
      s6  t  | t|jts.tdt|j dt|jtsPtdt|j d|j}|j}|j	| _	|j
| _
|j| _|j| _t|| _t|| _ttj| j| j
ddt| j
tjddtj| j
| j	dd| _ttj| j| j
ddt| j
tjddtj| j
| j	dd| _tt| jj| _|   d S )NzOconfig.text_config is expected to be of type GroupViTTextConfig but is of type .zSconfig.vision_config is expected to be of type GroupViTVisionConfig but is of type T)r  )Zinplace) rd   re   r   text_configr   r   typevision_configr   Zprojection_dimZprojection_intermediate_dimri   Ztext_embed_dimZvision_embed_dimr5  rB  rI  rL  r   r   r~   ZBatchNorm1dZReLUvisual_projectiontext_projectionr   r   rF   ra   Zlogit_scale_init_valuelogit_scalerC  )rp   ra   rQ  rS  rq   r#   r$   re   E  s>    





zGroupViTModel.__init__Nr;  c           
      C   sh   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}| j||||||d}|d }| |}	|	S )a  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`GroupViTTextModel`].

        Examples:

        ```python
        >>> from transformers import CLIPTokenizer, GroupViTModel

        >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
        >>> text_features = model.get_text_features(**inputs)
        ```NrG  r   )ra   r   r   r&  rB  rU  )
rp   r   r   r   r   r   r!  text_outputsr?  Ztext_featuresr#   r#   r$   get_text_featuresp  s    	
zGroupViTModel.get_text_featuresrM  c                 C   sd   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}| j||||d}|d }| |}|S )aH  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`GroupViTVisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, GroupViTModel

        >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```NrN  r   )ra   r   r   r&  rL  rT  )rp   r   r   r   r!  vision_outputsr?  Zimage_featuresr#   r#   r$   get_image_features  s    
z GroupViTModel.get_image_featuresr9  )
r   r   r   r   return_lossr   r   output_segmentationr!  r'   c
              
   C   sV  |dk	r|n| j j}|dk	r |n| j j}|r0d}|dk	r<|n| j j}|	dk	rP|	n| j j}	| j||||	d}
| j||||||	d}|
d }| |}|d }| |}||j	ddd }||j	ddd }| j
 }t|| | }| }d}|r|
d }| |d|jd }|r(|
d	 }n|
d
 }t||jd
d }||j	ddd }t|| | }||jd d|jd dd
d}||jd |jd d}t||| }||jd |jd |jd
 |jd	 }d}|rt|}|	s>|dk	r|||||||
f}n||||||
f}|dk	r:|f| S |S t||||||||
dS )a  
        Returns:

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, GroupViTModel

        >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(
        ...     text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
        ... )

        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```NTrN  rG  r   r@   r   r   r   rJ   )r   r   r   r   r   r   r   r   )ra   r   r\  r   r&  rL  rB  rT  rU  ZnormrV  expr   matmulr/   rP   rH   r^   r[   r0   r   )rp   r   r   r   r   r[  r   r   r\  r!  rY  rW  r   r   rV  r   r   Z
seg_logitsZimage_group_embedsrR   groupingZlogits_per_image_groupZflatten_groupingr   outputr#   r#   r$   rw     s    '	



       

zGroupViTModel.forward)NNNNNN)NNNN)	NNNNNNNNN)rx   ry   rz   r   r  re   r   r@  r   r   r   r   r   rX  rJ  rZ  GROUPVIT_INPUTS_DOCSTRINGr   r   r   r   r   rw   r{   r#   r#   rq   r$   rO  A  sj   +      .    0
         
rO  )N)r   Fr@   )F)r   )Nr   collections.abcr   r   dataclassesr   typingr   r   r   r   numpyrN   r   Ztorch.utils.checkpointr   Zactivationsr	   Zmodeling_outputsr
   r   Zmodeling_utilsr   r-  r   r   r   r   r   Zconfiguration_groupvitr   r   r   Z
get_loggerrx   loggerZ_CHECKPOINT_FOR_DOCZ&GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LISTr   r   rM   r%   r-   r0   r?   floatr   rI   rZ   r^   rH  r_   r|   r   r   r   r   r   r   rl   r   rf   r   r  ZGROUPVIT_START_DOCSTRINGr@  rJ  ra  r  r  Sizer)   r4  r5  rA  rI  rK  rO  r#   r#   r#   r$   <module>   sr   


072"F!_o22 &:i    ]5:4