U
    ,-e\                     @   sj  d Z ddlZddlZddlmZ ddlmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZmZ ddlmZ ddl m!Z! e"e#Z$dZ%dZ&dddgZ'dZ(dZ)dgZ*eG dd deZ+eG dd deZ,eG dd deZ-eG dd deZ.G dd dej/Z0G dd dej/Z1dAe
j2e3e4e
j2d"d#d$Z5G d%d& d&ej/Z6G d'd( d(ej/Z7G d)d* d*ej/Z8G d+d, d,ej/Z9G d-d. d.ej/Z:G d/d0 d0ej/Z;G d1d2 d2eZ<d3Z=d4Z>ed5e=G d6d7 d7e<Z?ed8e=G d9d: d:e<Z@ed;e=G d<d= d=e<ZAed>e=G d?d@ d@e<eZBdS )Bz PyTorch FocalNet model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutput)PreTrainedModel)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings)BackboneMixin   )FocalNetConfigr   zmicrosoft/focalnet-tiny1   i   ztabby, tabby catc                   @   sL   e Zd ZU dZdZejed< dZe	e
ej  ed< dZe	e
ej  ed< dS )FocalNetEncoderOutputa  
    FocalNet encoder's outputs, with potential hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlast_hidden_statehidden_statesreshaped_hidden_states)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r   r    r#   r#   o/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/focalnet/modeling_focalnet.pyr   ?   s   
r   c                   @   s^   e Zd ZU dZdZejed< dZe	ej ed< dZ
e	eej  ed< dZe	eej  ed< dS )FocalNetModelOutputa  
    FocalNet model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
            Average pooling of the last layer hidden-state.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr   pooler_outputr   r   )r   r   r   r   r   r    r!   r"   r&   r   r   r   r   r#   r#   r#   r$   r%   Z   s
   
r%   c                   @   s^   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS )!FocalNetMaskedImageModelingOutputa  
    FocalNet masked image model outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Masked image modeling (MLM) loss.
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Reconstructed pixel values.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlossreconstructionr   r   )r   r   r   r   r(   r   r    r!   r"   r)   r   r   r   r#   r#   r#   r$   r'   w   s
   
r'   c                   @   s^   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS )FocalNetImageClassifierOutputaS  
    FocalNet outputs for image classification.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr(   logitsr   r   )r   r   r   r   r(   r   r    r!   r"   r+   r   r   r   r#   r#   r#   r$   r*      s
   
r*   c                       sH   e Zd ZdZd	 fdd	Zd
eej eej e	ej
 dddZ  ZS )FocalNetEmbeddingszX
    Construct the patch embeddings and layernorm. Optionally, also the mask token.
    Fc              	      s|   t    t||j|j|j|j|jdd| _| jj	| _
|rPttdd|jnd | _tj|j|jd| _t|j| _d S )NT)config
image_size
patch_sizenum_channels	embed_dimuse_conv_embedis_stemr   Zeps)super__init__FocalNetPatchEmbeddingsr.   r/   r0   r1   r2   patch_embeddings	grid_size
patch_gridr   	Parameterr    Zzeros
mask_token	LayerNormlayer_norm_epsnormDropouthidden_dropout_probdropout)selfr-   use_mask_token	__class__r#   r$   r6      s    
	
 zFocalNetEmbeddings.__init__N)pixel_valuesbool_masked_posreturnc           
      C   st   |  |\}}| |}| \}}}|d k	rb| j||d}|d|}	|d|	  ||	  }| |}||fS )N      ?)r8   r?   sizer<   expand	unsqueezeZtype_asrB   )
rC   rG   rH   
embeddingsoutput_dimensions
batch_sizeZseq_len_Zmask_tokensmaskr#   r#   r$   forward   s    

zFocalNetEmbeddings.forward)F)N)r   r   r   r   r6   r   r    r!   
BoolTensorr   TensorrT   __classcell__r#   r#   rE   r$   r,      s     r,   c                       sJ   e Zd Zd	 fdd	Zdd Zeej eej	ee
 f dddZ  ZS )
r7   Fc	                    s  t    t|tjjr|n||f}t|tjjr6|n||f}|d |d  |d |d   }	|| _|| _|| _|	| _	|d |d  |d |d  f| _
|r|rd}
d}d}nd}
d}d}tj|||
||d| _ntj||||d| _|rtj||jd	| _nd | _d S )
Nr   r            r
   )kernel_sizestridepadding)r[   r\   r4   )r5   r6   
isinstancecollectionsabcIterabler.   r/   r0   num_patchesr9   r   Conv2d
projectionr=   r>   r?   )rC   r-   r.   r/   r0   r1   add_normr2   r3   rb   r[   r]   r\   rE   r#   r$   r6      s8    
 "    
z FocalNetPatchEmbeddings.__init__c                 C   s   || j d  dkr<d| j d || j d   f}tj||}|| j d  dkr|ddd| j d || j d   f}tj||}|S )Nr   r   )r/   r   
functionalpad)rC   rG   heightwidthZ
pad_valuesr#   r#   r$   	maybe_pad  s     z!FocalNetPatchEmbeddings.maybe_pad)rG   rI   c                 C   s|   |j \}}}}|| jkr td| |||}| |}|j \}}}}||f}|ddd}| jd k	rt| |}||fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rY   r   )shaper0   
ValueErrorrj   rd   flatten	transposer?   )rC   rG   rR   r0   rh   ri   rO   rP   r#   r#   r$   rT     s    



zFocalNetPatchEmbeddings.forward)FFF)r   r   r   r6   rj   r   r    r!   r   rV   intrT   rW   r#   r#   rE   r$   r7      s      *	r7           F)input	drop_probtrainingrI   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rp   r   r   )r   )dtypedevice)rk   ndimr    Zrandrt   ru   Zfloor_div)rq   rr   rs   Z	keep_probrk   Zrandom_tensoroutputr#   r#   r$   	drop_path!  s    
ry   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )FocalNetDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)rr   rI   c                    s   t    || _d S N)r5   r6   rr   )rC   rr   rE   r#   r$   r6   9  s    
zFocalNetDropPath.__init__)r   rI   c                 C   s   t || j| jS r{   )ry   rr   rs   )rC   r   r#   r#   r$   rT   =  s    zFocalNetDropPath.forward)rI   c                 C   s   d | jS )Nzp={})formatrr   rC   r#   r#   r$   
extra_repr@  s    zFocalNetDropPath.extra_repr)N)r   r   r   r   r   floatr6   r    rV   rT   strr~   rW   r#   r#   rE   r$   rz   6  s   rz   c                       s&   e Zd Zd fdd	Zdd Z  ZS )	FocalNetModulationrY   Trp   c           	         s"  t    || _|j| | _|j| | _|| _|j| _|j	| _	t
j|d| | jd  |d| _t
j||dd|d| _t
 | _t
||| _t
|| _t
 | _g | _t| jD ]P}| j| | j }| jt
t
j|||d||d ddt
  | j| q| jrt
j||jd| _d S )NrY   r   )bias)r[   r\   r   F)r[   r\   groupsr]   r   r4   )r5   r6   dimZfocal_windowsZfocal_windowZfocal_levelsfocal_levelfocal_factor use_post_layernorm_in_modulationnormalize_modulatorr   Linearprojection_inrc   projection_contextZGELU
activationprojection_outr@   projection_dropout
ModuleListfocal_layersZkernel_sizesrangeappend
Sequentialr=   r>   	layernorm)	rC   r-   indexr   r   r   r   kr[   rE   r#   r$   r6   E  sB    
 

      zFocalNetModulation.__init__c           
      C   s0  |j d }| |dddd }t|||| jd fd\}}| _d}t| jD ]4}| j	| |}||| jdd||d f   }qT| 
|jdddjddd}||| jdd| jdf   }| jr|| jd  }| || _|| j }	|	dddd }	| jr| |	}	| |	}	| |	}	|	S )	z
        Args:
            hidden_state:
                Input features with shape of (batch_size, height, width, num_channels)
        rJ   r   r
   r   rY   NT)Zkeepdim)rk   r   permute
contiguousr    splitr   Zgatesr   r   r   meanr   r   Z	modulatorr   r   r   r   )
rC   hidden_stater0   xqctxZctx_alllevelZ
ctx_globalZx_outr#   r#   r$   rT   f  s&    
"$ 



zFocalNetModulation.forward)rY   Trp   r   r   r   r6   rT   rW   r#   r#   rE   r$   r   D  s   !r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )FocalNetMlpNrp   c                    sR   t    |p|}|p|}t||| _t|j | _t||| _t	|| _
d S r{   )r5   r6   r   r   fc1r   Z
hidden_actr   fc2r@   drop)rC   r-   in_featureshidden_featuresout_featuresr   rE   r#   r$   r6     s    
zFocalNetMlp.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S r{   )r   r   r   r   )rC   r   r#   r#   r$   rT     s    




zFocalNetMlp.forward)NNrp   r   r#   r#   rE   r$   r     s   	r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )FocalNetLayera  Focal Modulation Network layer (block).

    Args:
        config (`FocalNetConfig`):
            Model config.
        index (`int`):
            Layer index.
        dim (`int`):
            Number of input channels.
        input_resolution (`Tuple[int]`):
            Input resulotion.
        drop_path (`float`, *optional*, defaults to 0.0):
            Stochastic depth rate.
    rp   c                    s   t    || _|| _|| _|j| _|j| _tj	||j
d| _t|||| jd| _|dkrbt|nt | _tj	||j
d| _t||j }t|||| jd| _d| _d| _|jrtj|jt| dd| _tj|jt| dd| _d S )Nr4   )r-   r   r   r   rp   )r-   r   r   r   rK   T)Zrequires_grad)r5   r6   r-   r   input_resolutionrA   r   use_post_layernormr   r=   r>   norm1r   
modulationrz   Identityry   norm2ro   Z	mlp_ratior   mlpgamma_1gamma_2Zuse_layerscaler;   Zlayerscale_valuer    Zones)rC   r-   r   r   r   ry   Zmlp_hidden_dimrE   r#   r$   r6     s,    
zFocalNetLayer.__init__c           	   	   C   s   |\}}|j \}}}|}| jr"|n| |}|||||}| |||| |}| js^|n| |}|| | j|  }|| | j| jr| | 	|n| 	| |  }|S r{   )
rk   r   r   viewr   ry   r   r   r   r   )	rC   r   input_dimensionsrh   ri   rQ   rR   r0   Zshortcutr#   r#   r$   rT     s    $zFocalNetLayer.forward)rp   )r   r   r   r   r6   rT   rW   r#   r#   rE   r$   r     s    r   c                       s>   e Zd Z fddZejeeef eej dddZ  Z	S )FocalNetStagec              
      s   t     | _t j| _ fddt| jD }| | jd k rV|d  nd }| jd k rltnd }dd t	d j
t jD }|t jd  t jd d   t fddt j D | _|d k	r| d|d jd	d
| _nd | _d	| _d S )Nc                    s   g | ]} j d |  qS )rY   )r1   .0ir-   r#   r$   
<listcomp>  s     z*FocalNetStage.__init__.<locals>.<listcomp>r   c                 S   s   g | ]}|  qS r#   )item)r   r   r#   r#   r$   r     s     r   c              
      s0   g | ](}t  ttr$| nd qS ))r-   r   r   r   ry   )r   r^   listr   r-   r   ry   r   r   r#   r$   r     s   rY   TF)r-   r.   r/   r0   r1   re   r2   r3   )r5   r6   r-   lendepths
num_stagesr   r7   r    ZlinspaceZdrop_path_ratesumr   r   layersr2   
downsampleZpointing)rC   r-   r   r   r1   Zout_dimr   ZdprrE   r   r$   r6     s6    
 ,

zFocalNetStage.__init__)r   r   rI   c           	      C   s|   |\}}| j D ]}|||}q|}| jd k	rb|\}}|dd|jd d||}| |\}}n||||f}|||f}|S )Nr   rY   r   rJ   )r   r   rn   reshaperk   )	rC   r   r   rh   ri   Zlayer_module!hidden_states_before_downsamplingrP   stage_outputsr#   r#   r$   rT     s     

   
zFocalNetStage.forward)
r   r   r   r6   r    rV   r   ro   rT   rW   r#   r#   rE   r$   r     s   ,r   c                	       sT   e Zd Z fddZdejeeef ee	 ee	 ee	 e
eef dddZ  ZS )	FocalNetEncoderc                    sH   t    t j| _ | _t fddt| jD | _	d| _
d S )Nc              	      s6   g | ].}t  |d  d|  d d|  fdqS )r   rY   r   )r-   r   r   )r   )r   Zi_layerr-   r9   r#   r$   r   ,  s   z,FocalNetEncoder.__init__.<locals>.<listcomp>F)r5   r6   r   r   r   r-   r   r   r   stagesgradient_checkpointing)rC   r-   r9   rE   r   r$   r6   &  s    
zFocalNetEncoder.__init__FT)r   r   output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrI   c                 C   s  |rdnd }|rdnd }|r`|j \}}	}
|j|f||
f }|dddd}||f7 }||f7 }t| jD ]\}}| jr| jrdd }tjj		||||}n
|||}|d }|d }|d }|d |d	 f}|r0|r0|j \}}	}
|j|f|d |d f|
f }|dddd}||f7 }||f7 }qj|rj|sj|j \}}	}
|j|f||
f }|dddd}||f7 }||f7 }qj|st
d
d ||fD S t|||dS )Nr#   r   r
   r   rY   c                    s    fdd}|S )Nc                     s    |  S r{   r#   )inputsmoduler#   r$   custom_forwardO  s    zNFocalNetEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr#   )r   r   r#   r   r$   create_custom_forwardN  s    z6FocalNetEncoder.forward.<locals>.create_custom_forwardrJ   c                 s   s   | ]}|d k	r|V  qd S r{   r#   )r   vr#   r#   r$   	<genexpr>u  s      z*FocalNetEncoder.forward.<locals>.<genexpr>)r   r   r   )rk   r   r   	enumerater   r   rs   r    utils
checkpointtupler   )rC   r   r   r   r   r   Zall_hidden_statesZall_reshaped_hidden_statesrQ   rR   Zhidden_sizeZreshaped_hidden_stater   Zstage_moduler   r   r   rP   r#   r#   r$   rT   8  sZ    


 

zFocalNetEncoder.forward)FFT)r   r   r   r6   r    rV   r   ro   r   boolr   r   rT   rW   r#   r#   rE   r$   r   %  s      

r   c                   @   s2   e Zd ZdZeZdZdZdZdd Z	ddd	Z
d
S )FocalNetPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    focalnetrG   Tc                 C   sj   t |tjtjfr@|jjjd| jjd |j	dk	rf|j	j
  n&t |tjrf|j	j
  |jjd dS )zInitialize the weightsrp   )r   ZstdNrK   )r^   r   r   rc   weightdataZnormal_r-   Zinitializer_ranger   Zzero_r=   Zfill_)rC   r   r#   r#   r$   _init_weights  s    
z%FocalNetPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r{   )r^   r   r   )rC   r   valuer#   r#   r$   _set_gradient_checkpointing  s    
z3FocalNetPreTrainedModel._set_gradient_checkpointingN)F)r   r   r   r   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r   r#   r#   r#   r$   r     s   r   aK  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`FocalNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aB  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`AutoImageProcessor.__call__`] for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zVThe bare FocalNet Model outputting raw hidden-states without any specific head on top.c                
       st   e Zd Zd fdd	Zdd Zeeeee	e
deddeej eej ee ee eee	f d
ddZ  ZS )FocalNetModelTFc                    s   t  | || _t|j| _t|jd| jd   | _t	||d| _
t|| j
j| _tj| j|jd| _|rxtdnd | _|   d S )NrY   r   )rD   r4   )r5   r6   r-   r   r   r   ro   r1   num_featuresr,   rO   r   r:   encoderr   r=   r>   r   ZAdaptiveAvgPool1dpooler	post_init)rC   r-   add_pooling_layerrD   rE   r#   r$   r6     s    zFocalNetModel.__init__c                 C   s   | j jS r{   )rO   r8   r}   r#   r#   r$   get_input_embeddings  s    z"FocalNetModel.get_input_embeddingsZvision)r   output_typer   Zmodalityexpected_outputNrG   rH   r   r   rI   c                 C   s   |dk	r|n| j j}|dk	r |n| j j}|dkr8td| j||d\}}| j||||d}|d }| |}d}	| jdk	r| |dd}	t	
|	d}	|s||	f|dd  }
|
S t||	|j|jdS )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)rH   r   r   r   r   rY   )r   r&   r   r   )r-   r   use_return_dictrl   rO   r   r   r   rn   r    rm   r%   r   r   )rC   rG   rH   r   r   Zembedding_outputr   Zencoder_outputssequence_outputpooled_outputrx   r#   r#   r$   rT     s6    

zFocalNetModel.forward)TF)NNNN)r   r   r   r6   r   r   FOCALNET_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr%   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r    r!   rU   r   r   r   rT   rW   r#   r#   rE   r$   r     s*   	    
r   a|  FocalNet Model with a decoder on top for masked image modeling.

    This follows the same implementation as in [SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                
       sd   e Zd Z fddZeeeeedde	e
j e	e
j e	e e	e eeef dddZ  ZS )	FocalNetForMaskedImageModelingc                    sz   t  | t|ddd| _t|j| _t|jd| jd   }t	
t	j||jd |j ddt	|j| _|   d S )NFT)r   rD   rY   r   )Zin_channelsZout_channelsr[   )r5   r6   r   r   r   r   r   ro   r1   r   r   rc   Zencoder_strider0   ZPixelShuffledecoderr   )rC   r-   r   rE   r#   r$   r6     s      
z'FocalNetForMaskedImageModeling.__init__r   r   Nr   c                 C   s8  |dk	r|n| j j}| j||||d}|d }|dd}|j\}}}	t|	d  }
}||||
|}| |}d}|dk	r| j j	| j j
 }|d||}|| j j
d| j j
dd }tjj||dd	}||  | d
  | j j }|s$|f|dd  }|dk	r |f| S |S t|||j|jdS )aQ  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192")
        >>> config = FocalNetConfig()
        >>> model = FocalNetForMaskedImageModeling(config)

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 192, 192]
        ```N)rH   r   r   r   r   rY   g      ?rJ   none)Z	reductiongh㈵>)r(   r)   r   r   )r-   r   r   rn   rk   mathfloorr   r   r.   r/   Zrepeat_interleaverN   r   r   rf   Zl1_lossr   r0   r'   r   r   )rC   rG   rH   r   r   outputsr   rQ   r0   Zsequence_lengthrh   ri   Zreconstructed_pixel_valuesZmasked_im_lossrL   rS   Zreconstruction_lossrx   r#   r#   r$   rT   $  sF    '
  z&FocalNetForMaskedImageModeling.forward)NNNN)r   r   r   r6   r   r   r   r'   r   r   r    r!   rU   r   r   r   rT   rW   r#   r#   rE   r$   r     s   
    
r   z
    FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for
    ImageNet.
    c                
       sh   e Zd Z fddZeeeeee	e
ddeej eej ee ee eeef dddZ  ZS )	FocalNetForImageClassificationc                    sP   t  | |j| _t|| _|jdkr:t| jj|jnt | _	| 
  d S )Nr   )r5   r6   
num_labelsr   r   r   r   r   r   
classifierr   rC   r-   rE   r#   r$   r6     s    
"z'FocalNetForImageClassification.__init__)r   r   r   r   N)rG   labelsr   r   rI   c                 C   sf  |dk	r|n| j j}| j|||d}|d }| |}d}|dk	r"| j jdkr| jdkrdd| j _n4| jdkr|jtjks|jtj	krd| j _nd| j _| j jdkrt
 }	| jdkr|	| | }n
|	||}nN| j jdkrt }	|	|d| j|d}n| j jdkr"t }	|	||}|sR|f|dd  }
|dk	rN|f|
 S |
S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrJ   rY   )r(   r+   r   r   )r-   r   r   r   Zproblem_typer   rt   r    longro   r	   Zsqueezer   r   r   r*   r   r   )rC   rG   r   r   r   r   r   r+   r(   Zloss_fctrx   r#   r#   r$   rT     sH    



"


z&FocalNetForImageClassification.forward)NNNN)r   r   r   r6   r   r   r   _IMAGE_CLASS_CHECKPOINTr*   r   _IMAGE_CLASS_EXPECTED_OUTPUTr   r    r!   Z
LongTensorr   r   r   rT   rW   r#   r#   rE   r$   r   w  s&   	    
r   zG
    FocalNet backbone, to be used with frameworks like X-Decoder.
    c                       sV   e Zd Zed fddZeeeee	dd	e
jee ee edddZ  ZS )
FocalNetBackboner   c                    s>   t  | t  | |jg|j | _t|| _|   d S r{   )	r5   r6   Z_init_backboner1   Zhidden_sizesr   r   r   r   r   rE   r#   r$   r6     s
    
zFocalNetBackbone.__init__r   N)rG   r   r   rI   c           
      C   s   |dk	r|n| j j}|dk	r |n| j j}| j|ddd}|j}d}t| jD ] \}}|| jkrL||| f7 }qL|s|f}	|r|	|jf7 }	|	S t	||r|jndddS )a|  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
        >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        ```NTr   r#   )feature_mapsr   Z
attentions)
r-   r   r   r   r   r   Zstage_namesr   r   r   )
rC   rG   r   r   r   r   r  idxZstagerx   r#   r#   r$   rT     s&    
zFocalNetBackbone.forward)NN)r   r   r   r   r6   r   r   r   r   r   r    rV   r   r   rT   rW   r#   r#   rE   r$   r    s   

  r  )rp   F)Cr   collections.abcr_   r   dataclassesr   typingr   r   r   r    Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zactivationsr   Zmodeling_outputsr   Zmodeling_utilsr   r   r   r   r   r   r   r   Zutils.backbone_utilsr   Zconfiguration_focalnetr   Z
get_loggerr   loggerr   r   r   r  r  Z&FOCALNET_PRETRAINED_MODEL_ARCHIVE_LISTr   r%   r'   r*   Moduler,   r7   rV   r   r   ry   rz   r   r   r   r   r   r   ZFOCALNET_START_DOCSTRINGr   r   r   r   r  r#   r#   r#   r$   <module>   s|    

(HGEBZLeS