U
    ,-eK                    @   s  d Z ddlZddlmZ ddlmZmZmZmZm	Z	 ddl
Z
ddl
mZmZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZmZmZmZmZmZmZmZ ddlm Z  ddl!m"Z" e rddl#m$Z$ e rddl%m&Z& e rddl'm(Z( e)e*Z+dZ,dZ-dgZ.eG dd deZ/eG dd deZ0eG dd deZ1eG dd deZ2G dd dej3Z4dd Z5G d d! d!ej3Z6G d"d# d#ej3Z7dje
je
j8ee9 d$d%d&Z:G d'd( d(ej3Z;G d)d* d*ej3Z<d+d, Z=G d-d. d.ej3Z>G d/d0 d0ej3Z?G d1d2 d2ej3Z@G d3d4 d4ej3ZAG d5d6 d6eZBd7ZCd8ZDG d9d: d:eBZEG d;d< d<eBZFed=eCG d>d? d?eBZGed@eCG dAdB dBeBZHedCeCG dDdE dEeBZIe9dFdGdHZJG dIdJ dJej3ZKG dKdL dLej3ZLdMdN ZMdkeNeNdPdQdRZOG dSdT dTej3ZPG dUdV dVej3ZQG dWdX dXej3ZReedYdZd[ZSeed\d]d^ZTd_d` ZUdadb ZVdcdd ZWG dedf dfeXZYee dgdhdiZZdS )lz PyTorch DETR model.    N)	dataclass)DictListOptionalTupleUnion)Tensornn   )ACT2FN)BaseModelOutput"BaseModelOutputWithCrossAttentionsSeq2SeqModelOutput)PreTrainedModel)	ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardis_scipy_availableis_timm_availableis_vision_availableloggingreplace_return_docstringsrequires_backends   )AutoBackbone   )
DetrConfiglinear_sum_assignment)create_model)center_to_corners_formatr   zfacebook/detr-resnet-50c                   @   s$   e Zd ZU dZdZeej ed< dS )DetrDecoderOutputa  
    Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
    namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
    gone through a layernorm. This is useful when training the model with auxiliary decoding losses.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
        intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
            Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
            layernorm.
    Nintermediate_hidden_states	__name__
__module____qualname____doc__r"   r   torchFloatTensor__annotations__ r+   r+   g/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/detr/modeling_detr.pyr!   ?   s   
r!   c                   @   s$   e Zd ZU dZdZeej ed< dS )DetrModelOutputa)  
    Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
    namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
    gone through a layernorm. This is useful when training the model with auxiliary decoding losses.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the decoder of the model.
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
            layer plus the initial embedding outputs.
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
            layer plus the initial embedding outputs.
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
        intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
            Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
            layernorm.
    Nr"   r#   r+   r+   r+   r,   r-   ]   s   
#r-   c                   @   s   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZejed< dZejed< dZeee
  ed< dZeej ed< dZeeej  ed	< dZeeej  ed
< dZeeej  ed< dZeej ed< dZeeej  ed< dZeeej  ed< dS )DetrObjectDetectionOutputa  
    Output type of [`DetrForObjectDetection`].

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
            Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
            bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
            scale-invariant IoU loss.
        loss_dict (`Dict`, *optional*):
            A dictionary containing the individual losses. Useful for logging.
        logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
            Classification logits (including no-object) for all queries.
        pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
            Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
            values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
            possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
            unnormalized bounding boxes.
        auxiliary_outputs (`list[Dict]`, *optional*):
            Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
            and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
            `pred_boxes`) for each decoder layer.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the decoder of the model.
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
            layer plus the initial embedding outputs.
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
            layer plus the initial embedding outputs.
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
    Nloss	loss_dictlogits
pred_boxesauxiliary_outputslast_hidden_statedecoder_hidden_statesdecoder_attentionscross_attentionsencoder_last_hidden_stateencoder_hidden_statesencoder_attentions)r$   r%   r&   r'   r/   r   r(   r)   r*   r0   r   r1   r2   r3   r   r4   r5   r   r6   r7   r8   r9   r:   r+   r+   r+   r,   r.      s   
/r.   c                   @   s  e Zd ZU dZdZeej ed< dZ	ee
 ed< dZejed< dZejed< dZejed< dZeee
  ed< dZeej ed	< dZeeej  ed
< dZeeej  ed< dZeeej  ed< dZeej ed< dZeeej  ed< dZeeej  ed< dS )DetrSegmentationOutputaP  
    Output type of [`DetrForSegmentation`].

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
            Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
            bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
            scale-invariant IoU loss.
        loss_dict (`Dict`, *optional*):
            A dictionary containing the individual losses. Useful for logging.
        logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
            Classification logits (including no-object) for all queries.
        pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
            Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
            values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
            possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
            unnormalized bounding boxes.
        pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
            Segmentation masks logits for all queries. See also
            [`~DetrImageProcessor.post_process_semantic_segmentation`] or
            [`~DetrImageProcessor.post_process_instance_segmentation`]
            [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
            segmentation masks respectively.
        auxiliary_outputs (`list[Dict]`, *optional*):
            Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
            and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
            `pred_boxes`) for each decoder layer.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the decoder of the model.
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
            layer plus the initial embedding outputs.
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
            layer plus the initial embedding outputs.
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
    Nr/   r0   r1   r2   
pred_masksr3   r4   r5   r6   r7   r8   r9   r:   )r$   r%   r&   r'   r/   r   r(   r)   r*   r0   r   r1   r2   r<   r3   r   r4   r5   r   r6   r7   r8   r9   r:   r+   r+   r+   r,   r;      s   
5r;   c                       s4   e Zd ZdZ fddZ fddZdd Z  ZS )DetrFrozenBatchNorm2dz
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
    torchvision.models.resnet[18,34,50,101] produce nans.
    c                    sV   t    | dt| | dt| | dt| | dt| d S )Nweightbiasrunning_meanrunning_var)super__init__register_bufferr(   oneszeros)selfn	__class__r+   r,   rC     s
    
zDetrFrozenBatchNorm2d.__init__c           	   	      s2   |d }||kr||= t  ||||||| d S )NZnum_batches_tracked)rB   _load_from_state_dict)	rG   Z
state_dictprefixZlocal_metadatastrictZmissing_keysZunexpected_keysZ
error_msgsZnum_batches_tracked_keyrI   r+   r,   rK     s          z+DetrFrozenBatchNorm2d._load_from_state_dictc                 C   st   | j dddd}| jdddd}| jdddd}| jdddd}d}|||   }|||  }|| | S )Nr   gh㈵>)r>   reshaper?   rA   r@   Zrsqrt)rG   xr>   r?   rA   r@   epsilonscaler+   r+   r,   forward&  s    zDetrFrozenBatchNorm2d.forward)r$   r%   r&   r'   rC   rK   rS   __classcell__r+   r+   rI   r,   r=     s   r=   c                 C   s   |   D ]\}}t|tjrpt|j}|jj|j |j	j|j	 |j
j|j
 |jj|j || j|< tt| dkrt| qdS )z
    Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.

    Args:
        model (torch.nn.Module):
            input model
    r   N)Znamed_children
isinstancer	   BatchNorm2dr=   Znum_featuresr>   datacopy_r?   r@   rA   Z_moduleslenlistchildrenreplace_batch_norm)modelnamemoduleZ
new_moduler+   r+   r,   r\   3  s    

r\   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )DetrConvEncoderz
    Convolutional backbone, using either the AutoBackbone API or one from the timm library.

    nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.

    c              	      s"  t    || _|jrVt| dg i }|jr4d|d< t|jf|jdd|j	d|}nt
|j}t  t| W 5 Q R X || _|jr| jj n| jj| _|jr|jn|jj}d|kr| j D ]T\}}|jrd|krd	|krd
|kr|d qd|krd|krd|kr|d qd S )Ntimm   Zoutput_strideT)r   r   r
      )Z
pretrainedZfeatures_onlyZout_indicesZin_chansZresnetZlayer2Zlayer3Zlayer4Fzstage.1zstage.2zstage.3)rB   rC   configuse_timm_backboner   Zdilationr   backboneZuse_pretrained_backbonenum_channelsr   from_configZbackbone_configr(   no_gradr\   r]   Zfeature_infoZchannelsintermediate_channel_sizesZ
model_typenamed_parametersrequires_grad_)rG   rd   kwargsrf   Zbackbone_model_typer^   Z	parameterrI   r+   r,   rC   R  s>    
	

zDetrConvEncoder.__init__)pixel_values
pixel_maskc                 C   sl   | j jr| |n
| |j}g }|D ]@}tjj|d   |jdd  d	t
jd }|||f q&|S )N)sizer   )rd   re   r]   Zfeature_mapsr	   
functionalinterpolatefloatshapetor(   boolappend)rG   rn   ro   featuresoutfeature_mapmaskr+   r+   r,   rS   y  s    .zDetrConvEncoder.forward)	r$   r%   r&   r'   rC   r(   r   rS   rT   r+   r+   rI   r,   r`   J  s   'r`   c                       s(   e Zd ZdZ fddZdd Z  ZS )DetrConvModelzp
    This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
    c                    s   t    || _|| _d S N)rB   rC   conv_encoderposition_embedding)rG   r   r   rI   r+   r,   rC     s    
zDetrConvModel.__init__c                 C   s@   |  ||}g }|D ]"\}}|| |||j q||fS r~   )r   rx   r   rv   dtype)rG   rn   ro   rz   posr{   r|   r+   r+   r,   rS     s
    zDetrConvModel.forwardr$   r%   r&   r'   rC   rS   rT   r+   r+   rI   r,   r}     s   r}   )r|   r   
target_lenc                 C   sf   |   \}}|dk	r|n|}| ddddddf |d|||}d| }|| t|jS )zs
    Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
    Nr   g      ?)rq   expandrv   Zmasked_fillrw   r(   finfomin)r|   r   r   
batch_size
source_lenZexpanded_maskZinverted_maskr+   r+   r,   _expand_mask  s
    *r   c                       s*   e Zd ZdZd
 fdd	Zdd	 Z  ZS )DetrSinePositionEmbeddingz
    This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
    need paper, generalized to work on images.
    @   '  FNc                    sP   t    || _|| _|| _|d k	r4|dkr4td|d krFdtj }|| _d S )NFz+normalize should be True if scale is passedr   )	rB   rC   embedding_dimtemperature	normalize
ValueErrormathpirR   )rG   r   r   r   rR   rI   r+   r,   rC     s    

z"DetrSinePositionEmbedding.__init__c           	   	   C   s  |d krt d|jdtjd}|jdtjd}| jr||d d dd d d f d  | j }||d d d d dd f d  | j }tj| jtj|jd}| j	dtj
|ddd	 | j  }|d d d d d d d f | }|d d d d d d d f | }tj|d d d d d d d
d df  |d d d d d d dd df  fddd}tj|d d d d d d d
d df  |d d d d d d dd df  fddd}tj||fddd
ddd}|S )NzNo pixel mask providedr   r   r   rN   gư>r   devicefloor)Zrounding_moder   rc   dimr
   )r   Zcumsumr(   float32r   rR   aranger   r   r   divstacksincosflattencatpermute)	rG   rn   ro   Zy_embedZx_embedZdim_tZpos_xZpos_yr   r+   r+   r,   rS     s    ((   \\z!DetrSinePositionEmbedding.forward)r   r   FNr   r+   r+   rI   r,   r     s   r   c                       s,   e Zd ZdZd fdd	Zd	ddZ  ZS )
DetrLearnedPositionEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
       c                    s*   t    td|| _td|| _d S )N2   )rB   rC   r	   	Embeddingrow_embeddingscolumn_embeddings)rG   r   rI   r+   r,   rC     s    
z%DetrLearnedPositionEmbedding.__init__Nc           
      C   s   |j dd  \}}tj||jd}tj||jd}| |}| |}tj|d|dd|dd|dgdd}	|		ddd}	|	d}	|	|j d ddd}	|	S )Nrp   r   r   r   rN   r   r   )
ru   r(   r   r   r   r   r   	unsqueezerepeatr   )
rG   rn   ro   heightwidthZwidth_valuesZheight_valuesZx_embZy_embr   r+   r+   r,   rS     s    

2
z$DetrLearnedPositionEmbedding.forward)r   )Nr   r+   r+   rI   r,   r     s   r   c                 C   sJ   | j d }| jdkr"t|dd}n$| jdkr6t|}ntd| j |S )Nr   ZsineT)r   ZlearnedzNot supported )d_modelZposition_embedding_typer   r   r   )rd   Zn_stepsr   r+   r+   r,   build_position_encoding  s    



r   c                       s   e Zd ZdZdeeeed fddZej	eeddd	Z
ej	ee	 d
ddZdej	eej	 eej	 eej	 eej	 eeej	eej	 eeej	  f dddZ  ZS )DetrAttentionz
    Multi-headed attention from 'Attention Is All You Need' paper.

    Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
            T)	embed_dim	num_headsdropoutr?   c                    s   t    || _|| _|| _|| | _| j| | jkrNtd| j d| d| jd | _tj	|||d| _
tj	|||d| _tj	|||d| _tj	|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      ࿩r?   )rB   rC   r   r   r   head_dimr   scalingr	   Lineark_projv_projq_projout_proj)rG   r   r   r   r?   rI   r+   r,   rC     s    

zDetrAttention.__init__)tensorseq_lenr   c                 C   s    | ||| j| jdd S )Nr   r   )viewr   r   	transpose
contiguous)rG   r   r   r   r+   r+   r,   _shape  s    zDetrAttention._shape)r   object_queriesc                 K   sd   | dd }|r"td|  |d k	r:|d k	r:td|d k	rPtd |}|d kr\|S || S )Nposition_embeddingsUnexpected arguments ZCannot specify both position_embeddings and object_queries. Please use just object_queriesgposition_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead)popr   keysloggerwarning_once)rG   r   r   rm   r   r+   r+   r,   with_pos_embed  s    zDetrAttention.with_pos_embedNF)hidden_statesattention_maskr   key_value_statesspatial_position_embeddingsoutput_attentionsreturnc                 K   s  | dd}| dd}	|r.td|  |dk	rF|dk	rFtd|	dk	r^|dk	r^td|dk	rttd |}|	dk	rtd |	}|dk	}
| \}}}|dk	r|}| ||}|dk	r|}| ||}| || j }|
r| 	| 
|d	|}| 	| |d	|}n(| 	| 
|d	|}| 	| |d	|}|| j d	| jf}| 	|||j| }|j| }|j| }|d
}t||d
d}| || j ||fkrtd|| j ||f d|  |dk	r8| |d
||fkrtd|d
||f d|  ||| j||| }||| j ||}tjj|d	d}|rv||| j||}||| j ||}nd}tjj|| j| jd}t||}| || j || jfkrtd|| j|| jf d|  ||| j|| j}|d
d}||||}| |}||fS )z#Input shape: Batch x Time x ChannelZposition_ebmeddingsNkey_value_position_embeddingsr   r   z~Cannot specify both key_value_position_embeddings and spatial_position_embeddings. Please use just spatial_position_embeddingsr   z~key_value_position_embeddings has been deprecated and will be removed in v4.34. Please use spatial_position_embeddings insteadrN   r   r   z$Attention weights should be of size z	, but is z!Attention mask should be of size r   ptrainingz `attn_output` should be of size )r   r   r   r   r   rq   r   r   r   r   r   r   r   r   r   r(   Zbmmr   r	   rr   softmaxr   r   rO   r   )rG   r   r   r   r   r   r   rm   r   r   Zis_cross_attentionr   r   r   Zhidden_states_originalZkey_value_states_originalZquery_statesZ
key_statesZvalue_statesZ
proj_shaper   attn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr+   r+   r,   rS   &  s    




zDetrAttention.forward)r   T)NNNNF)r$   r%   r&   r'   intrt   rw   rC   r(   r   r   r   r   r   rS   rT   r+   r+   rI   r,   r     s2   
       r   c                       s>   e Zd Zed fddZd	ejejejedddZ  Z	S )
DetrEncoderLayerrd   c                    s   t    |j| _t| j|j|jd| _t	| j| _
|j| _t|j | _|j| _t| j|j| _t|j| j| _t	| j| _d S )Nr   r   r   )rB   rC   r   r   r   encoder_attention_headsattention_dropout	self_attnr	   	LayerNormself_attn_layer_normr   r   activation_functionactivation_fnactivation_dropoutr   Zencoder_ffn_dimfc1fc2final_layer_normrG   rd   rI   r+   r,   rC     s    
zDetrEncoderLayer.__init__NFr   r   r   r   c                 K   sT  | dd}|r"td|  |dk	r:|dk	r:td|dk	rPtd |}|}| j||||d\}}tjj|| j| j	d}|| }| 
|}|}| | |}tjj|| j| j	d}| |}tjj|| j| j	d}|| }| |}| j	r:t| st| r:t|jjd }	tj||	 |	d	}|f}
|rP|
|f7 }
|
S )
a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
                values.
            object_queries (`torch.FloatTensor`, *optional*):
                Object queries (also called content embeddings), to be added to the hidden states.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        r   Nr   r   r   r   r   i  )r   max)r   r   r   r   r   r   r	   rr   r   r   r   r   r   r   r   r   r(   isinfanyisnanr   r   r   clamp)rG   r   r   r   r   rm   r   residualr   Zclamp_valueoutputsr+   r+   r,   rS     sJ    



 
zDetrEncoderLayer.forward)NF)
r$   r%   r&   r   rC   r(   r   rw   rS   rT   r+   r+   rI   r,   r     s     r   c                	       sb   e Zd Zed fddZd	ejeej eej eej eej eej ee dddZ	  Z
S )
DetrDecoderLayerr   c                    s   t    |j| _t| j|j|jd| _|j| _t	|j
 | _|j| _t| j| _t| j|j|jd| _t| j| _t| j|j| _t|j| j| _t| j| _d S )Nr   )r   )rB   rC   r   r   r   Zdecoder_attention_headsr   r   r   r   r   r   r   r	   r   r   encoder_attnencoder_attn_layer_normr   Zdecoder_ffn_dimr   r   r   r   rI   r+   r,   rC     s(    
zDetrDecoderLayer.__init__NF)r   r   r   query_position_embeddingsr9   encoder_attention_maskr   c                 K   s\  | dd}	|r"td|  |	dk	r:|dk	r:td|	dk	rPtd |	}|}
| j||||d\}}tjj|| j| j	d}|
| }| 
|}d}|dk	r|}
| j||||||d\}}tjj|| j| j	d}|
| }| |}|}
| | |}tjj|| j| j	d}| |}tjj|| j| j	d}|
| }| |}|f}|rX|||f7 }|S )	a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
                values.
            object_queries (`torch.FloatTensor`, *optional*):
                object_queries that are added to the hidden states
            in the cross-attention layer.
            query_position_embeddings (`torch.FloatTensor`, *optional*):
                position embeddings that are added to the queries and keys
            in the self-attention layer.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
                values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        r   Nr   r   r   )r   r   r   r   r   )r   r   r   r   r   r   )r   r   r   r   r   r   r	   rr   r   r   r   r   r   r   r   r   r   r   )rG   r   r   r   r   r9   r   r   rm   r   r   Zself_attn_weightsZcross_attn_weightsr   r+   r+   r,   rS     s^     


	


zDetrDecoderLayer.forward)NNNNNF)r$   r%   r&   r   rC   r(   r   r   rw   rS   rT   r+   r+   rI   r,   r     s          r   c                       s<   e Zd ZdZeeeed fddZejdddZ	  Z
S )DetrClassificationHeadz-Head for sentence-level classification tasks.)	input_dim	inner_dimnum_classespooler_dropoutc                    s8   t    t||| _tj|d| _t||| _d S )Nr   )rB   rC   r	   r   denseDropoutr   r   )rG   r   r   r   r   rI   r+   r,   rC   u  s    
zDetrClassificationHead.__init__)r   c                 C   s6   |  |}| |}t|}|  |}| |}|S r~   )r   r   r(   tanhr   )rG   r   r+   r+   r,   rS   {  s    




zDetrClassificationHead.forward)r$   r%   r&   r'   r   rt   rC   r(   r   rS   rT   r+   r+   rI   r,   r   r  s   r   c                   @   s*   e Zd ZeZdZdZdd Zd	ddZdS )
DetrPreTrainedModelr]   rn   c                 C   s  | j j}| j j}t|trdtj|jj	 tj|j
j	 tjj|jj|d tjj|j
j|d n*t|trtj|jj tj|jj t|tjtjtjfr|jjjd|d |j	d k	r|j	j  n>t|tjr|jjjd|d |jd k	r|jj|j   d S )N)Zgainr   )meanstd)rd   Zinit_stdinit_xavier_stdrU   DetrMHAttentionMapr	   initZzeros_k_linearr?   q_linearZxavier_uniform_r>   r   Zuniform_r   r   r   Conv2drV   rW   Znormal_Zzero_r   Zpadding_idx)rG   r_   r  Z
xavier_stdr+   r+   r,   _init_weights  s$    


z!DetrPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r~   )rU   DetrDecodergradient_checkpointing)rG   r_   valuer+   r+   r,   _set_gradient_checkpointing  s    
z/DetrPreTrainedModel._set_gradient_checkpointingN)F)	r$   r%   r&   r   config_classZbase_model_prefixZmain_input_namer	  r  r+   r+   r+   r,   r     s
   r   aI  
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`DetrConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a`	  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Padding will be ignored by default should you provide it.

            Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for details.

        pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:

            - 1 for pixels that are real (i.e. **not masked**),
            - 0 for pixels that are padding (i.e. **masked**).

            [What are attention masks?](../glossary#attention-mask)

        decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
            Not used by default. Can be used to mask object queries.
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
            can choose to directly pass a flattened representation of an image.
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
            Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
            embedded representation.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                       s0   e Zd ZdZed fddZdddZ  ZS )	DetrEncoderaU  
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`DetrEncoderLayer`].

    The encoder updates the flattened feature map through multiple self-attention layers.

    Small tweak for DETR:

    - object_queries are added to the forward pass.

    Args:
        config: DetrConfig
    r   c                    sH   t     j| _ j| _t fddt jD | _	| 
  d S )Nc                    s   g | ]}t  qS r+   )r   .0_r   r+   r,   
<listcomp>  s     z(DetrEncoder.__init__.<locals>.<listcomp>)rB   rC   r   Zencoder_layerdrop	layerdropr	   
ModuleListrangeZencoder_layerslayers	post_initr   rI   r   r,   rC     s
     zDetrEncoder.__init__Nc                 K   s  | dd}|r"td|  |dk	r:|dk	r:td|dk	rPtd |}|dk	r\|n| jj}|dk	rp|n| jj}|dk	r|n| jj}|}	t	j
j|	| j| jd}	|dk	rt||j}|rdnd}
|rdnd}t| jD ]r\}}|r|
|	f }
d}| jrtg }|| jk rd	}|r$d
}n||	|||d}|d }	|r||d f }q|r`|
|	f }
|s~tdd |	|
|fD S t|	|
|dS )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.

            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:

                - 1 for pixel features that are real (i.e. **not masked**),
                - 0 for pixel features that are padding (i.e. **masked**).

                [What are attention masks?](../glossary#attention-mask)

            object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Object queries that are added to the queries in each self-attention layer.

            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        r   Nr   r   r   r   r+   FT)NN)r   r   r   r   c                 s   s   | ]}|d k	r|V  qd S r~   r+   r  vr+   r+   r,   	<genexpr>X  s      z&DetrEncoder.forward.<locals>.<genexpr>r4   r   
attentions)r   r   r   r   r   rd   r   output_hidden_statesuse_return_dictr	   rr   r   r   r   r   	enumerater  r(   randr  tupler   )rG   inputs_embedsr   r   r   r  return_dictrm   r   r   Zencoder_statesZall_attentionsiZencoder_layerZto_dropdropout_probabilitylayer_outputsr+   r+   r,   rS     sd    #


  zDetrEncoder.forward)NNNNNNr$   r%   r&   r'   r   rC   rS   rT   r+   r+   rI   r,   r    s         r  c                	       s0   e Zd ZdZed fddZdddZ  ZS )	r
  a  
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].

    The decoder updates the query embeddings through multiple self-attention and cross-attention layers.

    Some small tweaks for DETR:

    - object_queries and query_position_embeddings are added to the forward pass.
    - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.

    Args:
        config: DetrConfig
    r   c                    s\   t     j| _ j| _t fddt jD | _	t
 j| _d| _|   d S )Nc                    s   g | ]}t  qS r+   )r   r  r   r+   r,   r  r  s     z(DetrDecoder.__init__.<locals>.<listcomp>F)rB   rC   r   Zdecoder_layerdropr  r	   r  r  decoder_layersr  r   r   	layernormr  r  r   rI   r   r,   rC   m  s     zDetrDecoder.__init__Nc
              
      s  |
 dd}|
r"td|
  |dk	r:|dk	r:td|dk	rPtd |} dk	r\ n| jj |dk	rp|n| jj}|	dk	r|	n| jj}	|dk	r|}|	 dd }d}|dk	r|dk	r|t
||j|d d }|dk	r|dk	rt
||j|d d}| jjrdnd}|rdnd} r"dnd} r:|dk	r:dnd}t| jD ]\}}|r`||f7 }| jrtg }|| jk rqH| jr| jr fd	d
}tjj||||||d}n||||||| d}|d }| jjr| |}||f7 } rH||d f7 }|dk	rH||d f7 }qH| |}|rB||f7 }| jjrVt|}|	sxtdd |||||fD S t|||||dS )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                The query embeddings that are passed into the decoder.

            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:

                - 1 for queries that are **not masked**,
                - 0 for queries that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
                in `[0, 1]`:

                - 1 for pixels that are real (i.e. **not masked**),
                - 0 for pixels that are padding (i.e. **masked**).

            object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Object queries that are added to the queries and keys in each cross-attention layer.
            query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
                , *optional*): Position embeddings that are added to the values and keys in each self-attention layer.

            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        r   Nr   r   r   rN   )r   r+   c                    s    fdd}|S )Nc                     s    | f S r~   r+   )inputs)r_   r   r+   r,   custom_forward  s    zJDetrDecoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr+   )r_   r,  r   )r_   r,   create_custom_forward  s    z2DetrDecoder.forward.<locals>.create_custom_forward)r   r   r   r9   r   r   r   r   r   c                 s   s   | ]}|d k	r|V  qd S r~   r+   r  r+   r+   r,   r    s   z&DetrDecoder.forward.<locals>.<genexpr>)r4   r   r  r7   r"   )r   r   r   r   r   rd   r   r  r  rq   r   r   auxiliary_lossr   r  r   r(   r!  r  r  utils
checkpointr*  r   r"  r!   )rG   r#  r   r9   r   r   r   r   r  r$  rm   r   r   Zinput_shapeZcombined_attention_maskintermediateZall_hidden_statesZall_self_attnsZall_cross_attentionsidxZdecoder_layerr&  r.  r'  r+   r-  r,   rS   z  s    1    

	








zDetrDecoder.forward)	NNNNNNNNNr(  r+   r+   rI   r,   r
  ^  s            r
  z
    The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
    any specific head on top.
    c                       s   e Zd Zed fddZdd Zdd Zdd	 Zd
d Ze	e
eeeddejeej eej eej eej eej ee ee ee eeej ef d
ddZ  ZS )	DetrModelr   c                    st   t  | t|}t|}t||| _tj|jd |j	dd| _
t|j|j	| _t|| _t|| _|   d S )NrN   r   )Zkernel_size)rB   rC   r`   r   r}   rf   r	   r  rj   r   input_projectionr   num_queriesr   r  encoderr
  decoderr  )rG   rd   rf   r   rI   r+   r,   rC   .  s    

zDetrModel.__init__c                 C   s   | j S r~   )r7  rG   r+   r+   r,   get_encoderA  s    zDetrModel.get_encoderc                 C   s   | j S r~   )r8  r9  r+   r+   r,   get_decoderD  s    zDetrModel.get_decoderc                 C   s&   | j jj D ]\}}|d qd S )NFrf   r   r]   rk   rl   rG   r^   paramr+   r+   r,   freeze_backboneG  s    zDetrModel.freeze_backbonec                 C   s&   | j jj D ]\}}|d qd S )NTr<  r=  r+   r+   r,   unfreeze_backboneK  s    zDetrModel.unfreeze_backboneoutput_typer  N)
rn   ro   decoder_attention_maskencoder_outputsr#  decoder_inputs_embedsr   r  r$  r   c
                 C   s  |dk	r|n| j j}|dk	r |n| j j}|	dk	r4|	n| j j}	|j\}
}}}|j}|dkrltj|
||f|d}| ||\}}|d \}}|dkrt	d| 
|}|dddd}|d dddd}|d}|dkr| j||||||	d}nP|	rHt|tsHt|d t|dkr(|d ndt|dkr@|d ndd	}| jjd|
dd}t|}| j|d|||d ||||	d
	}|	s|| S t|j|j|j|j|j|j|j|jdS )a  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, DetrModel
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
        >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")

        >>> # prepare image for the model
        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**inputs)

        >>> # the last hidden states are the final query embeddings of the Transformer decoder
        >>> # these are of shape (batch_size, num_queries, hidden_size)
        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 100, 256]
        ```Nr   rN   z/Backbone does not return downsampled pixel maskr   r   r   r#  r   r   r   r  r$  r  	r#  r   r   r   r9   r   r   r  r$  )r4   r5   r6   r7   r8   r9   r:   r"   )rd   r   r  r  ru   r   r(   rE   rf   r   r5  r   r   r7  rU   r   rY   r   r>   r   r   
zeros_liker8  r-   r4   r   r  r7   r"   )rG   rn   ro   rC  rD  r#  rE  r   r  r$  r   rg   r   r   r   ry   object_queries_listr{   r|   projected_feature_mapflattened_featuresr   flattened_maskr   queriesdecoder_outputsr+   r+   r,   rS   O  sp    *

	
zDetrModel.forward)NNNNNNNN)r$   r%   r&   r   rC   r:  r;  r?  r@  r   DETR_INPUTS_DOCSTRINGr   r-   _CONFIG_FOR_DOCr(   r)   r   
LongTensorrw   r   r   rS   rT   r+   r+   rI   r,   r4  &  s6   
        r4  z
    DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
    such as COCO detection.
    c                       s   e Zd Zed fddZejjdd Ze	e
eeeddejeej eej eej eej eej eee  ee ee ee eeej ef dd	d
Z  ZS )DetrForObjectDetectionr   c                    sN   t  | t|| _t|j|jd | _t	|j|jddd| _
|   d S )Nr   rc   r
   )r   
hidden_dim
output_dim
num_layers)rB   rC   r4  r]   r	   r   r   
num_labelsclass_labels_classifierDetrMLPPredictionHeadbbox_predictorr  r   rI   r+   r,   rC     s    
    zDetrForObjectDetection.__init__c                 C   s$   dd t |d d |d d D S )Nc                 S   s   g | ]\}}||d qS ))r1   r2   r+   )r  abr+   r+   r,   r    s     z8DetrForObjectDetection._set_aux_loss.<locals>.<listcomp>rN   )zip)rG   outputs_classoutputs_coordr+   r+   r,   _set_aux_loss  s    z$DetrForObjectDetection._set_aux_lossrA  Nrn   ro   rC  rD  r#  rE  labelsr   r  r$  r   c                    s
  |
dk	r|
n| j j}
| j||||||||	|
d	}|d }| |}| | }d\}}|dk	rt| j j| j j| j j	d}dddg}t
|| j j| j j|d	}|| j i }||d
< ||d< | j jr|
r|jn|d }| |}| | }| ||}||d< |||d| j jd| j jd< | j jrxi }t| j jd D ]" | fdd D  qJ| tfdd D }|
s|dk	r||f| | }n||f| }|dk	r|f| S |S t|||||j|j|j|j|j|j |j!dS )a  
        labels (`List[Dict]` of len `(batch_size,)`, *optional*):
            Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
            following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
            respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
            in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, DetrForObjectDetection
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
        >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> # convert outputs (bounding boxes and class logits) to COCO API
        >>> target_sizes = torch.tensor([image.size[::-1]])
        >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
        ...     0
        ... ]

        >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(
        ...         f"Detected {model.config.id2label[label.item()]} with confidence "
        ...         f"{round(score.item(), 3)} at location {box}"
        ...     )
        Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
        Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
        Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
        Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
        Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
        ```N)ro   rC  rD  r#  rE  r   r  r$  r   NNN
class_cost	bbox_cost	giou_costra  boxescardinalitymatcherr   eos_coeflossesr1   r2   rc   r3   r   loss_ce	loss_bbox	loss_giouc                    s    i | ]\}}|d    |qS r  r+   r  kr  r%  r+   r,   
<dictcomp>`  s      z2DetrForObjectDetection.forward.<locals>.<dictcomp>c                 3   s&   | ]}|kr | |  V  qd S r~   r+   r  rs  r0   weight_dictr+   r,   r  b  s      z1DetrForObjectDetection.forward.<locals>.<genexpr>)r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   )"rd   r  r]   rW  rY  sigmoidDetrHungarianMatcherrd  re  rf  DetrLossrV  eos_coefficientrv   r   r/  r"   r_  bbox_loss_coefficientgiou_loss_coefficientr  r)  updateitemssumr   r.   r4   r5   r6   r7   r8   r9   r:   )rG   rn   ro   rC  rD  r#  rE  ra  r   r  r$  r   sequence_outputr1   r2   r/   r3   rj  rl  	criterionoutputs_lossr2  r]  r^  aux_weight_dictoutputr+   r%  r0   rx  r,   rS     s    ;


  




 

zDetrForObjectDetection.forward)	NNNNNNNNN)r$   r%   r&   r   rC   r(   ZjitZunusedr_  r   rO  r   r.   rP  r)   r   rQ  r   dictrw   r   r   rS   rT   r+   r+   rI   r,   rR    s6   

         
rR  z
    DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
    such as COCO panoptic.

    c                       s   e Zd Zed fddZeeeee	dd	e
jee
j ee
j ee
j ee
j ee
j eee  ee ee ee eee
j ef dddZ  ZS )
DetrForSegmentationr   c                    sx   t  | t|| _|j|j }}| jjjjj	}t
|| |d d d dd  || _t|||d|jd| _|   d S )NrN   r   )r   r  )rB   rC   rR  detrr   r   r]   rf   r   rj   DetrMaskHeadSmallConv	mask_headr  r  bbox_attentionr  )rG   rd   Zhidden_sizeZnumber_of_headsrj   rI   r+   r,   rC     s"    
      zDetrForSegmentation.__init__rA  Nr`  c           -         s  |
dk	r|
n| j j}
|j\}}}}|j}|dkrDtj|||f|d}| jjj||d\}}|d \}}|j\}}}}| jj	|}|
dddd}|d 
dddd}|
d}|dkr| jjj|||||	|
d}nP|
r,t|ts,t|d t|dkr|d ndt|dkr$|d ndd	}| jjjjd|dd}t|}| jjj|d|||d |||	|
d
	}|d }| j|}| j| }|d ddd|| j j||}||||}| j||| d}| |||d d |d d |d d g} | || jj j| jd | jd }!d\}"}#|dk	rt| j j| j j | j j!d}$ddddg}%t"|$| j j#| j j$|%d}&|&%| j i }'||'d< ||'d< |!|'d< | j j&r|
r|j'n|d }(| |(})| |( }*| (|)|*}#|#|'d< |&|'|d| j j)d| j j*d< | j j+d< | j j,d< | j j&rvi }+t-| j j.d D ]" |+/ fdd0 D  qH/|+ t1fdd2 D }"|
s|#dk	r|||!f|# | | },n|||!f| | },|"dk	r|"f|, S |,S t3|"|||!|#|j4|j5|j6|j7|j4|j5|j6d S )!a  
        labels (`List[Dict]` of len `(batch_size,)`, *optional*):
            Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
            dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
            bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
            should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
            `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
            `torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.

        Returns:

        Examples:

        ```python
        >>> import io
        >>> import requests
        >>> from PIL import Image
        >>> import torch
        >>> import numpy

        >>> from transformers import AutoImageProcessor, DetrForSegmentation
        >>> from transformers.image_transforms import rgb_to_id

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
        >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")

        >>> # prepare image for the model
        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**inputs)

        >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
        >>> # Segmentation results are returned as a list of dictionaries
        >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])

        >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
        >>> panoptic_seg = result[0]["segmentation"]
        >>> # Get prediction score and segment_id to class_id mapping of each segment
        >>> panoptic_segments_info = result[0]["segments_info"]
        ```Nr   )ro   rN   r   r   r   rF  r  rG  r|   rp   rb  rc  ra  rg  rh  masksri  r1   r2   r<   r3   rm  rp  	loss_mask	loss_dicec                    s    i | ]\}}|d    |qS rq  r+   rr  rt  r+   r,   ru  I  s      z/DetrForSegmentation.forward.<locals>.<dictcomp>c                 3   s&   | ]}|kr | |  V  qd S r~   r+   rv  rw  r+   r,   r  K  s      z.DetrForSegmentation.forward.<locals>.<genexpr>)r/   r0   r1   r2   r<   r3   r4   r5   r6   r7   r8   r9   r:   )8rd   r  ru   r   r(   rE   r  r]   rf   r5  r   r   r7  rU   r   rY   r   r>   r   r   rH  r8  rW  rY  ry  r   r   r  r  r6  rz  rd  re  rf  r{  rV  r|  rv   r/  r"   r_  r}  r~  Zmask_loss_coefficientZdice_loss_coefficientr  r)  r  r  r  r   r;   r4   r   r  r7   )-rG   rn   ro   rC  rD  r#  rE  ra  r   r  r$  r   rg   r   r   r   ry   rI  r{   r|   rJ  rK  r   rL  r   rM  rN  r  r1   r2   Zmemory	bbox_maskZ	seg_masksr<   r/   r3   rj  rl  r  r  r2  r]  r^  r  r  r+   r  r,   rS     s    <
	  
","

  



 

zDetrForSegmentation.forward)	NNNNNNNNN)r$   r%   r&   r   rC   r   rO  r   r;   rP  r(   r)   r   rQ  r   r  rw   r   r   rS   rT   r+   r+   rI   r,   r  {  s2   	
         
r  )lengthc                 C   s$   |  ddt|dddddS )Nr   r   )r   r   r   r   )r   r  r+   r+   r,   _expande  s    r  c                       s6   e Zd ZdZ fddZeeee dddZ  ZS )r  z^
    Simple convolutional head, using group norm. Upsampling is done using a FPN approach
    c                    s  t    |d dkrtd||d |d |d |d |d g}tj||dd	d
| _td|| _tj||d	 dd	d
| _tt	d|d	 |d	 | _
tj|d	 |d dd	d
| _tt	d|d |d | _tj|d |d dd	d
| _tt	d|d |d | _tj|d |d dd	d
| _tt	d|d |d | _tj|d d	dd	d
| _|| _t|d |d	 d	| _t|d	 |d d	| _t|d |d d	| _|  D ]6}t|tjrtjj|jd	d tj|jd qd S )N   r   zsThe hidden_size + number of attention heads must be divisible by 8 as the number of groups in GroupNorm is set to 8r   rc   rb   r   r
   r   )padding)rZ  )rB   rC   r   r	   r  lay1Z	GroupNormgn1lay2r   gn2lay3gn3lay4gn4lay5gn5out_layr   adapter1adapter2adapter3modulesrU   r  Zkaiming_uniform_r>   Z	constant_r?   )rG   r   Zfpn_dimsZcontext_dimZ
inter_dimsmrI   r+   r,   rC   o  s2    
$zDetrMaskHeadSmallConv.__init__)rP   r  fpnsc                 C   s  t t||jd |ddgd}| |}| |}tj	|}| 
|}| |}tj	|}| |d }|d|dkrt||d|d }|tjj||jdd  dd }| |}| |}tj	|}| |d }|d|dkrt||d|d }|tjj||jdd  dd }| |}| |}tj	|}| |d }|d|dkrt||d|d }|tjj||jdd  dd }| |}| |}tj	|}| |}|S )Nr   r   rp   Znearest)rq   moder   )r(   r   r  ru   r   r  r  r	   rr   relur  r  r  rq   rs   r  r  r  r  r  r  r  r  r  )rG   rP   r  r  Zcur_fpnr+   r+   r,   rS     s<    $



 

 

 


zDetrMaskHeadSmallConv.forward)	r$   r%   r&   r'   rC   r   r   rS   rT   r+   r+   rI   r,   r  j  s   "r  c                       s6   e Zd ZdZd
 fdd	Zdee ddd	Z  ZS )r  zdThis is a 2D attention module, which only returns the attention softmax (no multiplication by value)r   TNc                    s^   t    || _|| _t|| _tj|||d| _tj|||d| _	t
|| j d | _d S )Nr   r   )rB   rC   r   rS  r	   r   r   r   r  r  rt   normalize_fact)rG   Z	query_dimrS  r   r   r?   r  rI   r+   r,   rC     s    
zDetrMHAttentionMap.__init__r  c                 C   s   |  |}tj|| jjdd| jj}||j	d |j	d | j
| j| j
 }||j	d | j
| j| j
 |j	d |j	d }td|| j |}|d k	r||ddt|jj tjj|ddd| }| |}|S )NrN   r   r   rp   zbqnc,bnchw->bqnhwr   r   )r  r	   rr   Zconv2dr  r>   r   r?   r   ru   r   rS  r(   Zeinsumr  Zmasked_fill_r   r   r   r   r   rq   r   )rG   qrs  r|   Zqueries_per_headZkeys_per_headweightsr+   r+   r,   rS     s    
$&." 
zDetrMHAttentionMap.forward)r   TN)N)	r$   r%   r&   r'   rC   r   r   rS   rT   r+   r+   rI   r,   r    s   r  c                 C   sX   |   } | d} d| | d }| d|d }d|d |d   }| | S )a  
    Compute the DICE loss, similar to generalized IOU for masks

    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs (0 for the negative class and 1 for the positive
                 class).
    r   r   rN   )ry  r   r  )r+  targets	num_boxes	numeratordenominatorr/   r+   r+   r,   	dice_loss  s    
r        ?)alphagammac           
      C   s|   |   }tjj| |dd}|| d| d|   }|d| |  }|dkrj|| d| d|   }	|	| }|d | S )a  
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
        inputs (`torch.FloatTensor` of arbitrary shape):
            The predictions for each example.
        targets (`torch.FloatTensor` with the same shape as `inputs`)
            A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
            and 1 for the positive class).
        alpha (`float`, *optional*, defaults to `0.25`):
            Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
        gamma (`int`, *optional*, defaults to `2`):
            Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.

    Returns:
        Loss tensor
    noneZ	reductionr   r   )ry  r	   rr   Z binary_cross_entropy_with_logitsr  r  )
r+  r  r  r  r  ZprobZce_lossZp_tr/   Zalpha_tr+   r+   r,   sigmoid_focal_loss  s    r  c                       sh   e Zd ZdZ fddZdd Ze dd Zdd	 Z	d
d Z
dd Zdd Zdd Zdd Z  ZS )r{  a  
    This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
    we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
    of matched ground-truth / prediction (supervise class and box).

    A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
    parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
    the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
    be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
    (`max_obj_id` + 1). For more details on this, check the following discussion
    https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"


    Args:
        matcher (`DetrHungarianMatcher`):
            Module able to compute a matching between targets and proposals.
        num_classes (`int`):
            Number of object categories, omitting the special no-object category.
        eos_coef (`float`):
            Relative classification weight applied to the no-object category.
        losses (`List[str]`):
            List of all the losses to be applied. See `get_loss` for a list of all available losses.
    c                    sL   t    || _|| _|| _|| _t| jd }| j|d< | d| d S )Nr   rN   empty_weight)	rB   rC   rj  r   rk  rl  r(   rE   rD   )rG   rj  r   rk  rl  r  rI   r+   r,   rC   "  s    

zDetrLoss.__init__c                 C   s   d|krt d|d }| |}tdd t||D }tj|jdd | jtj|j	d}|||< t
j|dd|| j}	d	|	i}
|
S )
z
        Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
        [nb_target_boxes]
        r1   z#No logits were found in the outputsc                 S   s    g | ]\}\}}|d  | qS class_labelsr+   )r  tr  Jr+   r+   r,   r  7  s    
 z(DetrLoss.loss_labels.<locals>.<listcomp>Nr   r   r   rn  )KeyError_get_source_permutation_idxr(   r   r\  fullru   r   int64r   r	   rr   Zcross_entropyr   r  )rG   r   r  indicesr  Zsource_logitsr3  Ztarget_classes_oZtarget_classesrn  rl  r+   r+   r,   loss_labels-  s    
   zDetrLoss.loss_labelsc                 C   sf   |d }|j }tjdd |D |d}|d|jd d kd}tj|	 |	 }	d|	i}
|
S )z
        Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.

        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
        r1   c                 S   s   g | ]}t |d  qS r  rY   r  r+   r+   r,   r  K  s     z-DetrLoss.loss_cardinality.<locals>.<listcomp>r   rN   r   Zcardinality_error)
r   r(   	as_tensorZargmaxru   r  r	   rr   l1_lossrt   )rG   r   r  r  r  r1   r   Ztarget_lengthsZ	card_predZcard_errrl  r+   r+   r,   loss_cardinalityB  s    zDetrLoss.loss_cardinalityc                 C   s   d|krt d| |}|d | }tjdd t||D dd}tjj||dd}i }	| | |	d	< d
t	t
t|t| }
|
 | |	d< |	S )a<  
        Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.

        Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
        are expected in format (center_x, center_y, w, h), normalized by the image size.
        r2   z#No predicted boxes found in outputsc                 S   s    g | ]\}\}}|d  | qS rg  r+   )r  r  r  r%  r+   r+   r,   r  ]  s    
 z'DetrLoss.loss_boxes.<locals>.<listcomp>r   r   r  r  ro  r   rp  )r  r  r(   r   r\  r	   rr   r  r  Zdiaggeneralized_box_iour    )rG   r   r  r  r  r3  Zsource_boxesZtarget_boxesro  rl  rp  r+   r+   r,   
loss_boxesR  s    
zDetrLoss.loss_boxesc                 C   s   d|krt d| |}| |}|d }|| }dd |D }t| \}	}
|	|}	|	| }	tjj|dddf |	j	dd ddd	}|ddd
f 
d}|	
d}	|	|j	}	t||	|t||	|d}|S )z
        Compute the losses related to the masks: the focal loss and the dice loss.

        Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
        r<   z#No predicted masks found in outputsc                 S   s   g | ]}|d  qS )r  r+   r  r  r+   r+   r,   r  w  s     z'DetrLoss.loss_masks.<locals>.<listcomp>Nrp   ZbilinearF)rq   r  Zalign_cornersr   r   )r  r  )r  r  _get_target_permutation_idxnested_tensor_from_tensor_list	decomposerv   r	   rr   rs   ru   r   r   r  r  )rG   r   r  r  r  
source_idx
target_idxZsource_masksr  Ztarget_masksZvalidrl  r+   r+   r,   
loss_masksj  s.    


   


zDetrLoss.loss_masksc                 C   s4   t dd t|D }t dd |D }||fS )Nc                 S   s    g | ]\}\}}t ||qS r+   r(   Z	full_like)r  r%  sourcer  r+   r+   r,   r    s    
 z8DetrLoss._get_source_permutation_idx.<locals>.<listcomp>c                 S   s   g | ]\}}|qS r+   r+   )r  r  r  r+   r+   r,   r    s     r(   r   r   )rG   r  	batch_idxr  r+   r+   r,   r    s    z$DetrLoss._get_source_permutation_idxc                 C   s4   t dd t|D }t dd |D }||fS )Nc                 S   s    g | ]\}\}}t ||qS r+   r  )r  r%  r  targetr+   r+   r,   r    s    
 z8DetrLoss._get_target_permutation_idx.<locals>.<listcomp>c                 S   s   g | ]\}}|qS r+   r+   )r  r  r  r+   r+   r,   r    s     r  )rG   r  r  r  r+   r+   r,   r    s    z$DetrLoss._get_target_permutation_idxc                 C   s@   | j | j| j| jd}||kr.td| d|| ||||S )N)ra  rh  rg  r  zLoss z not supported)r  r  r  r  r   )rG   r/   r   r  r  r  Zloss_mapr+   r+   r,   get_loss  s    zDetrLoss.get_lossc           
   
      s  dd |  D }| ||}tdd |D }tj|gtjtt| j	d}tj
|dd }i }| jD ]}|| ||||| qnd|krt|d D ]\\ }| ||}| jD ]@}|d	krq| |||||}	 fd
d|	  D }	||	 qq|S )a  
        This performs the loss computation.

        Args:
             outputs (`dict`, *optional*):
                Dictionary of tensors, see the output specification of the model for the format.
             targets (`List[dict]`, *optional*):
                List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
                losses applied, see each loss' doc.
        c                 S   s   i | ]\}}|d kr||qS )r3   r+   rr  r+   r+   r,   ru    s       z$DetrLoss.forward.<locals>.<dictcomp>c                 s   s   | ]}t |d  V  qdS )r  Nr  r  r+   r+   r,   r    s     z#DetrLoss.forward.<locals>.<genexpr>r   r   r   r3   r  c                    s    i | ]\}}|d    |qS rq  r+   rr  rt  r+   r,   ru    s      )r  rj  r  r(   r  rt   nextitervaluesr   r   itemrl  r  r  r   )
rG   r   r  Zoutputs_without_auxr  r  rl  r/   r3   Zl_dictr+   rt  r,   rS     s$    "

zDetrLoss.forward)r$   r%   r&   r'   rC   r  r(   ri   r  r  r  r  r  r  rS   rT   r+   r+   rI   r,   r{  	  s   
!r{  c                       s(   e Zd ZdZ fddZdd Z  ZS )rX  a  
    Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
    height and width of a bounding box w.r.t. an image.

    Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py

    c                    sJ   t    || _|g|d  }tdd t|g| ||g D | _d S )Nr   c                 s   s   | ]\}}t ||V  qd S r~   )r	   r   )r  rH   rs  r+   r+   r,   r    s     z1DetrMLPPredictionHead.__init__.<locals>.<genexpr>)rB   rC   rU  r	   r  r\  r  )rG   r   rS  rT  rU  hrI   r+   r,   rC     s    
zDetrMLPPredictionHead.__init__c                 C   s>   t | jD ].\}}|| jd k r0tj||n||}q
|S )Nr   )r   r  rU  r	   rr   r  )rG   rP   r%  layerr+   r+   r,   rS     s    (zDetrMLPPredictionHead.forwardr   r+   r+   rI   r,   rX    s   rX  c                       s<   e Zd ZdZdeeed fddZe dd Z  Z	S )	rz  a  
    This class computes an assignment between the targets and the predictions of the network.

    For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
    predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
    un-matched (and thus treated as non-objects).

    Args:
        class_cost:
            The relative weight of the classification error in the matching cost.
        bbox_cost:
            The relative weight of the L1 error of the bounding box coordinates in the matching cost.
        giou_cost:
            The relative weight of the giou loss of the bounding box in the matching cost.
    r   rc  c                    sL   t    t| dg || _|| _|| _|dkrH|dkrH|dkrHtdd S )NZscipyr   z#All costs of the Matcher can't be 0)rB   rC   r   rd  re  rf  r   )rG   rd  re  rf  rI   r+   r,   rC     s    
zDetrHungarianMatcher.__init__c                 C   s   |d j dd \}}|d ddd}|d dd}tdd	 |D }td
d	 |D }|dd|f  }	tj||dd}
tt|t| }| j|
 | j	|	  | j
|  }|||d }dd	 |D }dd	 t||dD }dd	 |D S )a  
        Args:
            outputs (`dict`):
                A dictionary that contains at least these entries:
                * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
            targets (`List[dict]`):
                A list of targets (len(targets) = batch_size), where each target is a dict containing:
                * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
                  ground-truth
                 objects in the target) containing the class labels
                * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.

        Returns:
            `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
            - index_i is the indices of the selected predictions (in order)
            - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        r1   Nr   r   r   rN   r2   c                 S   s   g | ]}|d  qS r  r+   r  r+   r+   r,   r  	  s     z0DetrHungarianMatcher.forward.<locals>.<listcomp>c                 S   s   g | ]}|d  qS r  r+   r  r+   r+   r,   r  	  s     r   c                 S   s   g | ]}t |d  qS r  r  r  r+   r+   r,   r  /	  s     c                 S   s   g | ]\}}t || qS r+   r   )r  r%  cr+   r+   r,   r  0	  s     c                 S   s0   g | ](\}}t j|t jd t j|t jd fqS )r   )r(   r  r  )r  r%  jr+   r+   r,   r  1	  s     )ru   r   r   r(   r   Zcdistr  r    re  rd  rf  r   cpur   split)rG   r   r  r   r6  Zout_probZout_bboxZ
target_idsZtarget_bboxrd  re  rf  Zcost_matrixsizesr  r+   r+   r,   rS   	  s    zDetrHungarianMatcher.forward)r   r   r   )
r$   r%   r&   r'   rt   rC   r(   ri   rS   rT   r+   r+   rI   r,   rz    s   
rz  )r  r   c                 C   sH   |   r&| jtjtjfkr| S |  S | jtjtjfkr<| S |  S d S r~   )	Zis_floating_pointr   r(   r   Zfloat64rt   Zint32r  r   )r  r+   r+   r,   _upcast7	  s    r  )rg  r   c                 C   sH   t | } | dddf | dddf  | dddf | dddf   S )a  
    Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

    Args:
        boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
            Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
            < x2` and `0 <= y1 < y2`.

    Returns:
        `torch.FloatTensor`: a tensor containing the area for each box.
    Nr   r   r
   r   )r  r  r+   r+   r,   box_area?	  s    r  c           
      C   s   t | }t |}t| d d d d df |d d d df }t| d d d dd f |d d dd f }|| jdd}|d d d d df |d d d d df  }|d d d f | | }|| }	|	|fS )Nr   r   r  r   )r  r(   r   r   r   )
boxes1boxes2Zarea1Zarea2Zleft_topZright_bottomwidth_heightZinterunioniour+   r+   r,   box_iouP	  s    ..,r  c                 C   s*  | ddddf | ddddf k  s:td|  |ddddf |ddddf k  sttd| t| |\}}t| dddddf |ddddf }t| dddddf |ddddf }|| jdd}|dddddf |dddddf  }||| |  S )z
    Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.

    Returns:
        `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
    Nr   z<boxes1 must be in [x0, y0, x1, y1] (corner) format, but got z<boxes2 must be in [x0, y0, x1, y1] (corner) format, but got r   r  r   )allr   r  r(   r   r   r   )r  r  r  r  top_leftbottom_rightr  Zarear+   r+   r,   r  `	  s    	,,..,r  c                 C   sB   | d }| dd  D ](}t |D ]\}}t|| |||< q q|S )Nr   r   )r   r   )Zthe_listZmaxesZsublistindexr  r+   r+   r,   _max_by_axisy	  s
    r  c                   @   s6   e Zd Zee dddZdd Zdd Zdd	 Zd
S )NestedTensorr  c                 C   s   || _ || _d S r~   tensorsr|   )rG   r  r|   r+   r+   r,   rC   	  s    zNestedTensor.__init__c                 C   s4   | j |}| j}|d k	r&||}nd }t||S r~   )r  rv   r|   r  )rG   r   Zcast_tensorr|   Z	cast_maskr+   r+   r,   rv   	  s    zNestedTensor.toc                 C   s   | j | jfS r~   r  r9  r+   r+   r,   r  	  s    zNestedTensor.decomposec                 C   s
   t | jS r~   )strr  r9  r+   r+   r,   __repr__	  s    zNestedTensor.__repr__N)	r$   r%   r&   r   r   rC   rv   r  r  r+   r+   r+   r,   r  	  s   	r  )tensor_listc                 C   s   | d j dkrtdd | D }t| g| }|\}}}}| d j}| d j}tj|||d}	tj|||ftj|d}
t	| |	|
D ]\\}}}|d |j
d d |j
d d |j
d f | d|d |j
d d |j
d f< qntd	t|	|
S )
Nr   r
   c                 S   s   g | ]}t |jqS r+   )rZ   ru   )r  imgr+   r+   r,   r  	  s     z2nested_tensor_from_tensor_list.<locals>.<listcomp>r   r   r   Fz(Only 3-dimensional tensors are supported)ndimr  rY   r   r   r(   rF   rE   rw   r\  ru   rX   r   r  )r  max_sizeZbatch_shaper   rg   r   r   r   r   r   r|   r  Zpad_imgr  r+   r+   r,   r  	  s    

2$r  )N)r  r   )[r'   r   dataclassesr   typingr   r   r   r   r   r(   r   r	   Zactivationsr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r0  r   r   r   r   r   r   r   r   r   autor   Zconfiguration_detrr   Zscipy.optimizer   ra   r   Ztransformers.image_transformsr    Z
get_loggerr$   r   rP  Z_CHECKPOINT_FOR_DOCZ"DETR_PRETRAINED_MODEL_ARCHIVE_LISTr!   r-   r.   r;   Moduler=   r\   r`   r}   r   r   r   r   r   r   r   r   r   r   r   ZDETR_START_DOCSTRINGrO  r  r
  r4  rR  r  r  r  r  r  rt   r  r{  rX  rz  r  r  r  r  r  objectr  r  r+   r+   r+   r,   <module>   s   ,
'>G';% 0Xz!&  I " ' cP  HQ	