U
    *-e                     @   s2  d dl mZmZ d dlZd dlmZ d dlmZmZmZ ddl	m
Z ddl	mZ dd	lmZ dd
lmZ ddlmZ dddddddgZG dd deZG dd deZG dd deeZG dd deZG dd deeZG dd deZG dd deeZG dd deZG dd deeZG dd deZdS )     )OptionalAnyN)Tensor)	ParameterUninitializedParameterUninitializedBuffer   )
functional)init   )SyncBatchNorm)LazyModuleMixin)ModuleBatchNorm1dLazyBatchNorm1dBatchNorm2dLazyBatchNorm2dBatchNorm3dLazyBatchNorm3dr   c                       s   e Zd ZU dZdZdddddgZeed< eed< eed< e	ed< e	ed< deeee	e	dd fddZ
ddddZddddZdd Zdd Z fddZ  ZS )	_NormBasez+Common base of _InstanceNorm and _BatchNormr   track_running_statsmomentumepsnum_featuresaffineh㈵>皙?TNr   r   r   r   r   returnc           	   	      s  ||d}t    || _|| _|| _|| _|| _| jrbttj	|f|| _
ttj	|f|| _n| dd  | dd  | jr| dtj|f| | dtj|f| |  |  | dtjddtjid	d
 | D  |  n$| dd  | dd  | dd  |   d S )Ndevicedtypeweightbiasrunning_meanrunning_varnum_batches_trackedr   r!   c                 S   s   i | ]\}}|d kr||qS r!    .0kvr(   r(   [/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py
<dictcomp><   s       z&_NormBase.__init__.<locals>.<dictcomp>)r   )super__init__r   r   r   r   r   r   torchemptyr"   r#   Zregister_parameterZregister_bufferZzerosZonestensorlongitemsreset_parameters	selfr   r   r   r   r   r    r!   factory_kwargs	__class__r(   r-   r0      s6    


z_NormBase.__init__r   c                 C   s*   | j r&| j  | jd | j  d S Nr   )r   r$   Zzero_r%   Zfill_r&   r8   r(   r(   r-   reset_running_statsD   s    
z_NormBase.reset_running_statsc                 C   s*   |    | jr&t| j t| j d S N)r?   r   r
   Zones_r"   Zzeros_r#   r>   r(   r(   r-   r6   L   s    z_NormBase.reset_parametersc                 C   s   t d S r@   )NotImplementedErrorr8   inputr(   r(   r-   _check_input_dimR   s    z_NormBase._check_input_dimc                 C   s   dj f | jS )Nzj{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats})format__dict__r>   r(   r(   r-   
extra_reprU   s    z_NormBase.extra_reprc           
   	      sb   | dd }|d ks|dk rF| jrF|d }	|	|krFtjdtjd||	< t ||||||| d S )Nversionr   r&   r   r'   )getr   r1   r3   r4   r/   _load_from_state_dict)
r8   Z
state_dictprefixZlocal_metadatastrictZmissing_keysZunexpected_keysZ
error_msgsrH   Znum_batches_tracked_keyr:   r(   r-   rJ   [   s    
z_NormBase._load_from_state_dict)r   r   TTNN)__name__
__module____qualname____doc___versionZ__constants__int__annotations__floatboolr0   r?   r6   rD   rG   rJ   __classcell__r(   r(   r:   r-   r      s6   
      &r   c                       s>   e Zd Zdeeeeedd fddZeedd	d
Z  Z	S )
_BatchNormr   r   TNr   c           	         s&   ||d}t  j|||||f| d S Nr   )r/   r0   r7   r:   r(   r-   r0   z   s    

    z_BatchNorm.__init__rC   r   c              
   C   s   |  | | jd krd}n| j}| jrb| jrb| jd k	rb| jd | jd kr\dt| j }n| j}| jrnd}n| jd ko| jd k}t	
|| jr| jr| jnd | jr| jr| jnd | j| j||| jS )N        r         ?T)rD   r   trainingr   r&   add_rT   r$   r%   F
batch_normr"   r#   r   )r8   rC   exponential_average_factorbn_trainingr(   r(   r-   forward   s6    



z_BatchNorm.forward)r   r   TTNN)
rM   rN   rO   rR   rT   rU   r0   r   rb   rV   r(   r(   r:   r-   rW   y   s         rW   c                       sV   e Zd ZU eed< eed< ddd fdd	Zdd fd
dZddddZ  ZS )_LazyNormBaser"   r#   r   r   TNr<   c                    s   ||d}t  jd||ddf| || _|| _| jrLtf || _tf || _| jrtf || _tf || _	t
jddt
jidd | D | _d S )Nr   r   Fr!   c                 S   s   i | ]\}}|d kr||qS r'   r(   r)   r(   r(   r-   r.      s       z*_LazyNormBase.__init__.<locals>.<dictcomp>)r   )r/   r0   r   r   r   r"   r#   r   r$   r%   r1   r3   r4   r5   r&   )r8   r   r   r   r   r    r!   r9   r:   r(   r-   r0      s2    

 z_LazyNormBase.__init__c                    s    |   s| jdkrt   d S )Nr   )has_uninitialized_paramsr   r/   r6   r>   r:   r(   r-   r6      s    z_LazyNormBase.reset_parametersc                 C   s   |   r|jd | _| jrZt| jts*tt| jts:t| j	| jf | j	| jf | j
r| j	| jf | j	| jf |   d S r=   )rd   shaper   r   
isinstancer"   r   AssertionErrorr#   Zmaterializer   r$   r%   r6   rB   r(   r(   r-   initialize_parameters   s    z#_LazyNormBase.initialize_parameters)r   r   TTNN)	rM   rN   rO   r   rS   r0   r6   rh   rV   r(   r(   r:   r-   rc      s   
    rc   c                   @   s   e Zd ZdZdd ZdS )r   a  Applies Batch Normalization over a 2D or 3D input as described in the paper
    `Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .

    .. math::

        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
    of size `C` (where `C` is the number of features or channels of the input). By default, the
    elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
    At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
    equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
    moving average of the standard-deviation is calculated via the unbiased  estimator, equivalent to
    ``torch.var(input, unbiased=True)``.

    Also by default, during training this layer keeps running estimates of its
    computed mean and variance, which are then used for normalization during
    evaluation. The running estimates are kept with a default :attr:`momentum`
    of 0.1.

    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and batch statistics are instead used during
    evaluation time as well.

    .. note::
        This :attr:`momentum` argument is different from one used in optimizer
        classes and the conventional notion of momentum. Mathematically, the
        update rule for running statistics here is
        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
        new observed value.

    Because the Batch Normalization is done over the `C` dimension, computing statistics
    on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.

    Args:
        num_features: number of features or channels :math:`C` of the input
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics, and initializes statistics
            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
            When these buffers are ``None``, this module always uses batch statistics.
            in both training and eval modes. Default: ``True``

    Shape:
        - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
          :math:`C` is the number of features or channels, and :math:`L` is the sequence length
        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)

    Examples::

        >>> # With Learnable Parameters
        >>> m = nn.BatchNorm1d(100)
        >>> # Without Learnable Parameters
        >>> m = nn.BatchNorm1d(100, affine=False)
        >>> input = torch.randn(20, 100)
        >>> output = m(input)
    c                 C   s0   |  dkr,|  dkr,td|   dd S Nr      zexpected 2D or 3D input (got D input)dim
ValueErrorrB   r(   r(   r-   rD   .  s    zBatchNorm1d._check_input_dimNrM   rN   rO   rP   rD   r(   r(   r(   r-   r      s   Dc                   @   s   e Zd ZdZeZdd ZdS )r   a6  A :class:`torch.nn.BatchNorm1d` module with lazy initialization of
    the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
    from the ``input.size(1)``.
    The attributes that will be lazily initialized are `weight`, `bias`,
    `running_mean` and `running_var`.

    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
    on lazy modules and their limitations.

    Args:
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics, and initializes statistics
            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
            When these buffers are ``None``, this module always uses batch statistics.
            in both training and eval modes. Default: ``True``
    c                 C   s0   |  dkr,|  dkr,td|   dd S ri   rl   rB   r(   r(   r-   rD   Q  s    z LazyBatchNorm1d._check_input_dimN)rM   rN   rO   rP   r   cls_to_becomerD   r(   r(   r(   r-   r   5  s   c                   @   s   e Zd ZdZdd ZdS )r   a  Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
    with additional channel dimension) as described in the paper
    `Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .

    .. math::

        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
    of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
    to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
    standard-deviation is calculated via the biased estimator, equivalent to
    ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
    standard-deviation is calculated via the unbiased  estimator, equivalent to
    ``torch.var(input, unbiased=True)``.

    Also by default, during training this layer keeps running estimates of its
    computed mean and variance, which are then used for normalization during
    evaluation. The running estimates are kept with a default :attr:`momentum`
    of 0.1.

    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and batch statistics are instead used during
    evaluation time as well.

    .. note::
        This :attr:`momentum` argument is different from one used in optimizer
        classes and the conventional notion of momentum. Mathematically, the
        update rule for running statistics here is
        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
        new observed value.

    Because the Batch Normalization is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.

    Args:
        num_features: :math:`C` from an expected input of size
            :math:`(N, C, H, W)`
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics, and initializes statistics
            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
            When these buffers are ``None``, this module always uses batch statistics.
            in both training and eval modes. Default: ``True``

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

    Examples::

        >>> # With Learnable Parameters
        >>> m = nn.BatchNorm2d(100)
        >>> # Without Learnable Parameters
        >>> m = nn.BatchNorm2d(100, affine=False)
        >>> input = torch.randn(20, 100, 35, 45)
        >>> output = m(input)
    c                 C   s$   |  dkr td|   dd S N   zexpected 4D input (got rk   rl   rB   r(   r(   r-   rD     s    zBatchNorm2d._check_input_dimNro   r(   r(   r(   r-   r   X  s   Ec                   @   s   e Zd ZdZeZdd ZdS )r   a6  A :class:`torch.nn.BatchNorm2d` module with lazy initialization of
    the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
    from the ``input.size(1)``.
    The attributes that will be lazily initialized are `weight`, `bias`,
    `running_mean` and `running_var`.

    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
    on lazy modules and their limitations.

    Args:
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics, and initializes statistics
            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
            When these buffers are ``None``, this module always uses batch statistics.
            in both training and eval modes. Default: ``True``
    c                 C   s$   |  dkr td|   dd S rq   rl   rB   r(   r(   r-   rD     s    z LazyBatchNorm2d._check_input_dimN)rM   rN   rO   rP   r   rp   rD   r(   r(   r(   r-   r     s   c                   @   s   e Zd ZdZdd ZdS )r   a  Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
    with additional channel dimension) as described in the paper
    `Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .

    .. math::

        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
    of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
    to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
    standard-deviation is calculated via the biased estimator, equivalent to
    ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
    standard-deviation is calculated via the unbiased  estimator, equivalent to
    ``torch.var(input, unbiased=True)``.

    Also by default, during training this layer keeps running estimates of its
    computed mean and variance, which are then used for normalization during
    evaluation. The running estimates are kept with a default :attr:`momentum`
    of 0.1.

    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and batch statistics are instead used during
    evaluation time as well.

    .. note::
        This :attr:`momentum` argument is different from one used in optimizer
        classes and the conventional notion of momentum. Mathematically, the
        update rule for running statistics here is
        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
        new observed value.

    Because the Batch Normalization is done over the `C` dimension, computing statistics
    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
    or Spatio-temporal Batch Normalization.

    Args:
        num_features: :math:`C` from an expected input of size
            :math:`(N, C, D, H, W)`
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics, and initializes statistics
            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
            When these buffers are ``None``, this module always uses batch statistics.
            in both training and eval modes. Default: ``True``

    Shape:
        - Input: :math:`(N, C, D, H, W)`
        - Output: :math:`(N, C, D, H, W)` (same shape as input)

    Examples::

        >>> # With Learnable Parameters
        >>> m = nn.BatchNorm3d(100)
        >>> # Without Learnable Parameters
        >>> m = nn.BatchNorm3d(100, affine=False)
        >>> input = torch.randn(20, 100, 35, 45, 10)
        >>> output = m(input)
    c                 C   s$   |  dkr td|   dd S N   zexpected 5D input (got rk   rl   rB   r(   r(   r-   rD     s    zBatchNorm3d._check_input_dimNro   r(   r(   r(   r-   r     s   Fc                   @   s   e Zd ZdZeZdd ZdS )r   a6  A :class:`torch.nn.BatchNorm3d` module with lazy initialization of
    the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
    from the ``input.size(1)``.
    The attributes that will be lazily initialized are `weight`, `bias`,
    `running_mean` and `running_var`.

    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
    on lazy modules and their limitations.

    Args:
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics, and initializes statistics
            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
            When these buffers are ``None``, this module always uses batch statistics.
            in both training and eval modes. Default: ``True``
    c                 C   s$   |  dkr td|   dd S rs   rl   rB   r(   r(   r-   rD   ,  s    z LazyBatchNorm3d._check_input_dimN)rM   rN   rO   rP   r   rp   rD   r(   r(   r(   r-   r     s   c                	       sf   e Zd ZdZdeeeeeee dd fddZ	d	d
 Z
dd ZeedddZedddZ  ZS )r   a   Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
    with additional channel dimension) as described in the paper
    `Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .

    .. math::

        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated per-dimension over all
    mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
    are learnable parameter vectors of size `C` (where `C` is the input size).
    By default, the elements of :math:`\gamma` are sampled from
    :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
    The standard-deviation is calculated via the biased estimator, equivalent to
    `torch.var(input, unbiased=False)`.

    Also by default, during training this layer keeps running estimates of its
    computed mean and variance, which are then used for normalization during
    evaluation. The running estimates are kept with a default :attr:`momentum`
    of 0.1.

    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and batch statistics are instead used during
    evaluation time as well.

    .. note::
        This :attr:`momentum` argument is different from one used in optimizer
        classes and the conventional notion of momentum. Mathematically, the
        update rule for running statistics here is
        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
        new observed value.

    Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
    statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
    Normalization or Spatio-temporal Batch Normalization.

    Currently :class:`SyncBatchNorm` only supports
    :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
    :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
    :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
    Network with DDP.

    Args:
        num_features: :math:`C` from an expected input of size
            :math:`(N, C, +)`
        eps: a value added to the denominator for numerical stability.
            Default: ``1e-5``
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics, and initializes statistics
            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
            When these buffers are ``None``, this module always uses batch statistics.
            in both training and eval modes. Default: ``True``
        process_group: synchronization of stats happen within each process group
            individually. Default behavior is synchronization across the whole
            world

    Shape:
        - Input: :math:`(N, C, +)`
        - Output: :math:`(N, C, +)` (same shape as input)

    .. note::
        Synchronization of batchnorm statistics occurs only while training, i.e.
        synchronization is disabled when ``model.eval()`` is set or if
        ``self.training`` is otherwise ``False``.

    Examples::

        >>> # xdoctest: +SKIP
        >>> # With Learnable Parameters
        >>> m = nn.SyncBatchNorm(100)
        >>> # creating process group (optional)
        >>> # ranks is a list of int identifying rank ids.
        >>> ranks = list(range(8))
        >>> r1, r2 = ranks[:4], ranks[4:]
        >>> # Note: every rank calls into new_group for every
        >>> # process group created, even if that rank is not
        >>> # part of the group.
        >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
        >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
        >>> # Without Learnable Parameters
        >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
        >>> input = torch.randn(20, 100, 35, 45, 10)
        >>> output = m(input)

        >>> # network is nn.BatchNorm layer
        >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
        >>> # only single gpu per process is currently supported
        >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
        >>>                         sync_bn_network,
        >>>                         device_ids=[args.local_rank],
        >>>                         output_device=args.local_rank)
    r   r   TN)r   r   r   r   r   process_groupr   c	           
         s,   ||d}	t  j|||||f|	 || _d S rX   )r/   r0   ru   )
r8   r   r   r   r   r   ru   r    r!   r9   r:   r(   r-   r0     s    
    zSyncBatchNorm.__init__c                 C   s$   |  dk r td|   dd S )Nr   z expected at least 2D input (got rk   rl   rB   r(   r(   r-   rD     s    zSyncBatchNorm._check_input_dimc                 C   s   | ddkrtdd S )Nr   r   z9SyncBatchNorm number of input channels should be non-zero)sizern   rB   r(   r(   r-   _check_non_zero_input_channels  s    z,SyncBatchNorm._check_non_zero_input_channelsrY   c           	      C   s  |  | | | | jd kr$d}n| j}| jrp| jrp| jd k	sDt| jd | jd krjd| j  }n| j}| jr|d}n| j	d ko| j
d k}| jr| jr| j	nd }| jr| jr| j
nd }|o| jotj otj }|r8|jjdtj fkrtdtj  tjjj}| jr$| j}tj|}|dk}|s\t|||| j| j||| jS |sftt|| j| j||| j|||	S d S )NrZ   r   r[   Tcudaz4SyncBatchNorm expected input tensor to be on GPU or )rD   rw   r   r\   r   r&   rg   r]   itemr$   r%   r1   distributedZis_availableZis_initializedr    typeZ_CZ_get_privateuse1_backend_namern   groupZWORLDru   Zget_world_sizer^   r_   r"   r#   r   sync_batch_normapply)	r8   rC   r`   ra   r$   r%   Z	need_syncru   Z
world_sizer(   r(   r-   rb     sl    



	


zSyncBatchNorm.forwardc              	   C   s   |}t |tjjjjrtj|j|j|j	|j
|j|}|j
r`t  |j|_|j|_W 5 Q R X |j|_|j|_|j|_t|dr|j|_| D ]\}}||| || q~|S )a{  Helper function to convert all :attr:`BatchNorm*D` layers in the model to
        :class:`torch.nn.SyncBatchNorm` layers.

        Args:
            module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
            process_group (optional): process group to scope synchronization,
                default is the whole world

        Returns:
            The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
            layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
            a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
            instead.

        Example::

            >>> # Network with nn.BatchNorm layer
            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
            >>> module = torch.nn.Sequential(
            >>>            torch.nn.Linear(20, 100),
            >>>            torch.nn.BatchNorm1d(100),
            >>>          ).cuda()
            >>> # creating process group (optional)
            >>> # ranks is a list of int identifying rank ids.
            >>> ranks = list(range(8))
            >>> r1, r2 = ranks[:4], ranks[4:]
            >>> # Note: every rank calls into new_group for every
            >>> # process group created, even if that rank is not
            >>> # part of the group.
            >>> # xdoctest: +SKIP("distributed")
            >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
            >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
            >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)

        qconfig)rf   r1   nnmodulesZ	batchnormrW   r   r   r   r   r   r   Zno_gradr"   r#   r$   r%   r&   hasattrr   Znamed_childrenZ
add_moduleconvert_sync_batchnorm)clsmoduleru   Zmodule_outputnamechildr(   r(   r-   r     s4    %

 
z$SyncBatchNorm.convert_sync_batchnorm)r   r   TTNNN)N)rM   rN   rO   rP   rR   rT   rU   r   r   r0   rD   rw   r   rb   classmethodr   rV   r(   r(   r:   r-   r   1  s,   h       S)typingr   r   r1   r   Ztorch.nn.parameterr   r   r    r	   r^   r
   Z
_functionsr   r}   Zlazyr   r   r   __all__r   rW   rc   r   r   r   r   r   r   r(   r(   r(   r-   <module>   s,   
 hA/L#K!L!