U
    9%e                     @   s  d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
mZmZ ddlZddlZddlmZmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZ ddlmZmZmZmZm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z& e  rddl'm(Z( e! rddl)m*Z* e"+e,Z-dZ.dZ/dddgZ0dgZ1eG dd deZ2G dd dej3Z4G dd dej3Z5G dd dej3Z6G dd dej3Z7G dd dej3Z8G d d! d!ej3Z9G d"d# d#ej3Z:G d$d% d%ej3Z;G d&d' d'ej3Z<G d(d) d)ej3Z=G d*d+ d+ej3Z>G d,d- d-eZ?d.Z@d/ZAed0e@G d1d2 d2e?ZBG d3d4 d4ej3ZCed5e@G d6d7 d7e?ZDd8d9 ZEdVeFeFd<d=d>ZGG d?d@ d@ej3ZHG dAdB dBej3ZIG dCdD dDej3ZJeedEdFdGZKeedHdIdJZLdKdL ZMdMdN ZNdOdP ZOG dQdR dRePZQee dSdTdUZRdS )Wz PyTorch YOLOS model.    N)	dataclass)DictListOptionalSetTupleUnion)Tensornn   )ACT2FN)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)	ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardis_scipy_availableis_vision_availableloggingreplace_return_docstringsrequires_backends   )YolosConfiglinear_sum_assignment)center_to_corners_formatr   zhustvl/yolos-smalliI  i  c                   @   s   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZejed< dZejed< dZeee
  ed< dZeej ed< dZeeej  ed	< dZeeej  ed
< dS )YolosObjectDetectionOutputaG
  
    Output type of [`YolosForObjectDetection`].

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
            Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
            bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
            scale-invariant IoU loss.
        loss_dict (`Dict`, *optional*):
            A dictionary containing the individual losses. Useful for logging.
        logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
            Classification logits (including no-object) for all queries.
        pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
            Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
            values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
            possible padding). You can use [`~YolosImageProcessor.post_process`] to retrieve the unnormalized bounding
            boxes.
        auxiliary_outputs (`list[Dict]`, *optional*):
            Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
            and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
            `pred_boxes`) for each decoder layer.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the decoder of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
            the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nloss	loss_dictlogits
pred_boxesauxiliary_outputslast_hidden_statehidden_states
attentions)__name__
__module____qualname____doc__r!   r   torchFloatTensor__annotations__r"   r   r#   r$   r%   r   r&   r'   r   r(    r0   r0   g/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/yolos/modeling_yolos.pyr    D   s   
!r    c                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	YolosEmbeddingszT
    Construct the CLS token, detection tokens, position and patch embeddings.

    Nconfigreturnc                    s   t    ttdd|j| _ttd|j|j| _	t
|| _| jj}ttd||j d |j| _t|j| _t|| _|| _d S Nr   )super__init__r
   	Parameterr-   zeroshidden_size	cls_tokennum_detection_tokensdetection_tokensYolosPatchEmbeddingspatch_embeddingsnum_patchesposition_embeddingsDropouthidden_dropout_probdropout$InterpolateInitialPositionEmbeddingsinterpolationr4   )selfr4   rA   	__class__r0   r1   r8   w   s    


zYolosEmbeddings.__init__pixel_valuesr5   c                 C   s   |j \}}}}| |}| \}}}| j|dd}	| j|dd}
tj|	||
fdd}| | j	||f}|| }| 
|}|S )Nr   dim)shaper@   sizer<   expandr>   r-   catrG   rB   rE   )rH   rL   
batch_sizenum_channelsheightwidth
embeddingsseq_len_Z
cls_tokensr>   rB   r0   r0   r1   forward   s    

zYolosEmbeddings.forward
r)   r*   r+   r,   r   r8   r-   r	   r[   __classcell__r0   r0   rI   r1   r2   q   s   r2   c                       s4   e Zd Zdd fddZdejdddZ  ZS )	rF   Nr5   c                    s   t    || _d S Nr7   r8   r4   rH   r4   rI   r0   r1   r8      s    
z-InterpolateInitialPositionEmbeddings.__init__i   i@  c                 C   s  |d d dd d f }|d d d f }|d d | j j d d d f }|d d d| j j d d f }|dd}|j\}}}| j jd | j j | j jd | j j  }	}
||||	|
}|\}}|| j j || j j  }}tjj	|||fddd}|
ddd}tj|||fdd}|S )Nr   r      bicubicFrQ   modeZalign_cornersrN   )r4   r=   	transposerP   
image_size
patch_sizeviewr
   
functionalinterpolateflattenr-   rS   )rH   	pos_embedimg_sizecls_pos_embeddet_pos_embedpatch_pos_embedrT   r;   rY   patch_heightpatch_widthrV   rW   Znew_patch_heigthnew_patch_widthscale_pos_embedr0   r0   r1   r[      s*         z,InterpolateInitialPositionEmbeddings.forward)rb   r)   r*   r+   r8   r-   r	   r[   r]   r0   r0   rI   r1   rF      s   rF   c                       s4   e Zd Zdd fddZdejdddZ  ZS )	 InterpolateMidPositionEmbeddingsNr^   c                    s   t    || _d S r_   r`   ra   rI   r0   r1   r8      s    
z)InterpolateMidPositionEmbeddings.__init__rb   c                 C   sH  |d d d d dd d f }|d d d f }|d d d d | j j d d d f }|d d d d d| j j d d f }|dd}|j\}}}}	| j jd | j j | j jd | j j  }
}||| ||
|}|\}}|| j j || j j  }}tjj	|||fddd}|
ddd |||| |}tj|||fdd}|S )	Nr   r   rc   r   rd   Fre   rN   )r4   r=   rg   rP   rh   ri   rj   r
   rk   rl   rm   
contiguousr-   rS   )rH   rn   ro   rp   rq   rr   depthrT   r;   rY   rs   rt   rV   rW   Znew_patch_heightru   rv   r0   r0   r1   r[      s<    &&   
    z(InterpolateMidPositionEmbeddings.forward)rb   rw   r0   r0   rI   r1   rx      s   rx   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )r?   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )Zkernel_sizeZstride)r7   r8   rh   ri   rU   r;   
isinstancecollectionsabcIterablerA   r
   Conv2d
projection)rH   r4   rh   ri   rU   r;   rA   rI   r0   r1   r8      s    
 zYolosPatchEmbeddings.__init__rK   c                 C   s<   |j \}}}}|| jkr td| |ddd}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rc   r   )rP   rU   
ValueErrorr   rm   rg   )rH   rL   rT   rU   rV   rW   rX   r0   r0   r1   r[      s    
zYolosPatchEmbeddings.forward)	r)   r*   r+   r,   r8   r-   r	   r[   r]   r0   r0   rI   r1   r?      s   r?   c                       sl   e Zd Zedd fddZejejdddZdeej e	e
eejejf eej f d	d
dZ  ZS )YolosSelfAttentionNr3   c                    s   t    |j|j dkr@t|ds@td|jf d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)bias)r7   r8   r;   num_attention_headshasattrr   intattention_head_sizeall_head_sizer
   LinearZqkv_biasquerykeyvaluerC   Zattention_probs_dropout_probrE   ra   rI   r0   r1   r8      s    
zYolosSelfAttention.__init__)xr5   c                 C   s6   |  d d | j| jf }||}|ddddS )NrM   r   rc   r   r   )rQ   r   r   rj   permute)rH   r   Znew_x_shaper0   r0   r1   transpose_for_scores  s    
z'YolosSelfAttention.transpose_for_scoresF)	head_maskoutput_attentionsr5   c                 C   s   |  |}| | |}| | |}| |}t||dd}|t| j	 }t
jj|dd}	| |	}	|d k	r|	| }	t|	|}
|
dddd }
|
 d d | jf }|
|}
|r|
|	fn|
f}|S )NrM   rN   r   rc   r   r   )r   r   r   r   r-   matmulrg   mathsqrtr   r
   rk   softmaxrE   r   ry   rQ   r   rj   )rH   r'   r   r   Zmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr0   r0   r1   r[     s     



zYolosSelfAttention.forward)NF)r)   r*   r+   r   r8   r-   r	   r   r   boolr   r   r[   r]   r0   r0   rI   r1   r      s       r   c                       s@   e Zd ZdZedd fddZejejejdddZ  Z	S )	YolosSelfOutputz
    The residual connection is defined in YolosLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    Nr3   c                    s.   t    t|j|j| _t|j| _d S r_   )	r7   r8   r
   r   r;   denserC   rD   rE   ra   rI   r0   r1   r8   ?  s    
zYolosSelfOutput.__init__r'   input_tensorr5   c                 C   s   |  |}| |}|S r_   r   rE   rH   r'   r   r0   r0   r1   r[   D  s    

zYolosSelfOutput.forwardr\   r0   r0   rI   r1   r   9  s   r   c                       sp   e Zd Zedd fddZee ddddZdej	e
ej	 eeeej	ej	f eej	 f d	d
dZ  ZS )YolosAttentionNr3   c                    s*   t    t|| _t|| _t | _d S r_   )r7   r8   r   	attentionr   outputsetpruned_headsra   rI   r0   r1   r8   M  s    


zYolosAttention.__init__)headsr5   c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rN   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)rH   r   indexr0   r0   r1   prune_headsS  s       zYolosAttention.prune_headsFr'   r   r   r5   c                 C   s4   |  |||}| |d |}|f|dd   }|S Nr   r   )r   r   )rH   r'   r   r   Zself_outputsattention_outputr   r0   r0   r1   r[   e  s    zYolosAttention.forward)NF)r)   r*   r+   r   r8   r   r   r   r-   r	   r   r   r   r   r[   r]   r0   r0   rI   r1   r   L  s     r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )YolosIntermediateNr3   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r_   )r7   r8   r
   r   r;   intermediate_sizer   r{   Z
hidden_actstrr   intermediate_act_fnra   rI   r0   r1   r8   u  s
    
zYolosIntermediate.__init__)r'   r5   c                 C   s   |  |}| |}|S r_   )r   r   )rH   r'   r0   r0   r1   r[   }  s    

zYolosIntermediate.forward	r)   r*   r+   r   r8   r-   r	   r[   r]   r0   r0   rI   r1   r   t  s   r   c                       s<   e Zd Zedd fddZejejejdddZ  ZS )YolosOutputNr3   c                    s.   t    t|j|j| _t|j| _	d S r_   )
r7   r8   r
   r   r   r;   r   rC   rD   rE   ra   rI   r0   r1   r8     s    
zYolosOutput.__init__r   c                 C   s    |  |}| |}|| }|S r_   r   r   r0   r0   r1   r[     s    

zYolosOutput.forwardr   r0   r0   rI   r1   r     s   r   c                       s`   e Zd ZdZedd fddZd
ejeej e	e
eejejf eej f ddd	Z  ZS )
YolosLayerz?This corresponds to the Block class in the timm implementation.Nr3   c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   Zeps)r7   r8   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r
   	LayerNormr;   layer_norm_epslayernorm_beforelayernorm_afterra   rI   r0   r1   r8     s    



zYolosLayer.__init__Fr   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )Nr   r   r   )r   r   r   r   r   )rH   r'   r   r   Zself_attention_outputsr   r   Zlayer_outputr0   r0   r1   r[     s    


zYolosLayer.forward)NF)r)   r*   r+   r,   r   r8   r-   r	   r   r   r   r   r[   r]   r0   r0   rI   r1   r     s     r   c                	       sN   e Zd Zedd fddZd
ejeej eeee	e
ef ddd	Z  ZS )YolosEncoderNr3   c                    s   t     | _t fddt jD | _d| _d j	d  j	d   j
d    j } jrtt jd d| jnd | _ jrt nd | _d S )Nc                    s   g | ]}t  qS r0   )r   ).0rZ   r4   r0   r1   
<listcomp>  s     z)YolosEncoder.__init__.<locals>.<listcomp>Fr   r   rc   )r7   r8   r4   r
   
ModuleListrangenum_hidden_layerslayergradient_checkpointingrh   ri   r=   use_mid_position_embeddingsr9   r-   r:   r;   mid_position_embeddingsrx   rG   )rH   r4   Z
seq_lengthrI   r   r1   r8     s$    
 &	zYolosEncoder.__init__FT)r'   r   r   output_hidden_statesreturn_dictr5   c                    s  |rdnd } rdnd }	| j jr2| | j||f}
t| jD ]\}}|rR||f }|d k	rb|| nd }| jr| jr fdd}tj	j

||||}n||| }|d }| j jr|| j jd k r||
|  } r<|	|d f }	q<|r||f }|stdd |||	fD S t|||	dS )	Nr0   c                    s    fdd}|S )Nc                     s    | f S r_   r0   )inputs)moduler   r0   r1   custom_forward  s    zKYolosEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr0   )r   r   r   )r   r1   create_custom_forward  s    z3YolosEncoder.forward.<locals>.create_custom_forwardr   r   c                 s   s   | ]}|d k	r|V  qd S r_   r0   r   vr0   r0   r1   	<genexpr>  s      z'YolosEncoder.forward.<locals>.<genexpr>)r&   r'   r(   )r4   r   rG   r   	enumerater   r   Ztrainingr-   utils
checkpointr   tupler   )rH   r'   rV   rW   r   r   r   r   Zall_hidden_statesZall_self_attentionsZ$interpolated_mid_position_embeddingsiZlayer_moduleZlayer_head_maskr   Zlayer_outputsr0   r   r1   r[     s>    


zYolosEncoder.forward)NFFT)r)   r*   r+   r   r8   r-   r	   r   r   r   r   r   r[   r]   r0   r0   rI   r1   r     s       
r   c                   @   sT   e Zd ZdZeZdZdZdZe	e
je
je
jf ddddZdeedd
ddZdS )YolosPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    vitrL   TN)r   r5   c                 C   sj   t |tjtjfr@|jjjd| jjd |j	dk	rf|j	j
  n&t |tjrf|j	j
  |jjd dS )zInitialize the weightsg        )meanZstdNg      ?)r{   r
   r   r   weightdataZnormal_r4   Zinitializer_ranger   Zzero_r   Zfill_)rH   r   r0   r0   r1   _init_weights  s    
z"YolosPreTrainedModel._init_weightsF)r   r   r5   c                 C   s   t |tr||_d S r_   )r{   r   r   )rH   r   r   r0   r0   r1   _set_gradient_checkpointing*  s    
z0YolosPreTrainedModel._set_gradient_checkpointing)F)r)   r*   r+   r,   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r
   r   r   r   r   r   r   r   r0   r0   r0   r1   r     s    r   aG  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`YolosConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aM  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`YolosImageProcessor.__call__`] for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z_The bare YOLOS Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zdeed fddZedddZee	e
e	 f dd	d
dZeeeeeededdeej eej ee ee ee eeef dddZ  ZS )
YolosModelT)r4   add_pooling_layerc                    sX   t  | || _t|| _t|| _tj|j	|j
d| _|rFt|nd | _|   d S )Nr   )r7   r8   r4   r2   rX   r   encoderr
   r   r;   r   	layernormYolosPoolerpooler	post_init)rH   r4   r   rI   r0   r1   r8   V  s    

zYolosModel.__init__r^   c                 C   s   | j jS r_   )rX   r@   rH   r0   r0   r1   get_input_embeddingsc  s    zYolosModel.get_input_embeddingsN)heads_to_pruner5   c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model.

        Args:
            heads_to_prune (`dict` of {layer_num: list of heads to prune in this layer}):
                See base class `PreTrainedModel`.
        N)itemsr   r   r   r   )rH   r   r   r   r0   r0   r1   _prune_headsf  s    zYolosModel._prune_headsZvision)r   output_typer   ZmodalityZexpected_output)rL   r   r   r   r   r5   c              	   C   s   |d k	r|n| j j}|d k	r |n| j j}|d k	r4|n| j j}|d krLtd| || j j}| |}| j||j	d |j	d ||||d}|d }| 
|}| jd k	r| |nd }	|s|	d k	r||	fn|f}
|
|dd   S t||	|j|jdS )Nz You have to specify pixel_valuesr   rM   )rV   rW   r   r   r   r   r   r   )r&   Zpooler_outputr'   r(   )r4   r   r   use_return_dictr   Zget_head_maskr   rX   r   rP   r   r   r   r'   r(   )rH   rL   r   r   r   r   Zembedding_outputZencoder_outputssequence_outputpooled_outputZhead_outputsr0   r0   r1   r[   q  s:    
	
zYolosModel.forward)T)NNNNN)r)   r*   r+   r   r   r8   r?   r   r   r   r   r   r   YOLOS_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r-   r	   r   r   r[   r]   r0   r0   rI   r1   r   Q  s0   	     
r   c                       s*   e Zd Zed fddZdd Z  ZS )r   r   c                    s*   t    t|j|j| _t | _d S r_   )r7   r8   r
   r   r;   r   ZTanh
activationra   rI   r0   r1   r8     s    
zYolosPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )rH   r'   Zfirst_token_tensorr   r0   r0   r1   r[     s    

zYolosPooler.forward)r)   r*   r+   r   r8   r[   r]   r0   r0   rI   r1   r     s   r   zy
    YOLOS Model (consisting of a ViT encoder) with object detection heads on top, for tasks such as COCO detection.
    c                       s~   e Zd Zed fddZejjdd Ze	e
eeeddejeee  ee ee ee eeef dd	d
Z  ZS )YolosForObjectDetectionr   c                    sX   t  | t|dd| _t|j|j|jd dd| _t|j|jddd| _| 	  d S )NF)r   r   r   )	input_dim
hidden_dim
output_dim
num_layers   )
r7   r8   r   r   YolosMLPPredictionHeadr;   
num_labelsclass_labels_classifierbbox_predictorr   ra   rI   r0   r1   r8     s          z YolosForObjectDetection.__init__c                 C   s$   dd t |d d |d d D S )Nc                 S   s   g | ]\}}||d qS ))r#   r$   r0   )r   abr0   r0   r1   r     s     z9YolosForObjectDetection._set_aux_loss.<locals>.<listcomp>rM   )zip)rH   outputs_classoutputs_coordr0   r0   r1   _set_aux_loss  s    z%YolosForObjectDetection._set_aux_loss)r   r   N)rL   labelsr   r   r   r5   c              
      s  |dk	r|n| j j}| j||||d}|d }|dd| j j dddf }| |}| | }	d\}
}|dk	rt| j j| j j	| j j
d}dddg}t|| j j| j j|d	}|| j i }||d
< |	|d< | j jr|r|jn|d }| |}| | }| ||}||d< |||d| j jd| j jd< | j jri }t| j jd D ]" | fdd D  q`| tfdd D }
|s|dk	r||	f| | }n||	f| }|
dk	r|
f| S |S t|
||	||j|j|jdS )a  
        labels (`List[Dict]` of len `(batch_size,)`, *optional*):
            Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
            following 2 keys: `'class_labels'` and `'boxes'` (the class labels and bounding boxes of an image in the
            batch respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding
            boxes in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image,
            4)`.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
        >>> model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> # convert outputs (bounding boxes and class logits) to COCO API
        >>> target_sizes = torch.tensor([image.size[::-1]])
        >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
        ...     0
        ... ]

        >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(
        ...         f"Detected {model.config.id2label[label.item()]} with confidence "
        ...         f"{round(score.item(), 3)} at location {box}"
        ...     )
        Detected remote with confidence 0.994 at location [46.96, 72.61, 181.02, 119.73]
        Detected remote with confidence 0.975 at location [340.66, 79.19, 372.59, 192.65]
        Detected cat with confidence 0.984 at location [12.27, 54.25, 319.42, 470.99]
        Detected remote with confidence 0.922 at location [41.66, 71.96, 178.7, 120.33]
        Detected cat with confidence 0.914 at location [342.34, 21.48, 638.64, 372.46]
        ```N)r   r   r   r   )NNN
class_cost	bbox_cost	giou_costr  boxescardinality)matchernum_classeseos_coeflossesr#   r$   r   r%   r   )loss_ce	loss_bbox	loss_giouc                    s    i | ]\}}|d    |qS rZ   r0   r   kr   r   r0   r1   
<dictcomp>G  s      z3YolosForObjectDetection.forward.<locals>.<dictcomp>c                 3   s&   | ]}|kr | |  V  qd S r_   r0   )r   r  )r"   weight_dictr0   r1   r   I  s      z2YolosForObjectDetection.forward.<locals>.<genexpr>)r!   r"   r#   r$   r%   r&   r'   r(   )r4   r   r   r=   r   r   sigmoidYolosHungarianMatcherr  r  r  	YolosLossr   Zeos_coefficienttodeviceZauxiliary_lossZintermediate_hidden_statesr   Zbbox_loss_coefficientZgiou_loss_coefficientr   Zdecoder_layersupdater   sumkeysr    r&   r'   r(   )rH   rL   r  r   r   r   r   r   r#   r$   r!   r%   r  r  	criterionZoutputs_lossr   r   r   Zaux_weight_dictr   r0   )r   r"   r  r1   r[     sv    7 


  




 

zYolosForObjectDetection.forward)NNNN)r)   r*   r+   r   r8   r-   ZjitZunusedr   r   r   r   r    r   r.   r   r   r   r   r   r   r[   r]   r0   r0   rI   r1   r     s"   

    

r   c                 C   sX   |   } | d} d| | d }| d|d }d|d |d   }| | S )a  
    Compute the DICE loss, similar to generalized IOU for masks

    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs (0 for the negative class and 1 for the positive
                 class).
    r   rc   rM   )r  rm   r  )r   targets	num_boxes	numeratordenominatorr!   r0   r0   r1   	dice_loss_  s    
r"        ?rc   )alphagammac           
      C   s|   |   }tjj| |dd}|| d| d|   }|d| |  }|dkrj|| d| d|   }	|	| }|d | S )a  
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
        inputs (`torch.FloatTensor` of arbitrary shape):
            The predictions for each example.
        targets (`torch.FloatTensor` with the same shape as `inputs`)
            A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
            and 1 for the positive class).
        alpha (`float`, *optional*, defaults to `0.25`):
            Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
        gamma (`int`, *optional*, defaults to `2`):
            Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.

    Returns:
        Loss tensor
    noneZ	reductionr   r   )r  r
   rk   Z binary_cross_entropy_with_logitsr   r  )
r   r  r  r$  r%  ZprobZce_lossZp_tr!   Zalpha_tr0   r0   r1   sigmoid_focal_losss  s    r(  c                       sh   e Zd ZdZ fddZdd Ze dd Zdd	 Z	d
d Z
dd Zdd Zdd Zdd Z  ZS )r  a  
    This class computes the losses for YolosForObjectDetection/YolosForSegmentation. The process happens in two steps:
    1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each
    pair of matched ground-truth / prediction (supervise class and box).

    A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
    parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
    the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
    be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
    (`max_obj_id` + 1). For more details on this, check the following discussion
    https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"


    Args:
        matcher (`YolosHungarianMatcher`):
            Module able to compute a matching between targets and proposals.
        num_classes (`int`):
            Number of object categories, omitting the special no-object category.
        eos_coef (`float`):
            Relative classification weight applied to the no-object category.
        losses (`List[str]`):
            List of all the losses to be applied. See `get_loss` for a list of all available losses.
    c                    sL   t    || _|| _|| _|| _t| jd }| j|d< | d| d S )Nr   rM   empty_weight)	r7   r8   r  r	  r
  r  r-   onesZregister_buffer)rH   r  r	  r
  r  r)  rI   r0   r1   r8     s    

zYolosLoss.__init__c                 C   s   d|krt d|d }| |}tdd t||D }tj|jdd | jtj|j	d}|||< t
j|dd|| j}	d	|	i}
|
S )
z
        Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
        [nb_target_boxes]
        r#   z#No logits were found in the outputsc                 S   s    g | ]\}\}}|d  | qS class_labelsr0   )r   trZ   Jr0   r0   r1   r     s    
 z)YolosLoss.loss_labels.<locals>.<listcomp>Nrc   dtyper  r   r  )KeyError_get_source_permutation_idxr-   rS   r   fullrP   r	  int64r  r
   rk   Zcross_entropyrg   r)  )rH   r   r  indicesr  Zsource_logitsidxZtarget_classes_oZtarget_classesr  r  r0   r0   r1   loss_labels  s    
   zYolosLoss.loss_labelsc                 C   sf   |d }|j }tjdd |D |d}|d|jd d kd}tj|	 |	 }	d|	i}
|
S )z
        Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.

        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
        r#   c                 S   s   g | ]}t |d  qS r+  r   r   r0   r0   r1   r     s     z.YolosLoss.loss_cardinality.<locals>.<listcomp>)r  rM   r   Zcardinality_error)
r  r-   	as_tensorZargmaxrP   r  r
   rk   l1_lossfloat)rH   r   r  r5  r  r#   r  Ztarget_lengthsZ	card_predZcard_errr  r0   r0   r1   loss_cardinality  s    zYolosLoss.loss_cardinalityc                 C   s   d|krt d| |}|d | }tjdd t||D dd}tjj||dd}i }	| | |	d	< d
t	t
t|t| }
|
 | |	d< |	S )a<  
        Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.

        Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
        are expected in format (center_x, center_y, w, h), normalized by the image size.
        r$   z#No predicted boxes found in outputsc                 S   s    g | ]\}\}}|d  | qS r  r0   )r   r-  rZ   r   r0   r0   r1   r     s    
 z(YolosLoss.loss_boxes.<locals>.<listcomp>r   rN   r&  r'  r  r   r  )r1  r2  r-   rS   r   r
   rk   r:  r  Zdiaggeneralized_box_iour   )rH   r   r  r5  r  r6  Zsource_boxesZtarget_boxesr  r  r  r0   r0   r1   
loss_boxes  s    
zYolosLoss.loss_boxesc                 C   s   d|krt d| |}| |}|d }|| }dd |D }t| \}	}
|	|}	|	| }	tjj|dddf |	j	dd ddd	}|ddd
f 
d}|	
d}	|	|j	}	t||	|t||	|d}|S )z
        Compute the losses related to the masks: the focal loss and the dice loss.

        Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
        Z
pred_masksz#No predicted masks found in outputsc                 S   s   g | ]}|d  qS )masksr0   r   r-  r0   r0   r1   r     s     z(YolosLoss.loss_masks.<locals>.<listcomp>Nr   ZbilinearFre   r   r   )Z	loss_maskZ	loss_dice)r1  r2  _get_target_permutation_idxnested_tensor_from_tensor_list	decomposer  r
   rk   rl   rP   rm   rj   r(  r"  )rH   r   r  r5  r  
source_idx
target_idxZsource_masksr@  Ztarget_masksZvalidr  r0   r0   r1   
loss_masks  s.    


   


zYolosLoss.loss_masksc                 C   s4   t dd t|D }t dd |D }||fS )Nc                 S   s    g | ]\}\}}t ||qS r0   r-   Z	full_like)r   r   sourcerZ   r0   r0   r1   r     s    
 z9YolosLoss._get_source_permutation_idx.<locals>.<listcomp>c                 S   s   g | ]\}}|qS r0   r0   )r   rI  rZ   r0   r0   r1   r     s     r-   rS   r   )rH   r5  	batch_idxrE  r0   r0   r1   r2    s    z%YolosLoss._get_source_permutation_idxc                 C   s4   t dd t|D }t dd |D }||fS )Nc                 S   s    g | ]\}\}}t ||qS r0   rH  )r   r   rZ   targetr0   r0   r1   r     s    
 z9YolosLoss._get_target_permutation_idx.<locals>.<listcomp>c                 S   s   g | ]\}}|qS r0   r0   )r   rZ   rL  r0   r0   r1   r     s     rJ  )rH   r5  rK  rF  r0   r0   r1   rB    s    z%YolosLoss._get_target_permutation_idxc                 C   s@   | j | j| j| jd}||kr.td| d|| ||||S )N)r  r  r  r@  zLoss z not supported)r7  r<  r?  rG  r   )rH   r!   r   r  r5  r  Zloss_mapr0   r0   r1   get_loss!  s    zYolosLoss.get_lossc           
   
      s  dd |  D }| ||}tdd |D }tj|gtjtt| j	d}tj
|dd }i }| jD ]}|| ||||| qnd|krt|d D ]\\ }| ||}| jD ]@}|d	krq| |||||}	 fd
d|	  D }	||	 qq|S )a  
        This performs the loss computation.

        Args:
             outputs (`dict`, *optional*):
                Dictionary of tensors, see the output specification of the model for the format.
             targets (`List[dict]`, *optional*):
                List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
                losses applied, see each loss' doc.
        c                 S   s   i | ]\}}|d kr||qS )r%   r0   r  r0   r0   r1   r  7  s       z%YolosLoss.forward.<locals>.<dictcomp>c                 s   s   | ]}t |d  V  qdS )r,  Nr8  rA  r0   r0   r1   r   =  s     z$YolosLoss.forward.<locals>.<genexpr>r/  r   minr%   r@  c                    s    i | ]\}}|d    |qS r  r0   r  r  r0   r1   r  S  s      )r   r  r  r-   r9  r;  nextitervaluesr  clampitemr  r  rM  r   )
rH   r   r  Zoutputs_without_auxr5  r  r  r!   r%   Zl_dictr0   r  r1   r[   ,  s$    "

zYolosLoss.forward)r)   r*   r+   r,   r8   r7  r-   no_gradr<  r?  rG  r2  rB  rM  r[   r]   r0   r0   rI   r1   r    s   
!r  c                       s(   e Zd ZdZ fddZdd Z  ZS )r   a  
    Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
    height and width of a bounding box w.r.t. an image.

    Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py

    c                    sJ   t    || _|g|d  }tdd t|g| ||g D | _d S )Nr   c                 s   s   | ]\}}t ||V  qd S r_   )r
   r   )r   nr  r0   r0   r1   r   g  s     z2YolosMLPPredictionHead.__init__.<locals>.<genexpr>)r7   r8   r   r
   r   r   layers)rH   r   r   r   r   hrI   r0   r1   r8   c  s    
zYolosMLPPredictionHead.__init__c                 C   s>   t | jD ].\}}|| jd k r0tj||n||}q
|S r6   )r   rW  r   r
   rk   Zrelu)rH   r   r   r   r0   r0   r1   r[   i  s    (zYolosMLPPredictionHead.forward)r)   r*   r+   r,   r8   r[   r]   r0   r0   rI   r1   r   Z  s   r   c                       s<   e Zd ZdZdeeed fddZe dd Z  Z	S )	r  a  
    This class computes an assignment between the targets and the predictions of the network.

    For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
    predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
    un-matched (and thus treated as non-objects).

    Args:
        class_cost:
            The relative weight of the classification error in the matching cost.
        bbox_cost:
            The relative weight of the L1 error of the bounding box coordinates in the matching cost.
        giou_cost:
            The relative weight of the giou loss of the bounding box in the matching cost.
    r   r  c                    sL   t    t| dg || _|| _|| _|dkrH|dkrH|dkrHtdd S )NZscipyr   z#All costs of the Matcher can't be 0)r7   r8   r   r  r  r  r   )rH   r  r  r  rI   r0   r1   r8     s    
zYolosHungarianMatcher.__init__c                 C   s   |d j dd \}}|d ddd}|d dd}tdd	 |D }td
d	 |D }|dd|f  }	tj||dd}
tt|t| }| j|
 | j	|	  | j
|  }|||d }dd	 |D }dd	 t||dD }dd	 |D S )a  
        Args:
            outputs (`dict`):
                A dictionary that contains at least these entries:
                * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
            targets (`List[dict]`):
                A list of targets (len(targets) = batch_size), where each target is a dict containing:
                * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
                  ground-truth
                 objects in the target) containing the class labels
                * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.

        Returns:
            `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
            - index_i is the indices of the selected predictions (in order)
            - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        r#   Nrc   r   r   rM   r$   c                 S   s   g | ]}|d  qS r+  r0   r   r0   r0   r1   r     s     z1YolosHungarianMatcher.forward.<locals>.<listcomp>c                 S   s   g | ]}|d  qS r=  r0   r   r0   r0   r1   r     s     )pc                 S   s   g | ]}t |d  qS r=  r8  r   r0   r0   r1   r     s     c                 S   s   g | ]\}}t || qS r0   r   )r   r   cr0   r0   r1   r     s     c                 S   s0   g | ](\}}t j|t jd t j|t jd fqS ))r0  )r-   r9  r4  )r   r   jr0   r0   r1   r     s     )rP   rm   r   r-   rS   Zcdistr>  r   r  r  r  rj   cpur   split)rH   r   r  rT   Znum_queriesZout_probZout_bboxZ
target_idsZtarget_bboxr  r  r  Zcost_matrixsizesr5  r0   r0   r1   r[     s    zYolosHungarianMatcher.forward)r   r   r   )
r)   r*   r+   r,   r;  r8   r-   rU  r[   r]   r0   r0   rI   r1   r  p  s   
r  )r-  r5   c                 C   sH   |   r&| jtjtjfkr| S |  S | jtjtjfkr<| S |  S d S r_   )	Zis_floating_pointr0  r-   Zfloat32Zfloat64r;  Zint32r4  r   )r-  r0   r0   r1   _upcast  s    r_  )r  r5   c                 C   sH   t | } | dddf | dddf  | dddf | dddf   S )a  
    Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

    Args:
        boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
            Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
            < x2` and `0 <= y1 < y2`.

    Returns:
        `torch.FloatTensor`: a tensor containing the area for each box.
    Nrc   r   r   r   )r_  r=  r0   r0   r1   box_area  s    r`  c           
      C   s   t | }t |}t| d d d d df |d d d df }t| d d d dd f |d d dd f }|| jdd}|d d d d df |d d d d df  }|d d d f | | }|| }	|	|fS )Nrc   r   rN  r   )r`  r-   maxrO  rS  )
boxes1boxes2Zarea1Zarea2Zleft_topZright_bottomwidth_heightZinterr   iour0   r0   r1   box_iou  s    ..,rf  c                 C   s*  | ddddf | ddddf k  s:td|  |ddddf |ddddf k  sttd| t| |\}}t| dddddf |ddddf }t| dddddf |ddddf }|| jdd}|dddddf |dddddf  }||| |  S )z
    Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.

    Returns:
        `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
    Nrc   z<boxes1 must be in [x0, y0, x1, y1] (corner) format, but got z<boxes2 must be in [x0, y0, x1, y1] (corner) format, but got r   rN  r   )allr   rf  r-   rO  ra  rS  )rb  rc  re  r   top_leftbottom_rightrd  Zarear0   r0   r1   r>    s    	,,..,r>  c                 C   sB   | d }| dd  D ](}t |D ]\}}t|| |||< q q|S r   )r   ra  )Zthe_listZmaxesZsublistr   rT  r0   r0   r1   _max_by_axis  s
    rj  c                   @   s6   e Zd Zee dddZdd Zdd Zdd	 Zd
S )NestedTensor)maskc                 C   s   || _ || _d S r_   tensorsrl  )rH   rn  rl  r0   r0   r1   r8     s    zNestedTensor.__init__c                 C   s4   | j |}| j}|d k	r&||}nd }t||S r_   )rn  r  rl  rk  )rH   r  Zcast_tensorrl  Z	cast_maskr0   r0   r1   r    s    zNestedTensor.toc                 C   s   | j | jfS r_   rm  r   r0   r0   r1   rD    s    zNestedTensor.decomposec                 C   s
   t | jS r_   )r   rn  r   r0   r0   r1   __repr__  s    zNestedTensor.__repr__N)	r)   r*   r+   r   r	   r8   r  rD  ro  r0   r0   r0   r1   rk    s   	rk  )tensor_listc                 C   s   | d j dkrtdd | D }t| g| }|\}}}}| d j}| d j}tj|||d}	tj|||ftj|d}
t	| |	|
D ]\\}}}|d |j
d d |j
d d |j
d f | d|d |j
d d |j
d f< qntd	t|	|
S )
Nr   r   c                 S   s   g | ]}t |jqS r0   )listrP   )r   imgr0   r0   r1   r   %  s     z2nested_tensor_from_tensor_list.<locals>.<listcomp>r/  r   rc   Fz(Only 3-dimensional tensors are supported)ndimrj  r   r0  r  r-   r:   r*  r   r   rP   Zcopy_r   rk  )rp  max_sizeZbatch_shaperT   rU   rV   rW   r0  r  Ztensorrl  rr  Zpad_imgmr0   r0   r1   rC  #  s    

2$rC  )r#  rc   )Sr,   collections.abcr|   r   dataclassesr   typingr   r   r   r   r   r   r-   Ztorch.utils.checkpointr	   r
   Zactivationsr   Zmodeling_outputsr   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   r   r   r   r   r   r   r   Zconfiguration_yolosr   Zscipy.optimizer   Ztransformers.image_transformsr   Z
get_loggerr)   loggerr   r   r   Z#YOLOS_PRETRAINED_MODEL_ARCHIVE_LISTr    Moduler2   rF   rx   r?   r   r   r   r   r   r   r   r   ZYOLOS_START_DOCSTRINGr   r   r   r   r"  r;  r(  r  r   r  r_  r`  rf  r>  rj  objectrk  rC  r0   r0   r0   r1   <module>   s    ,

,+!"=(*TW   HO	
