U
    9%e                     @   s  d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
mZ ddlZddlZddlmZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZmZmZ ddlmZ ddlmZm Z  ddl!m"Z"m#Z# ddl$m%Z% ddl&m'Z' e#(e)Z*dZ+dZ,dddgZ-ddgZ.eG dd de"Z/eG dd de"Z0G dd dej1Z2G dd dej1Z3G dd  d ej1Z4G d!d" d"ej1Z5G d#d$ d$ej1Z6G d%d& d&ej1Z7G d'd( d(ej1Z8G d)d* d*ej1Z9G d+d, d,ej1Z:G d-d. d.ej1Z;G d/d0 d0ej1Z<G d1d2 d2ej1Z=G d3d4 d4ej1Z>G d5d6 d6ej1Z?G d7d8 d8ej1Z@G d9d: d:eZAd;ZBd<ZCed=eBG d>d? d?eAZDG d@dA dAej1ZEG dBdC dCej1ZFG dDdE dEej1ZGedFeBG dGdH dHeAZHG dIdJ dJej1ZIG dKdL dLej1ZJedMeBG dNdO dOeAZKdS )Pz PyTorch DPT (Dense Prediction Transformers) model.

This implementation is heavily inspired by OpenMMLab's implementation, found here:
https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py.

    N)	dataclass)ListOptionalSetTupleUnion)nn)CrossEntropyLoss   )ACT2FN)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardreplace_return_docstrings)BaseModelOutputDepthEstimatorOutputSemanticSegmenterOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputlogging   )AutoBackbone   )	DPTConfigr   zIntel/dpt-largeiA  i   zIntel/dpt-hybrid-midasc                   @   s6   e Zd ZU dZdZejed< dZe	e
ej  ed< dS )*BaseModelOutputWithIntermediateActivationsa#  
    Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
    in the context of Vision models.:

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
            Intermediate activations that can be used to compute hidden states of the model at various layers.
    Nlast_hidden_statesintermediate_activations)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r    r&   r&   c/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/dpt/modeling_dpt.pyr   A   s   
r   c                   @   sp   e Zd ZU dZdZejed< dZejed< dZ	e
eej  ed< dZe
eej  ed< dZe
eej  ed< dS )4BaseModelOutputWithPoolingAndIntermediateActivationsa  
    Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
    activations that can be used by the model at later stages.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) after further processing
            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
            the classification token after processing through a linear layer and a tanh activation function. The linear
            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
            Intermediate activations that can be used to compute hidden states of the model at various layers.
    Nlast_hidden_statepooler_outputhidden_states
attentionsr   )r   r    r!   r"   r)   r#   r$   r%   r*   r+   r   r   r,   r   r&   r&   r&   r'   r(   R   s   
r(   c                       sF   e Zd ZdZd fdd	ZdddZdejeeejd	d
dZ	  Z
S )DPTViTHybridEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    Nc           
         sn  t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }t
|j| _| jjd }t|jjdkrtdt|jj ddg| _|d kr|j}	|	dd  }|	d }n&t|tjj	r|n||f}| jjd }|| _|d | _|| _tj||dd| _ttdd|j| _ttd|d |j| _d S )Nr   r   r
   z1Expected backbone to have 3 output features, got kernel_size)super__init__
image_size
patch_sizenum_channelshidden_size
isinstancecollectionsabcIterabler   from_configZbackbone_configbackbonechannelslenZout_features
ValueErrorresidual_feature_map_indexZbackbone_featmap_shaper   Conv2d
projection	Parameterr#   zeros	cls_tokenposition_embeddings)
selfconfigZfeature_sizer4   r5   r6   r7   num_patchesZfeature_dimZfeat_map_shape	__class__r&   r'   r3   }   s4    
 


zDPTViTHybridEmbeddings.__init__r   c                 C   s   |d d d |f }|d|d f }t tt|}|d||ddddd}tjj|||fdd}|ddddd|| d}t	j
||gdd}|S 	Nr   r   r.   r
   r   bilinear)sizemodedimintmathsqrtr?   reshapepermuter   
functionalinterpolater#   catrH   ZposembZgrid_size_heightZgrid_size_widthstart_indexZ
posemb_tokZposemb_gridZold_grid_sizer&   r&   r'   _resize_pos_embed   s    z(DPTViTHybridEmbeddings._resize_pos_embedF)pixel_valuesinterpolate_pos_encodingreturn_dictreturnc              
      s  |j \}}}}|| jkr td|sn|| jd ks@|| jd krntd| d| d| jd  d| jd  d	| | j|| j || j }| |  jd }	 fd	d
| j	D }
| 
|	ddd}| j|dd}tj||fdd}|| }|s||
fS t||
dS )NeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model (z).r.   c                    s   g | ]} j | qS r&   )feature_maps).0indexZbackbone_outputr&   r'   
<listcomp>   s     z2DPTViTHybridEmbeddings.forward.<locals>.<listcomp>r   rQ   )r   r   )shaper6   r@   r4   r^   rG   r5   r=   re   rA   rC   flatten	transposerF   expandr#   r[   r   )rH   r_   r`   ra   
batch_sizer6   heightwidthrG   featuresoutput_hidden_states
embeddings
cls_tokensr&   rh   r'   forward   s8    
(  

zDPTViTHybridEmbeddings.forward)N)r   )FF)r   r    r!   r"   r3   r^   r#   Tensorboolru   __classcell__r&   r&   rK   r'   r-   v   s   $
     r-   c                       s4   e Zd ZdZ fddZd
ddZddd	Z  ZS )DPTViTEmbeddingszB
    Construct the CLS token, position and patch embeddings.

    c                    sh   t    ttdd|j| _t|| _	| j	j
}ttd|d |j| _t|j| _|| _d S )Nr   )r2   r3   r   rD   r#   rE   r7   rF   DPTViTPatchEmbeddingspatch_embeddingsrJ   rG   Dropouthidden_dropout_probdropoutrI   )rH   rI   rJ   rK   r&   r'   r3      s    

zDPTViTEmbeddings.__init__r   c                 C   s   |d d d |f }|d|d f }t tt|}|d||ddddd}tjj|||fdd}|ddddd|| d}t	j
||gdd}|S rM   rS   r\   r&   r&   r'   r^      s    z"DPTViTEmbeddings._resize_pos_embedFc                 C   s   |j \}}}}| jj}| | j|| || }| |}	|	 \}}
}| j|dd}t	j
||	fdd}	|	| }	| |	}	|s|	fS t|	dS )Nr.   r   rQ   )r   )rj   rI   r5   r^   rG   r{   rO   rF   rm   r#   r[   r~   r   )rH   r_   ra   rn   r6   ro   rp   r5   rG   rs   Zseq_len_rt   r&   r&   r'   ru      s       

zDPTViTEmbeddings.forward)r   )F)r   r    r!   r"   r3   r^   ru   rx   r&   r&   rK   r'   ry      s   

ry   c                       s(   e Zd ZdZ fddZdd Z  ZS )rz   z$
    Image to Patch Embedding.

    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )r1   stride)r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   rJ   r   rB   rC   )rH   rI   r4   r5   r6   r7   rJ   rK   r&   r'   r3     s    
 zDPTViTPatchEmbeddings.__init__c                 C   s<   |j \}}}}|| jkr td| |ddd}|S )Nrc   r   r   )rj   r6   r@   rC   rk   rl   )rH   r_   rn   r6   ro   rp   rs   r&   r&   r'   ru   *  s    
zDPTViTPatchEmbeddings.forwardr   r    r!   r"   r3   ru   rx   r&   r&   rK   r'   rz     s   rz   c                       sl   e Zd Zedd fddZejejdddZdeej e	e
eejejf eej f d	d
dZ  ZS )DPTViTSelfAttentionNrI   rb   c                    s   t    |j|j dkr@t|ds@td|jf d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)bias)r2   r3   r7   num_attention_headshasattrr@   rT   attention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyvaluer|   Zattention_probs_dropout_probr~   rH   rI   rK   r&   r'   r3   6  s    
zDPTViTSelfAttention.__init__)xrb   c                 C   s6   |  d d | j| jf }||}|ddddS )Nr.   r   r   r   r
   )rO   r   r   viewrX   )rH   r   Znew_x_shaper&   r&   r'   transpose_for_scoresH  s    
z(DPTViTSelfAttention.transpose_for_scoresF)	head_maskoutput_attentionsrb   c                 C   s   |  |}| | |}| | |}| |}t||dd}|t| j	 }t
jj|dd}	| |	}	|d k	r|	| }	t|	|}
|
dddd }
|
 d d | jf }|
|}
|r|
|	fn|
f}|S )Nr.   r/   rQ   r   r   r   r
   )r   r   r   r   r#   matmulrl   rU   rV   r   r   rY   Zsoftmaxr~   rX   
contiguousrO   r   r   )rH   r+   r   r   Zmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr&   r&   r'   ru   M  s     



zDPTViTSelfAttention.forward)NF)r   r    r!   r   r3   r#   rv   r   r   rw   r   r   ru   rx   r&   r&   rK   r'   r   5  s       r   c                       s@   e Zd ZdZedd fddZejejejdddZ  Z	S )	DPTViTSelfOutputz
    The residual connection is defined in DPTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    Nr   c                    s.   t    t|j|j| _t|j| _d S N)	r2   r3   r   r   r7   denser|   r}   r~   r   rK   r&   r'   r3   x  s    
zDPTViTSelfOutput.__init__r+   input_tensorrb   c                 C   s   |  |}| |}|S r   r   r~   rH   r+   r   r&   r&   r'   ru   }  s    

zDPTViTSelfOutput.forward)
r   r    r!   r"   r   r3   r#   rv   ru   rx   r&   r&   rK   r'   r   r  s   r   c                       sp   e Zd Zedd fddZee ddddZdej	e
ej	 eeeej	ej	f eej	 f d	d
dZ  ZS )DPTViTAttentionNr   c                    s*   t    t|| _t|| _t | _d S r   )r2   r3   r   	attentionr   outputsetpruned_headsr   rK   r&   r'   r3     s    


zDPTViTAttention.__init__)headsrb   c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rQ   )r?   r   r   r   r   r   r   r   r   r   r   r   r   union)rH   r   rg   r&   r&   r'   prune_heads  s       zDPTViTAttention.prune_headsFr+   r   r   rb   c                 C   s4   |  |||}| |d |}|f|dd   }|S )Nr   r   )r   r   )rH   r+   r   r   Zself_outputsattention_outputr   r&   r&   r'   ru     s    zDPTViTAttention.forward)NF)r   r    r!   r   r3   r   rT   r   r#   rv   r   rw   r   r   ru   rx   r&   r&   rK   r'   r     s     r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )DPTViTIntermediateNr   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r2   r3   r   r   r7   intermediate_sizer   r8   
hidden_actstrr   intermediate_act_fnr   rK   r&   r'   r3     s
    
zDPTViTIntermediate.__init__r+   rb   c                 C   s   |  |}| |}|S r   )r   r   )rH   r+   r&   r&   r'   ru     s    

zDPTViTIntermediate.forward	r   r    r!   r   r3   r#   rv   ru   rx   r&   r&   rK   r'   r     s   r   c                       s<   e Zd Zedd fddZejejejdddZ  ZS )DPTViTOutputNr   c                    s.   t    t|j|j| _t|j| _	d S r   )
r2   r3   r   r   r   r7   r   r|   r}   r~   r   rK   r&   r'   r3     s    
zDPTViTOutput.__init__r   c                 C   s    |  |}| |}|| }|S r   r   r   r&   r&   r'   ru     s    

zDPTViTOutput.forwardr   r&   r&   rK   r'   r     s   r   c                       s`   e Zd ZdZedd fddZd
ejeej e	e
eejejf eej f ddd	Z  ZS )DPTViTLayerz?This corresponds to the Block class in the timm implementation.Nr   c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   Zeps)r2   r3   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   	LayerNormr7   layer_norm_epslayernorm_beforelayernorm_afterr   rK   r&   r'   r3     s    



zDPTViTLayer.__init__Fr   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )Nr   r   r   )r   r   r   r   r   )rH   r+   r   r   Zself_attention_outputsr   r   Zlayer_outputr&   r&   r'   ru     s    


zDPTViTLayer.forward)NF)r   r    r!   r"   r   r3   r#   rv   r   rw   r   r   ru   rx   r&   r&   rK   r'   r     s     r   c                	       sN   e Zd Zedd fddZd
ejeej eeee	e
ef ddd	Z  ZS )DPTViTEncoderNr   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r&   )r   )rf   r   rI   r&   r'   ri     s     z*DPTViTEncoder.__init__.<locals>.<listcomp>F)	r2   r3   rI   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   rK   r   r'   r3     s    
 zDPTViTEncoder.__init__FT)r+   r   r   rr   ra   rb   c                    s   |rdnd } rdnd }t | jD ]\}}	|r8||f }|d k	rH|| nd }
| jr|| jr| fdd}tjj||	||
}n|	||
 }|d } r"||d f }q"|r||f }|stdd |||fD S t|||dS )	Nr&   c                    s    fdd}|S )Nc                     s    | f S r   r&   )inputs)moduler   r&   r'   custom_forward  s    zLDPTViTEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr&   )r   r   r   )r   r'   create_custom_forward  s    z4DPTViTEncoder.forward.<locals>.create_custom_forwardr   r   c                 s   s   | ]}|d k	r|V  qd S r   r&   )rf   vr&   r&   r'   	<genexpr>+  s      z(DPTViTEncoder.forward.<locals>.<genexpr>)r)   r+   r,   )		enumerater   r   Ztrainingr#   utils
checkpointtupler   )rH   r+   r   r   rr   ra   Zall_hidden_statesZall_self_attentionsiZlayer_moduleZlayer_head_maskr   Zlayer_outputsr&   r   r'   ru     s4    

zDPTViTEncoder.forward)NFFT)r   r    r!   r   r3   r#   rv   r   rw   r   r   r   ru   rx   r&   r&   rK   r'   r     s   	    
r   c                       sL   e Zd ZdZ fddZdd Zdd Zeej	 eej	 dd	d
Z
  ZS )DPTReassembleStagea@  
    This class reassembles the hidden states of the backbone into image-like feature representations at various
    resolutions.

    This happens in 3 stages:
    1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
       `config.readout_type`.
    2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
    3. Resizing the spatial dimensions (height, width).

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
    c                    sB   t    || _t | _|jr,| | n
| | |j	| _	d S r   )
r2   r3   rI   r   r   layers	is_hybrid_init_reassemble_dpt_hybrid_init_reassemble_dptneck_ignore_stagesr   rK   r&   r'   r3   C  s    


zDPTReassembleStage.__init__c              	   C   s   t tt|j|jD ]F\}}|dkr8| jt  q|dkr| jt	||j| |d q|j
dkrztd|j
 dt | _tt|jD ]V}|dkr| jtt  q|dkr| jttd|j |jt|j  qdS )a   "
        For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
        implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
        for more details.
        r   r>   factorprojectzReadout type z! is not supported for DPT-Hybrid.r   N)zipr   r?   neck_hidden_sizesreassemble_factorsr   appendr   IdentityDPTReassembleLayerreadout_typer@   r   readout_projects
Sequentialr   r7   r   r   )rH   rI   r   r   r&   r&   r'   r   O  s    

 z.DPTReassembleStage._init_reassemble_dpt_hybridc              	   C   s   t tt|j|jD ]$\}}| jt||j| |d q|jdkrt	
 | _tt|jD ].}| jt	t	d|j |jt|j  q^d S )Nr   r   r   )r   r   r?   r   r   r   r   r   r   r   r   r   r   r   r7   r   r   )rH   rI   r   r   r   r&   r&   r'   r   h  s    

 z'DPTReassembleStage._init_reassemble_dptr   c                 C   s2  g }t |D ]\}}|| jkr"|ddddf |dddf  }}|j\}}}tt|}	|||	|	|}|dddd }|j}
| j	j
dkr|dd}|d|}| j| t||fd}|ddd|
}n,| j	j
d	kr|d|d }||
}| j| |}|| q|S )
z
        Args:
            hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
                List of hidden states from the backbone.
        Nr   r   r
   r   r   )r   r   r   r.   add)r   r   rj   rT   rU   rV   rW   rX   r   rI   r   rk   Z	unsqueezeZ	expand_asr   r#   r[   r   r   )rH   r+   outr   hidden_staterF   rn   Zsequence_lengthr6   rO   Zfeature_shapeZreadoutr&   r&   r'   ru   s  s(    &
zDPTReassembleStage.forward)r   r    r!   r"   r3   r   r   r   r#   rv   ru   rx   r&   r&   rK   r'   r   3  s
   r   c                       s$   e Zd Z fddZdd Z  ZS )r   c                    s|   t    tj|j|dd| _|dkr>tj||||dd| _n:|dkrRt | _n&|dk rxtj||dt	d| dd| _d S )Nr   )Zin_channelsZout_channelsr1   r   r1   r   paddingr
   )
r2   r3   r   rB   r7   rC   ConvTranspose2dresizer   rT   )rH   rI   r>   r   rK   r&   r'   r3     s    
zDPTReassembleLayer.__init__c                 C   s   |  |}| |}|S r   )rC   r   )rH   r   r&   r&   r'   ru     s    

zDPTReassembleLayer.forwardr   r    r!   r3   ru   rx   r&   r&   rK   r'   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )DPTFeatureFusionStagec                    s<   t    t | _tt|jD ]}| jt	| q"d S r   )
r2   r3   r   r   r   r   r?   r   r   DPTFeatureFusionLayer)rH   rI   r   rK   r&   r'   r3     s    

zDPTFeatureFusionStage.__init__c                 C   sl   |d d d }g }| j d |d }|| t|dd  | j dd  D ]\}}|||}|| qJ|S )Nr.   r   r   )r   r   r   )rH   r+   Zfused_hidden_statesZfused_hidden_stater   r   r&   r&   r'   ru     s    
$
zDPTFeatureFusionStage.forwardr   r&   r&   rK   r'   r     s   r   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )DPTPreActResidualLayerz
    ResidualConvUnit, pre-activate residual unit.

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
    c                    s   t    |j| _td | _tj|j|jddd| j d| _	td | _
tj|j|jddd| j d| _| jrt|j| _t|j| _d S )Nrelur
   r   )r1   r   r   r   )r2   r3   Z!use_batch_norm_in_fusion_residualuse_batch_normr   activation1r   rB   fusion_hidden_sizeconvolution1activation2convolution2BatchNorm2dbatch_norm1batch_norm2r   rK   r&   r'   r3     s.    

	
	zDPTPreActResidualLayer.__init__)r   rb   c                 C   sT   |}|  |}| |}| jr(| |}| |}| |}| jrL| |}|| S r   )r   r   r   r   r   r   r   rH   r   Zresidualr&   r&   r'   ru     s    





zDPTPreActResidualLayer.forward)	r   r    r!   r"   r3   r#   rv   ru   rx   r&   r&   rK   r'   r     s   r   c                       s,   e Zd ZdZd fdd	Zd	ddZ  ZS )
r   a3  Feature fusion layer, merges feature maps from different stages.

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
        align_corners (`bool`, *optional*, defaults to `True`):
            The align_corner setting for bilinear upsample.
    Tc                    s@   t    || _tj|j|jddd| _t|| _t|| _	d S )Nr   T)r1   r   )
r2   r3   align_cornersr   rB   r   rC   r   residual_layer1residual_layer2)rH   rI   r   rK   r&   r'   r3     s
    

zDPTFeatureFusionLayer.__init__Nc                 C   st   |d k	rF|j |j kr8tjj||j d |j d fddd}|| | }| |}tjj|dd| jd}| |}|S )Nr   r
   rN   FrO   rP   r   Zscale_factorrP   r   )rj   r   rY   rZ   r   r   r   rC   r   r&   r&   r'   ru     s$       
   
zDPTFeatureFusionLayer.forward)T)Nr   r&   r&   rK   r'   r     s   	
r   c                   @   s2   e Zd ZdZeZdZdZdZdd Z	ddd	Z
d
S )DPTPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    dptr_   Tc                 C   sn   t |tjtjtjfrD|jjjd| jj	d |j
dk	rj|j
j  n&t |tjrj|j
j  |jjd dS )zInitialize the weightsg        )ZmeanZstdNg      ?)r8   r   r   rB   r   weightdataZnormal_rI   Zinitializer_ranger   Zzero_r   Zfill_)rH   r   r&   r&   r'   _init_weights)  s    
z DPTPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r   )r8   r   r   )rH   r   r   r&   r&   r'   _set_gradient_checkpointing5  s    
z.DPTPreTrainedModel._set_gradient_checkpointingN)F)r   r    r!   r"   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r   r&   r&   r&   r'   r     s   r   aE  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aP  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
            for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
z]The bare DPT Model transformer outputting raw hidden-states without any specific head on top.c                       s~   e Zd Zd fdd	Zdd Zdd Zeeee	e
eded	dejeej ee ee ee eee
f dddZ  ZS )DPTModelTc                    sj   t  | || _|jr$t|| _n
t|| _t|| _t	j
|j|jd| _|rXt|nd | _|   d S )Nr   )r2   r3   rI   r   r-   rs   ry   r   encoderr   r   r7   r   	layernormDPTViTPoolerpooler	post_init)rH   rI   add_pooling_layerrK   r&   r'   r3   a  s    

zDPTModel.__init__c                 C   s   | j jr| jS | jjS d S r   )rI   r   rs   r{   rH   r&   r&   r'   get_input_embeddingsr  s    zDPTModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )rH   Zheads_to_pruner   r   r&   r&   r'   _prune_headsx  s    zDPTModel._prune_headsZvision)r   output_typer   ZmodalityZexpected_outputN)r_   r   r   rr   ra   rb   c                 C   s   |d k	r|n| j j}|d k	r |n| j j}|d k	r4|n| j j}| || j j}| j||d}|sf|d n|j}| j|||||d}|d }	| 	|	}	| j
d k	r| 
|	nd }
|s|
d k	r|	|
fn|	f}||dd   |dd   S t|	|
|j|j|jdS )N)ra   r   r   r   rr   ra   r   )r)   r*   r+   r,   r   )rI   r   rr   use_return_dictZget_head_maskr   rs   r   r   r   r  r(   r+   r,   r   )rH   r_   r   r   rr   ra   Zembedding_outputZembedding_last_hidden_statesZencoder_outputsZsequence_outputpooled_outputZhead_outputsr&   r&   r'   ru     s6    
zDPTModel.forward)T)NNNN)r   r    r!   r3   r  r  r   DPT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr(   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr#   r$   r   rw   r   r   ru   rx   r&   r&   rK   r'   r   \  s.   
    
r   c                       s*   e Zd Zed fddZdd Z  ZS )r  r   c                    s*   t    t|j|j| _t | _d S r   )r2   r3   r   r   r7   r   ZTanh
activationr   rK   r&   r'   r3     s    
zDPTViTPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r  )rH   r+   Zfirst_token_tensorr  r&   r&   r'   ru     s    

zDPTViTPooler.forward)r   r    r!   r   r3   ru   rx   r&   r&   rK   r'   r    s   r  c                       s<   e Zd ZdZ fddZeej eej dddZ  Z	S )DPTNecka;  
    DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
    input and produces another list of tensors as output. For DPT, it includes 2 stages:

    * DPTReassembleStage
    * DPTFeatureFusionStage.

    Args:
        config (dict): config dict.
    c              
      s\   t    || _t|| _t | _|jD ]"}| j	tj
||jdddd q*t|| _d S )Nr
   r   Fr1   r   r   )r2   r3   rI   r   reassemble_stager   r   convsr   r   rB   r   r   fusion_stage)rH   rI   ZchannelrK   r&   r'   r3     s    



 zDPTNeck.__init__r   c                    s\   t |tstdt|t jjkr.td |} fddt|D } |}|S )Nz)hidden_states should be a list of tensorszOThe number of hidden states should be equal to the number of neck hidden sizes.c                    s   g | ]\}} j | |qS r&   )r  )rf   r   featurer  r&   r'   ri     s     z#DPTNeck.forward.<locals>.<listcomp>)	r8   listr@   r?   rI   r   r  r   r  )rH   r+   rq   r   r&   r  r'   ru     s    


zDPTNeck.forward
r   r    r!   r"   r3   r   r#   rv   ru   rx   r&   r&   rK   r'   r    s   r  c                       s8   e Zd ZdZ fddZeej ejdddZ  Z	S )DPTDepthEstimationHeada  
    Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
    the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
    supplementary material).
    c                    s|   t    || _|j}ttj||d ddddtjddddtj|d dddddtd	 tjddddd
dtd	 | _	d S )Nr   r
   r   r   rN   Tr       r   r   )
r2   r3   rI   r   r   r   rB   Upsampler   headrH   rI   rq   rK   r&   r'   r3     s    
zDPTDepthEstimationHead.__init__r   c                 C   s&   || j j }| |}|jdd}|S )Nr   rQ   )rI   head_in_indexr  Zsqueeze)rH   r+   predicted_depthr&   r&   r'   ru     s    
zDPTDepthEstimationHead.forwardr  r&   r&   rK   r'   r    s   r  zu
    DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
    c                       st   e Zd Z fddZeeeeedde	j
ee	j
 ee	j ee ee ee eee	j ef dddZ  ZS )	DPTForDepthEstimationc                    s:   t  | t|dd| _t|| _t|| _|   d S NF)r  )	r2   r3   r   r   r  neckr  r  r  r   rK   r&   r'   r3     s
    

zDPTForDepthEstimation.__init__r	  r   Nr_   r   labelsr   rr   ra   rb   c                    sB  |dk	r|n j j}|dk	r |n j j} j|||d|d}|rF|jn|d } j jsv fddt|dd D }n>|r|jn
t|d }	|		 fdd	t|dd D  |	} 
|} |}
d}|dk	rtd
|s$|r|
f|dd  }n|
f|dd  }|dk	r |f| S |S t||
|r6|jnd|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth depth estimation maps for computing the loss.

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
        >>> import torch
        >>> import numpy as np
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
        >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")

        >>> # prepare image for the model
        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)
        ...     predicted_depth = outputs.predicted_depth

        >>> # interpolate to original size
        >>> prediction = torch.nn.functional.interpolate(
        ...     predicted_depth.unsqueeze(1),
        ...     size=image.size[::-1],
        ...     mode="bicubic",
        ...     align_corners=False,
        ... )

        >>> # visualize the prediction
        >>> output = prediction.squeeze().cpu().numpy()
        >>> formatted = (output * 255 / np.max(output)).astype("uint8")
        >>> depth = Image.fromarray(formatted)
        ```NTr
  r   c                    s    g | ]\}}| j jkr|qS r&   rI   Zbackbone_out_indicesrf   idxr  r  r&   r'   ri   p  s     z1DPTForDepthEstimation.forward.<locals>.<listcomp>r.   c                 3   s*   | ]"\}}| j jd d kr|V  qdS r   Nr'  r(  r  r&   r'   r   u  s     z0DPTForDepthEstimation.forward.<locals>.<genexpr>zTraining is not implemented yetr   )lossr   r+   r,   )rI   r  rr   r   r+   r   r   r   r  extendr#  r  NotImplementedErrorr   r,   )rH   r_   r   r&  r   rr   ra   r   r+   backbone_hidden_statesr   r+  r   r&   r  r'   ru   +  sH    3


zDPTForDepthEstimation.forward)NNNNN)r   r    r!   r3   r   r  r   r   r  r#   r$   r   
LongTensorrw   r   r   rv   ru   rx   r&   r&   rK   r'   r!    s"   
     r!  c                       s4   e Zd Z fddZeej ejdddZ  ZS )DPTSemanticSegmentationHeadc                    sl   t    || _|j}ttj||ddddt|td t	|j
tj||jddtjddd	d
| _d S )Nr
   r   Fr  r   r0   r   rN   Tr   )r2   r3   rI   r   r   r   rB   r   r   r|   Zsemantic_classifier_dropout
num_labelsr  r  r  rK   r&   r'   r3     s    

z$DPTSemanticSegmentationHead.__init__r   c                 C   s   || j j }| |}|S r   )rI   r  r  rH   r+   logitsr&   r&   r'   ru     s    
z#DPTSemanticSegmentationHead.forward)	r   r    r!   r3   r   r#   rv   ru   rx   r&   r&   rK   r'   r0    s   r0  c                       s$   e Zd Z fddZdd Z  ZS )DPTAuxiliaryHeadc                    sX   t    |j}ttj||ddddt|td tddtj||j	dd| _
d S )Nr
   r   Fr  r   g?r0   )r2   r3   r   r   r   rB   r   r   r|   r1  r  r  rK   r&   r'   r3     s    

zDPTAuxiliaryHead.__init__c                 C   s   |  |}|S r   )r  r2  r&   r&   r'   ru     s    
zDPTAuxiliaryHead.forwardr   r&   r&   rK   r'   r4    s   r4  zY
    DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
    c                       sx   e Zd Z fddZeeeeedde	e
j e	e
j e	e
j e	e e	e e	e eee
j ef dddZ  ZS )	DPTForSemanticSegmentationc                    sN   t  | t|dd| _t|| _t|| _|jr<t	|nd | _
|   d S r"  )r2   r3   r   r   r  r#  r0  r  Zuse_auxiliary_headr4  auxiliary_headr  r   rK   r&   r'   r3     s    

z#DPTForSemanticSegmentation.__init__r$  Nr%  c                    s  |dk	r|n j j}|dk	r |n j j} j|||d|d}|rF|jn|d } j jsv fddt|dd D }n>|r|jn
t|d }	|		 fdd	t|dd D  |	} 
|} |}
d} jdk	r |d }d}|dk	r j jdkr
td
nxtjj|
|jdd ddd}|dk	rPtjj||jdd ddd}t j jd}|||}|||}| j j|  }|s|r|
f|dd  }n|
f|dd  }|dk	r|f| S |S t||
|r|jnd|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
        >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        ```NTr
  r   c                    s    g | ]\}}| j jkr|qS r&   r'  r(  r  r&   r'   ri   	  s     z6DPTForSemanticSegmentation.forward.<locals>.<listcomp>r.   c                 3   s*   | ]"\}}| j jd d kr|V  qdS r*  r'  r(  r  r&   r'   r     s     z5DPTForSemanticSegmentation.forward.<locals>.<genexpr>z/The number of labels should be greater than oner/   rN   Fr   )Zignore_indexr   )r+  r3  r+   r,   )rI   r  rr   r   r+   r   r   r   r  r,  r#  r  r6  r1  r@   r   rY   rZ   rj   r	   Zsemantic_loss_ignore_indexZauxiliary_loss_weightr   r,   )rH   r_   r   r&  r   rr   ra   r   r+   r.  r3  Zauxiliary_logitsr+  Zupsampled_logitsZupsampled_auxiliary_logitsZloss_fctZ	main_lossZauxiliary_lossr   r&   r  r'   ru     sr    #





   
   

z"DPTForSemanticSegmentation.forward)NNNNNN)r   r    r!   r3   r   r  r   r   r  r   r#   r$   r/  rw   r   r   rv   ru   rx   r&   r&   rK   r'   r5    s$   
      r5  )Lr"   collections.abcr9   rU   dataclassesr   typingr   r   r   r   r   r#   Ztorch.utils.checkpointr   Ztorch.nnr	   Zactivationsr   Z
file_utilsr   r   r   r   Zmodeling_outputsr   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   autor   Zconfiguration_dptr   Z
get_loggerr   loggerr  r  r  Z!DPT_PRETRAINED_MODEL_ARCHIVE_LISTr   r(   Moduler-   ry   rz   r   r   r   r   r   r   r   r   r   r   r   r   r   ZDPT_START_DOCSTRINGr  r   r  r  r  r!  r0  r4  r5  r&   r&   r&   r'   <module>   s   

#e: =*+9c7%[+!v