# coding=utf-8
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OneFormer model configuration"""
from typing import Dict, Optional

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING


logger = logging.get_logger(__name__)

ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "shi-labs/oneformer_ade20k_swin_tiny": (
        "https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny/blob/main/config.json"
    ),
    # See all OneFormer models at https://huggingface.co/models?filter=oneformer
}


class OneFormerConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`OneFormerModel`]. It is used to instantiate a
    OneFormer model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the OneFormer
    [shi-labs/oneformer_ade20k_swin_tiny](https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny) architecture
    trained on [ADE20k-150](https://huggingface.co/datasets/scene_parse_150).

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`)
            The configuration of the backbone model.
        ignore_value (`int`, *optional*, defaults to 255)
            Values to be ignored in GT label while calculating loss.
        num_queries (`int`, *optional*, defaults to 150)
            Number of object queries.
        no_object_weight (`float`, *optional*, defaults to 0.1)
            Weight for no-object class predictions.
        class_weight (`float`, *optional*, defaults to 2.0)
            Weight for Classification CE loss.
        mask_weight (`float`, *optional*, defaults to 5.0)
            Weight for binary CE loss.
        dice_weight (`float`, *optional*, defaults to 5.0)
            Weight for dice loss.
        contrastive_weight (`float`, *optional*, defaults to 0.5)
            Weight for contrastive loss.
        contrastive_temperature (`float`, *optional*, defaults to 0.07)
            Initial value for scaling the contrastive logits.
        train_num_points (`int`, *optional*, defaults to 12544)
            Number of points to sample while calculating losses on mask predictions.
        oversample_ratio (`float`, *optional*, defaults to 3.0)
            Ratio to decide how many points to oversample.
        importance_sample_ratio (`float`, *optional*, defaults to 0.75)
            Ratio of points that are sampled via importance sampling.
        init_std (`float`, *optional*, defaults to 0.02)
            Standard deviation for normal intialization.
        init_xavier_std (`float`, *optional*, defaults to 0.02)
            Standard deviation for xavier uniform initialization.
        layer_norm_eps (`float`, *optional*, defaults to 1e-05)
            Epsilon for layer normalization.
        is_training (`bool`, *optional*, defaults to False)
            Whether to run in training or inference mode.
        use_auxiliary_loss (`bool`, *optional*, defaults to True)
            Whether to calculate loss using intermediate predictions from transformer decoder.
        output_auxiliary_logits (`bool`, *optional*, defaults to True)
            Whether to return intermediate predictions from transformer decoder.
        strides (`list`, *optional*, defaults to [4, 8, 16, 32])
            List containing the strides for feature maps in the encoder.
        task_seq_len (`int`, *optional*, defaults to 77)
            Sequence length for tokenizing text list input.
        text_encoder_width (`int`, *optional*, defaults to 256)
            Hidden size for text encoder.
        text_encoder_context_length (`int`, *optional*, defaults to 77):
            Input sequence length for text encoder.
        text_encoder_num_layers (`int`, *optional*, defaults to 6)
            Number of layers for transformer in text encoder.
        text_encoder_vocab_size (`int`, *optional*, defaults to 49408)
            Vocabulary size for tokenizer.
        text_encoder_proj_layers (`int`, *optional*, defaults to 2)
            Number of layers in MLP for project text queries.
        text_encoder_n_ctx (`int`, *optional*, defaults to 16)
            Number of learnable text context queries.
        conv_dim (`int`, *optional*, defaults to 256)
            Feature map dimension to map outputs from the backbone.
        mask_dim (`int`, *optional*, defaults to 256)
            Dimension for feature maps in pixel decoder.
        hidden_dim (`int`, *optional*, defaults to 256)
            Dimension for hidden states in transformer decoder.
        encoder_feedforward_dim (`int`, *optional*, defaults to 1024)
            Dimension for FFN layer in pixel decoder.
        norm (`str`, *optional*, defaults to `GN`)
            Type of normalization.
        encoder_layers (`int`, *optional*, defaults to 6)
            Number of layers in pixel decoder.
        decoder_layers (`int`, *optional*, defaults to 10)
            Number of layers in transformer decoder.
        use_task_norm (`bool`, *optional*, defaults to `True`)
            Whether to normalize the task token.
        num_attention_heads (`int`, *optional*, defaults to 8)
            Number of attention heads in transformer layers in the pixel and transformer decoders.
        dropout (`float`, *optional*, defaults to 0.1)
            Dropout probability for pixel and transformer decoders.
        dim_feedforward (`int`, *optional*, defaults to 2048)
            Dimension for FFN layer in transformer decoder.
        pre_norm (`bool`, *optional*, defaults to `False`)
            Whether to normalize hidden states before attention layers in transformer decoder.
        enforce_input_proj (`bool`, *optional*, defaults to `False`)
            Whether to project hidden states in transformer decoder.
        query_dec_layers (`int`, *optional*, defaults to 2)
            Number of layers in query transformer.
        common_stride (`int`, *optional*, defaults to 4)
            Common stride used for features in pixel decoder.

    Examples:
    ```python
    >>> from transformers import OneFormerConfig, OneFormerModel

    >>> # Initializing a OneFormer shi-labs/oneformer_ade20k_swin_tiny configuration
    >>> configuration = OneFormerConfig()
    >>> # Initializing a model (with random weights) from the shi-labs/oneformer_ade20k_swin_tiny style configuration
    >>> model = OneFormerModel(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """
    model_type = "oneformer"
    attribute_map = {"hidden_size": "hidden_dim"}

    def __init__(
        self,
        backbone_config: Optional[Dict] = None,
        ignore_value: int = 255,
        num_queries: int = 150,
        no_object_weight: int = 0.1,
        class_weight: float = 2.0,
        mask_weight: float = 5.0,
        dice_weight: float = 5.0,
        contrastive_weight: float = 0.5,
        contrastive_temperature: float = 0.07,
        train_num_points: int = 12544,
        oversample_ratio: float = 3.0,
        importance_sample_ratio: float = 0.75,
        init_std: float = 0.02,
        init_xavier_std: float = 1.0,
        layer_norm_eps: float = 1e-05,
        is_training: bool = False,
        use_auxiliary_loss: bool = True,
        output_auxiliary_logits: bool = True,
        strides: Optional[list] = [4, 8, 16, 32],
        task_seq_len: int = 77,
        text_encoder_width: int = 256,
        text_encoder_context_length: int = 77,
        text_encoder_num_layers: int = 6,
        text_encoder_vocab_size: int = 49408,
        text_encoder_proj_layers: int = 2,
        text_encoder_n_ctx: int = 16,
        conv_dim: int = 256,
        mask_dim: int = 256,
        hidden_dim: int = 256,
        encoder_feedforward_dim: int = 1024,
        norm: str = "GN",
        encoder_layers: int = 6,
        decoder_layers: int = 10,
        use_task_norm: bool = True,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        dim_feedforward: int = 2048,
        pre_norm: bool = False,
        enforce_input_proj: bool = False,
        query_dec_layers: int = 2,
        common_stride: int = 4,
        **kwargs,
    ):
        if backbone_config is None:
            logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
            backbone_config = CONFIG_MAPPING["swin"](
                image_size=224,
                in_channels=3,
                patch_size=4,
                embed_dim=96,
                depths=[2, 2, 6, 2],
                num_heads=[3, 6, 12, 24],
                window_size=7,
                drop_path_rate=0.3,
                use_absolute_embeddings=False,
                out_features=["stage1", "stage2", "stage3", "stage4"],
            )
        elif isinstance(backbone_config, dict):
            backbone_model_type = backbone_config.get("model_type")
            config_class = CONFIG_MAPPING[backbone_model_type]
            backbone_config = config_class.from_dict(backbone_config)

        self.backbone_config = backbone_config

        self.ignore_value = ignore_value
        self.num_queries = num_queries
        self.no_object_weight = no_object_weight
        self.class_weight = class_weight
        self.mask_weight = mask_weight
        self.dice_weight = dice_weight
        self.contrastive_weight = contrastive_weight
        self.contrastive_temperature = contrastive_temperature
        self.train_num_points = train_num_points
        self.oversample_ratio = oversample_ratio
        self.importance_sample_ratio = importance_sample_ratio
        self.init_std = init_std
        self.init_xavier_std = init_xavier_std
        self.layer_norm_eps = layer_norm_eps
        self.is_training = is_training
        self.use_auxiliary_loss = use_auxiliary_loss
        self.output_auxiliary_logits = output_auxiliary_logits
        self.strides = strides
        self.task_seq_len = task_seq_len
        self.text_encoder_width = text_encoder_width
        self.text_encoder_context_length = text_encoder_context_length
        self.text_encoder_num_layers = text_encoder_num_layers
        self.text_encoder_vocab_size = text_encoder_vocab_size
        self.text_encoder_proj_layers = text_encoder_proj_layers
        self.text_encoder_n_ctx = text_encoder_n_ctx
        self.conv_dim = conv_dim
        self.mask_dim = mask_dim
        self.hidden_dim = hidden_dim
        self.encoder_feedforward_dim = encoder_feedforward_dim
        self.norm = norm
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.use_task_norm = use_task_norm
        self.num_attention_heads = num_attention_heads
        self.dropout = dropout
        self.dim_feedforward = dim_feedforward
        self.pre_norm = pre_norm
        self.enforce_input_proj = enforce_input_proj
        self.query_dec_layers = query_dec_layers
        self.common_stride = common_stride
        self.num_hidden_layers = decoder_layers

        super().__init__(**kwargs)
