U
    *-e:                     @   st  d dl Z d dlmZ d dlmZ d dlmZmZ ejfddZ	d ejfddZ
d ejfdd	Zejejfd
dZejejfddZejfddZejfddZejfddZddejfddZejejfddZG dd deZG dd deZG dd deZG dd deZG d d! d!eZG d"d# d#eZG d$d% d%eZG d&d' d'eZG d(d) d)eZG d*d+ d+eZdS ),    N)Function)groupReduceOpc                 C   s   t ||| S )a  
    Broadcasts the tensor to the whole group.

    ``tensor`` must have the same number of elements in all processes
    participating in the collective.

    Arguments:
        tensor (Tensor): Data to be sent if ``src`` is the rank of current
            process.
        src (int): Source rank.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Received tensor from the broadcast op.

    )
_Broadcastapply)tensorsrcr    r	   `/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/torch/distributed/nn/functional.py	broadcast	   s    r   c                 C   s   t ||| S )aT  
    Gathers a list of tensors in a single process.

    Arguments:
        tensor (Tensor): Input tensor.
        dst (int, optional): Destination rank (default is 0).
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
    )_Gatherr   )r   dstr   r	   r	   r
   gather   s    r   c                 C   s   t j||f|  S )a  
    Scatters a list of tensors to all processes in a group.

    Each process will receive exactly one tensor and store its data in the
    ``tensor`` argument.

    Arguments:
        tensors (list[Tensor]): List of tensors to scatter on the source rank.
            Receivers must pass ``None`.
        src (int, optional): Source rank (default is 0).
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output tensor from the scatter operation.

    )_Scatterr   )tensorsr   r   r	   r	   r
   scatter,   s    r   c                 C   s   t |||| S )a  
    Reduces the tensor data across all machines.

    Only the process with rank ``dst`` is going to receive the final result.

    Arguments:
        tensor (Tensor): Input of the collective.
        dst (int): Destination rank.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective.

    )_Reducer   )r   r   opr   r	   r	   r
   reduce@   s    r   c                 C   s   t j||| f| S )a  
    Reduces, then scatters a list of tensors to all processes in a group.

    Arguments:
        output (Tensor): Output tensor.
        input_list (list[Tensor]): List of tensors to reduce and scatter.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective.

    )_Reduce_Scatterr   )outputZ
input_listr   r   r	   r	   r
   reduce_scatterU   s    r   c                 C   s   t || S )a  
    Gathers tensors from the whole group in a list.

    Arguments:
        tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple([Tensor]): Output of the collective.

    )
_AllGatherr   )r   r   r	   r	   r
   
all_gatherh   s    r   c                 C   s   t | ||S )a  
    Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.

    Args:
        output_tensor (Tensor): Output tensor. It should contain
            correctly-sized tensors to be used for output of the collective.
        input_tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Examples:
        >>> # All tensors below are of torch.int64 dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> # xdoctest: +SKIP("incorrect want text")
        >>> output_tensor = torch.zeros(2, dtype=torch.int64)
        >>> output_tensor
        [tensor([0, 0])] # Rank 0 and 1
        >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank
        >>> tensor
        tensor([1]) # Rank 0
        tensor([2]) # Rank 1
        >>> dist.all_gather_base(output_tensor, tensor)
        >>> output_tensor
        tensor([1,2]) # Rank 0
        tensor([1,2]) # Rank 1

    .. warning::
        `_all_gather_base` is experimental and subject to change.
        It is the caller's responsibility to ensure the output_tensor
        is correctly sized.

    )_AllGatherBaser   )output_tensorinput_tensorr   r	   r	   r
   _all_gather_basev   s    !r   c                 C   s   t j|| f| S )a  
    Each process scatters list of input tensors to all processes in a group and
    return gathered list of tensors in output list.

    Arguments:
        output_tensor_list (list[Tensor]): list of tensors to gather one per rank.
        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple([Tensor]): Output of the collective.

    )	_AlltoAllr   )Zoutput_tensor_listinput_tensor_listr   r	   r	   r
   
all_to_all   s    r    c                 C   s   t || |||S )a  
    Each process splits input tensor and then scatters the split list
    to all processes in a group. Then concatenate the received tensors from all
    the processes in the group and return single output tensor.

    Arguments:
        output (Tensor): Gathered concatenated output tensor.
        input (Tensor): Input tensor to scatter.
        output_split_sizes: (list[Int], optional): Output split sizes for dim 0
            if specified None or empty, dim 0 of ``output`` tensor must divide
            equally by ``world_size``.
        input_split_sizes: (list[Int], optional): Input split sizes for dim 0
            if specified None or empty, dim 0 of ``input`` tensor must divide
            equally by ``world_size``.

    Returns:
        Tensor: Output of the collective.

    )_AlltoAllSingler   )r   inputoutput_split_sizesinput_split_sizesr   r	   r	   r
   all_to_all_single   s        r%   c                 C   s   t ||| S )a*  
    Reduces the tensor data across all machines in such a way that all get
    the final result.

    After the call the returned tensor is going to be bitwise
    identical in all processes.

    Arguments:
        tensor (Tensor): Input of the collective.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective

    )
_AllReducer   )r   r   r   r	   r	   r
   
all_reduce   s    r'   c                   @   s$   e Zd Zedd Zedd ZdS )r   c                 C   s2   || _ || _t | _| }tj|||d |S Nr   )r   r   distget_rankrankcloner   )ctxr   r   r   r	   r	   r
   forward   s    
z_Broadcast.forwardc                 C   s4   t | jtj| j|}| j| jkr*|  d d |fS N)r   r   r   r   SUMr   r,   Zzero_)r.   grad_outputgxr	   r	   r
   backward   s    z_Broadcast.backwardN__name__
__module____qualname__staticmethodr/   r4   r	   r	   r	   r
   r      s   

r   c                   @   s$   e Zd Zedd Zedd ZdS )r   c                    sp   || _ || _ fddttj|dD }   tj|d|krVtj |||d ntj d ||d t|S )Nc                    s   g | ]}t  qS r	   )torch
zeros_like.0ir   r	   r
   
<listcomp>   s    z#_Gather.forward.<locals>.<listcomp>r)   )	r   r   ranger*   get_world_size
contiguousr+   r   tuple)r.   r   r   r   tensor_listr	   r?   r
   r/      s    
z_Gather.forwardc                 G   s   dt j| j| jf| f S NNN)r   r   r   r   )r.   grad_outputsr	   r	   r
   r4     s    z_Gather.backwardNr5   r	   r	   r	   r
   r      s   
r   c                   @   s$   e Zd Zedd Zedd ZdS )r   c                    sr   || _ || _t fdd D s&tt d }tj|d|kr\tj|t	 ||d ntj|d ||d |S )Nc                 3   s"   | ]}|   d    kV  qdS )r   Nsizer=   tr   r	   r
   	<genexpr>  s     z#_Scatter.forward.<locals>.<genexpr>r   r)   )
r   r   allAssertionErrorr:   r;   r*   r+   r   list)r.   r   r   r   r   r	   rM   r
   r/     s    z_Scatter.forwardc                 C   s   dt | j| j| S rF   )r   r   r   r   r.   r2   r	   r	   r
   r4     s    z_Scatter.backwardNr5   r	   r	   r	   r
   r     s   
r   c                   @   s$   e Zd Zedd Zedd ZdS )r   c                 C   s*   || _ || _| }tj||||d |S Nr   r   )r   r   r-   r*   r   )r.   r   r   r   r   r	   r	   r
   r/      s
    z_Reduce.forwardc                 C   s   dt | j| j|f S N)NNN)r   r   r   r   rR   r	   r	   r
   r4   (  s    z_Reduce.backwardNr5   r	   r	   r	   r
   r     s   
r   c                   @   s$   e Zd Zedd Zedd ZdS )r   c                 G   s:   || _ | }tdd |D }tj|t|||d |S )Nc                 s   s   | ]}|  V  qd S r0   rC   rK   r	   r	   r
   rN   3  s     z*_Reduce_Scatter.forward.<locals>.<genexpr>rT   )r   rC   rD   r*   r   rQ   )r.   r   r   r   r   r	   r	   r
   r/   .  s
    z_Reduce_Scatter.forwardc                 C   s   dt | j| S rU   )r   r   r   rR   r	   r	   r
   r4   7  s    z_Reduce_Scatter.backwardNr5   r	   r	   r	   r
   r   -  s   
r   c                   @   s$   e Zd Zedd Zedd ZdS )r   c                    sD       || _ fddttj|dD }tj| |d t|S )Nc                    s   g | ]}t  qS r	   r:   
empty_like)r=   _r?   r	   r
   r@   C  s    z&_AllGather.forward.<locals>.<listcomp>r)   )rC   r   rA   r*   rB   r   rD   )r.   r   r   out_tensor_listr	   r?   r
   r/   =  s    
z_AllGather.forwardc                 G   s   t j| jdt jjkrFt  }t|| }tj	t
j| j|f|  n6dd |D }tj	| j|f| }tjt|dd}d |fS )Nr)   c                 S   s   g | ]}t |qS r	   rW   )r=   r   r	   r	   r
   r@   S  s     z'_AllGather.backward.<locals>.<listcomp>r   )dim)r*   get_backendr   BackendNCCLr+   r:   rX   r   r   r   r1   r   sumstack)r.   rH   r,   r3   rE   Zgxsr	   r	   r
   r4   J  s    z_AllGather.backwardNr5   r	   r	   r	   r
   r   <  s   
r   c                   @   s$   e Zd Zedd Zedd ZdS )r   c                 C   s   || _ tj|| |d |S r(   )r   r*   r   rC   )r.   r   r   r   r	   r	   r
   r/   Y  s    z_AllGatherBase.forwardc                 C   s   t j| jdt jjkrt j| jd}t| }|d | dkrTtd| d| |d t j| jd |d< t	j
||j|jd}t ||tj| j ntdd |d fS )Nr)   r   zTensor with dimensions: z8 does not have first dimension divisible by world_size: devicedtypezBackend not supported!)r*   r\   r   r]   r^   rB   rQ   rJ   RuntimeErrorr:   emptyrb   rc   Z_reduce_scatter_baser   r1   )r.   r2   Z
world_sizeZout_sizer3   r	   r	   r
   r4   _  s    z_AllGatherBase.backwardNr5   r	   r	   r	   r
   r   X  s   
r   c                   @   s$   e Zd Zedd Zedd ZdS )r   c                    s   || _  fddttj|dD | _tj|d}tdd  D  tj|dtjj	krttj|dD ].}d }||krt
 }tj|| |||d qhntj|t
 |d t|S )Nc                    s   g | ]} |   qS r	   rI   r<   rM   r	   r
   r@   t  s    z%_AlltoAll.forward.<locals>.<listcomp>r)   c                 s   s   | ]}|  V  qd S r0   rV   rK   r	   r	   r
   rN   x  s     z$_AlltoAll.forward.<locals>.<genexpr>)r   rA   r*   rB   input_tensor_size_listr+   rD   r\   r]   ZGLOOrQ   r   r    )r.   r   rZ   r   Zmy_rankr>   to_sendr	   rM   r
   r/   q  s$    
z_AlltoAll.forwardc                    s,    fdd| j D }dtj| j|f   S )Nc                    s(   g | ] }t j| d  j d  jdqS )r   ra   )r:   re   rb   rc   )r=   rJ   rH   r	   r
   r@     s   z&_AlltoAll.backward.<locals>.<listcomp>rG   )rf   r   r   r   )r.   rH   rE   r	   rh   r
   r4     s    
z_AlltoAll.backwardNr5   r	   r	   r	   r
   r   p  s   
r   c                   @   s$   e Zd Zedd Zedd ZdS )r!   c                 C   s4   || _ | | _|| _|| _tj|||||d |S )N)r#   r$   r   )r   rJ   
input_sizer#   r$   r*   r%   )r.   r   r   r#   r$   r"   r	   r	   r
   r/     s    
z_AlltoAllSingle.forwardc              	   C   s8   t j| j|j|jd}dt| j|| j| j	|
 f S )Nra   )NNNN)r:   re   ri   rb   rc   r!   r   r   r#   r$   rC   )r.   r2   r   r	   r	   r
   r4     s    z_AlltoAllSingle.backwardNr5   r	   r	   r	   r
   r!     s   
r!   c                   @   s$   e Zd Zedd Zedd ZdS )r&   c                 C   s(   || _ || _| }tj|||d |S rS   )r   r   r-   r*   r'   )r.   r   r   r   r	   r	   r
   r/     s
    z_AllReduce.forwardc                 C   s   dt | j| j|f S rF   )r&   r   r   r   rR   r	   r	   r
   r4     s    z_AllReduce.backwardNr5   r	   r	   r	   r
   r&     s   
r&   )r:   Ztorch.distributeddistributedr*   Ztorch.autogradr   r   r   ZWORLDr   r   r   r1   r   r   r   r   r    r%   r'   r   r   r   r   r   r   r   r   r!   r&   r	   r	   r	   r
   <module>   s4   $
!