U
    9%e                     @   s  d Z ddlZddlZddlZddlZddlZddlZddlmZ ddl	m
Z
 ddlmZ ddlmZ ddlmZmZmZmZmZmZ ddlZddlZddlmZ ddlmZ dd	lmZmZm Z m!Z! dd
l"m#Z# ddl$m%Z% ddl&m'Z' ddl(m)Z)m*Z*m+Z+mZ e+ re,eej- e*ddr6ddl.m/  m0Z1 zddl2m3Z3 W n e4k
r`   dZ3Y nX e5e6Z7dd Z8eej9ej:f dddZ;dlddZ<dmddZ=dnddZ>dd  Z?d!d" Z@d#d$ ZAd%d& ZBdoeeeC ed'd(d)ZDdeEd*feeeCeFf  eeC eejE ej9d+d,d-ZGd.d/ ZHe
eCd0d1d2ZIG d3d4 d4e#ZJG d5d6 d6e!ZKej(jLjeCd7d8d9ZMdpd:d;ZNdqd<d=ZOd>d? ZPG d@dA dAZQeG dBdC dCZRdrdDdEZSG dFdG dGe!ZTG dHdI dIe#ZUG dJdK dKe!ZVG dLdM dMeZWdNdO ZXdPdQ ZYeeZeFf eeZeFf dRdSdTZ[dUdV Z\dsdXdYZ]dZd[ Z^dtd\d]Z_d^d_ Z`d`da Zadbdc Zbe) rddlcmd  mZe eef dudddeZgeef dfdg Zhdhdi Zidjdk ZjdS )vz(
Torch utilities for the Trainer class.
    N)Mapping)contextmanager)	dataclass)StreamHandler)AnyDictIteratorListOptionalUnion)nn)DatasetIterableDatasetRandomSamplerSampler)DistributedSampler   )is_deepspeed_zero3_enabled)BatchEncoding)is_sagemaker_mp_enabledis_torch_tpu_availableis_training_run_on_sagemakerloggingF)Zcheck_device)SAVE_STATE_WARNING c                 C   s2   t | dr| jd k	rt| jS t | dr.| jS d S )Nbatch_samplersampler)hasattrr   get_dataloader_samplerr   )Z
dataloader r   \/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/trainer_pt_utils.pyr   :   s    

r   Ztensor_or_arrayc                 C   sD   t | tjr6ttdr"t| } q@| jdk r@| d  } n
t| } | S )N
atleast_1dr   )
isinstancetorchTensorr   r"   ndimnpr!   r   r   r    r"   A   s    



r"   c                 C   s   t | } t |}t| jdks2| jd |jd krDtj| |fddS | jd |jd  t| jd |jd f| jdd  }| ||}| |d| jd d| jd f< ||| jd dd|jd f< |S )z`Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary.r   r   dim   N)r"   lenshaper$   catmaxZnew_full)Ztensor1Ztensor2padding_index	new_shaperesultr   r   r    torch_pad_and_concatenateL   s    "8  r3   c                 C   s   t | } t |}t| jdks2| jd |jd krDtj| |fddS | jd |jd  t| jd |jd f| jdd  }tj| ||d}| |d| jd d| jd f< ||| jd dd|jd f< |S )z^Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary.r   r   )Zaxisr+   Nr-   )r"   r,   r-   r'   Zconcatenater/   	full_like)Zarray1Zarray2r0   r1   r2   r   r   r    numpy_pad_and_concatenate^   s    "8  r6   c                    s   t | t  ks.tdt |  dt   dt| ttfr\t | fddt|  D S t| tjrvt|  dS t| t	rt |  fdd| 
 D S t| tjrt|  dS td	t |  d
S )z
    Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
    nested list/tuples/dict of tensors.
    zEExpected `tensors` and `new_tensors` to have the same type but found z and .c                 3   s    | ]\}}t || d V  qdS )r0   Nnested_concat).0tnr8   r   r    	<genexpr>y   s     z nested_concat.<locals>.<genexpr>r8   c                    s$   i | ]\}}|t | | d qS )r8   r9   r;   kr<   new_tensorsr0   r   r    
<dictcomp>~   s      z!nested_concat.<locals>.<dictcomp>z(Unsupported type for concatenation: got N)typeAssertionErrorr#   listtuplezipr$   r%   r3   r   itemsr'   ndarrayr6   	TypeError)tensorsrB   r0   r   rA   r    r:   p   s      
r:   c                 C   s   t | ttfr2| D ]}t|}|dk	r|  S qnt | trh|  D ] \}}t|}|dk	rD|  S qDnPt | tjrt| j	dkr| j	d S dS t | t
jrt| j	dkr| j	d S dS dS )zV
    Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
    Nr   r   )r#   rF   rG   find_batch_sizer   rI   r$   r%   r,   r-   r'   rJ   )rL   r<   r2   keyvaluer   r   r    rM      s    
rM   c                 C   sp   t | ttfr$t| dd | D S t | trHt| dd |  D S |  }|jtj	krh|
tj}| S )zENumpify `tensors` (even if it's a nested list/tuple/dict of tensors).c                 s   s   | ]}t |V  qd S Nnested_numpifyr;   r<   r   r   r    r>      s     z!nested_numpify.<locals>.<genexpr>c                 S   s   i | ]\}}|t |qS r   rQ   r?   r   r   r    rC      s      z"nested_numpify.<locals>.<dictcomp>)r#   rF   rG   rD   r   rI   cpudtyper$   Zbfloat16tofloat32numpy)rL   r<   r   r   r    rR      s    
rR   c                 C   sP   t | ttfr$t| dd | D S t | trHt| dd |  D S |  S )zDDetach `tensors` (even if it's a nested list/tuple/dict of tensors).c                 s   s   | ]}t |V  qd S rP   nested_detachrS   r   r   r    r>      s     z nested_detach.<locals>.<genexpr>c                 S   s   i | ]\}}|t |qS r   rY   r?   r   r   r    rC      s      z!nested_detach.<locals>.<dictcomp>)r#   rF   rG   rD   r   rI   detach)rL   r   r   r    rZ      s
    
rZ   c                    s   t  rdd lm  m} t| ttfrDt|  fddt| D S t| t	rpt|  fddt| 
 D S t| } | | tjS tdd S )Nr   c                 3   s&   | ]\}}t |  d | V  qdS )_Nnested_xla_mesh_reduce)r;   ir<   namer   r    r>      s     z)nested_xla_mesh_reduce.<locals>.<genexpr>c                    s*   i | ]"\}\}}|t |  d | qS )r\   r]   )r;   r_   r@   r<   r`   r   r    rC      s    
  z*nested_xla_mesh_reduce.<locals>.<dictcomp>z;Torch xla must be installed to use `nested_xla_mesh_reduce`)r   torch_xla.core.xla_modelcore	xla_modelr#   rF   rG   rD   	enumerater   rI   r"   Zmesh_reducer$   r.   ImportError)rL   ra   xmr   r`   r    r^      s    
r^   )tensornum_total_examplesreturnc                    s   zt ttfr,t fddD W S t trVt fdd D W S t fddtt	
 D }t	| tj|dd} d k	r|d   }|W S  tk
r   td	Y nX d S )
Nc                 3   s   | ]}t | V  qd S rP   distributed_concatrS   ri   r   r    r>      s     z%distributed_concat.<locals>.<genexpr>c                    s   i | ]\}}|t | qS r   rk   r?   rm   r   r    rC      s      z&distributed_concat.<locals>.<dictcomp>c                    s   g | ]}   qS r   cloner;   r\   rh   r   r    
<listcomp>   s     z&distributed_concat.<locals>.<listcomp>r   r)   (Not currently using distributed training)r#   rG   rF   rD   r   rI   r"   
contiguousrangedistget_world_size
all_gatherr$   r.   rE   )rh   ri   output_tensorsconcatr   )ri   rh   r    rl      s    
 rl   cuda)scalarsri   devicerj   c                    s   z^t | |  fddtt D }t|  t j|dd}|d k	rZ|d | }|W S  tk
rz   tdY nX d S )Nc                    s   g | ]}   qS r   rn   rp   Ztensorized_scalarr   r    rr      s     z1distributed_broadcast_scalars.<locals>.<listcomp>r   r)   rs   )	r$   rh   rV   ru   rv   rw   rx   r.   rE   )r|   ri   r}   ry   rz   r   r~   r    distributed_broadcast_scalars   s    r   c                 C   s>   t | dkr:| D ](}|jtks(|jtkrt|j|j qd S )Nr   )r,   categoryUserWarningmessager   warningswarn)Zcaught_warningswr   r   r    reissue_pt_warnings   s    r   Z
local_rankc                 c   s*   | dkrt   dV  | dkr&t   dS )z
    Decorator to make all processes in distributed training wait for each local_master to do something.

    Args:
        local_rank (`int`): The rank of the local process.
    )r   Nr   )rv   Zbarrierr   r   r   r    torch_distributed_zero_first   s
    r   c                       s,   e Zd ZdZ fddZ fddZ  ZS )DistributedSamplerWithLoopa  
    Like a torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the shuffled
    samples to make each process have a round multiple of batch_size samples.

    Args:
        dataset (`torch.utils.data.Dataset`):
            Dataset used for sampling.
        batch_size (`int`):
            The batch size used with this sampler
        kwargs (`Dict[str, Any]`, *optional*):
            All other keyword arguments passed to `DistributedSampler`.
    c                    s   t  j|f| || _d S rP   )super__init__
batch_size)selfdatasetr   kwargs	__class__r   r    r     s    z#DistributedSamplerWithLoop.__init__c                    sr   t t  }t|| j dkr$dn| jt|| j  }| jt| j| j k rRdnd}|||||  7 }t|S )Nr   r   )	rF   r   __iter__r,   r   rankr   num_replicasiter)r   indices	remainderZstart_remainderr   r   r    r     s
    *z#DistributedSamplerWithLoop.__iter__)__name__
__module____qualname____doc__r   r   __classcell__r   r   r   r    r     s   r   c                   @   s*   e Zd ZdZd	ddZdd Zdd ZdS )
SequentialDistributedSamplera  
    Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training), which means that the model params won't
    have to be synced (i.e. will not hang for synchronization even if varied number of forward passes), we still add
    extra samples to the sampler to make it evenly divisible (like in `DistributedSampler`) to make it easy to `gather`
    or `reduce` resulting tensors at the end of the loop.
    Nc                 C   s   t dt |d kr,t s$tdt }|d krLt sDtdt }|| _|| _	|| _
t| j}|d k	rtt|||  | | _ntt|| | _| j| j	 | _|| _d S )NzUSequentialDistributedSampler is deprecated and will be removed in v5 of Transformers.,Requires distributed package to be available)r   r   FutureWarningrv   is_availableRuntimeErrorrw   get_rankr   r   r   r,   intmathceilnum_samples
total_sizer   )r   r   r   r   r   r   r   r   r    r   '  s*    
z%SequentialDistributedSampler.__init__c                 C   s   t tt| j}||d | jt|  7 }t|| jksVtdt| d| j d|| j| j | jd | j  }t|| jkstdt| d| j dt|S )NzIndices length z and total size z mismatchedr   z and sample number )	rF   ru   r,   r   r   rE   r   r   r   r   r   r   r   r    r   @  s     z%SequentialDistributedSampler.__iter__c                 C   s   | j S rP   r   r   r   r   r    __len__Q  s    z$SequentialDistributedSampler.__len__)NNN)r   r   r   r   r   r   r   r   r   r   r    r     s   	
r   r   r   c                 C   s*   t  dkrt| S t| t  t  dS )Nr   )r   r   )rg   Zxrt_world_sizer   r   Zget_ordinalr   r   r   r    get_tpu_samplerU  s    r   c                    sH   t | ttfr(t|  fdd| D S tj| | f| jdd dS )z\Create the same nested structure as `arrays` with a first dimension always at `num_samples`.c                 3   s   | ]}t | V  qd S rP   )nested_new_liker;   xr   r   r    r>   ^  s     z"nested_new_like.<locals>.<genexpr>r   Nr4   )r#   rF   rG   rD   r'   r5   r-   )arraysr   r0   r   r   r    r   [  s    r   c                 C   sF   t j| || jd |f| jdd  d}| |ddd| jd f< |S )zmExpand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding.r   r+   Nr4   r   )r'   r5   r-   )r   Znew_seq_lengthr0   r2   r   r   r    expand_likeb  s    (r   c                    s\   t | ttfr(t|  fdd| D S t | trPt|  fdd|  D S | d  S )zQTruncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors).c                 3   s   | ]}t | V  qd S rP   nested_truncaterS   limitr   r    r>   l  s     z"nested_truncate.<locals>.<genexpr>c                    s   i | ]\}}|t | qS r   r   r?   r   r   r    rC   n  s      z#nested_truncate.<locals>.<dictcomp>N)r#   rF   rG   rD   r   rI   )rL   r   r   r   r    r   i  s
    
r   c                   @   s2   e Zd ZdZdddZdd Zdd	 Zd
d ZdS )DistributedTensorGathereraR  
    A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.

    If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on CPU at every
    step, our sampler will generate the following indices:

        `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`

    to get something of size a multiple of 3 (so that each process gets the same dataset length). Then process 0, 1 and
    2 will be responsible of making predictions for the following samples:

        - P0: `[0, 1, 2, 3, 4, 5]`
        - P1: `[6, 7, 8, 9, 10, 11]`
        - P2: `[12, 13, 14, 15, 0, 1]`

    The first batch treated on each process will be

        - P0: `[0, 1]`
        - P1: `[6, 7]`
        - P2: `[12, 13]`

    So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) corresponding to
    the following indices:

        `[0, 1, 6, 7, 12, 13]`

    If we directly concatenate our results without taking any precautions, the user will then get the predictions for
    the indices in this order at the end of the prediction loop:

        `[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`

    For some reason, that's not going to roll their boat. This class is there to solve that problem.

    Args:
        world_size (`int`):
            The number of processes used in the distributed training.
        num_samples (`int`):
            The number of samples in our dataset.
        make_multiple_of (`int`, *optional*):
            If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
            (by adding samples).
        padding_index (`int`, *optional*, defaults to -100):
            The padding index to use if the arrays don't all have the same sequence length.
    Nr(   c                 C   sf   t dt || _|| _|d kr$|n|| }tt|| | | _| j| | _	d | _
d | _|| _d S )NzRDistributedTensorGatherer is deprecated and will be removed in v5 of Transformers.)r   r   r   
world_sizer   r   r'   r   total_samplesprocess_length_storage_offsetsr0   )r   r   r   Zmake_multiple_ofr0   r   r   r   r    r     s    z"DistributedTensorGatherer.__init__c                 C   sz   |dkrdS | j dkr@t|| j| jd| _ ttd| j| j| _| | j |\}| _ t| j	D ]}| j|  |7  < q^dS )z
        Add `arrays` to the internal storage, Will initialize the storage to the full size at the first arrays passed
        so that if we're bound to get an OOM, it happens at the beginning.
        Nr8   r   )
r   r   r   r0   rF   ru   r   r   _nested_set_tensorsr   )r   r   	slice_lenr_   r   r   r    
add_arrays  s    
z$DistributedTensorGatherer.add_arraysc                    s\  t |ttfrH fddt||D }|d d t|dd |D fS |jd  j dksztd j d|jd  d|jd  j }t jD ]}t	|jd	kr||| |d	 |  | j
|  j
| | < qt	|jd	kr|jd	 |jd	 k rt||jd	  jd
}||| |d	 |  | j
|  j
| | d |jd	 f< q||fS )Nc                    s   g | ]\}}  ||qS r   )r   )r;   r   yr   r   r    rr     s     zADistributedTensorGatherer._nested_set_tensors.<locals>.<listcomp>r   c                 s   s   | ]}|d  V  qdS )r   Nr   )r;   rr   r   r    r>     s     z@DistributedTensorGatherer._nested_set_tensors.<locals>.<genexpr>z<Arrays passed should all have a first dimension multiple of z, found r7   r   r8   )r#   rF   rG   rH   rD   r-   r   rE   ru   r,   r   r   r0   )r   Zstorager   r2   r   r_   r   r   r    r     s$    "2& 
.z-DistributedTensorGatherer._nested_set_tensorsc                 C   s6   | j dkrdS | jd | jkr(td t| j | jS )z
        Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
        to get each process a dataset of the same length).
        Nr   z>Not all data has been set. Are you sure you passed all values?)r   r   r   loggerwarningr   r   r   r   r   r    finalize  s
    

z"DistributedTensorGatherer.finalize)Nr(   )r   r   r   r   r   r   r   r   r   r   r   r    r   s  s
   -
r   c                   @   s4   e Zd ZU dZdZeed< dZeed< d
ddZ	d	S )LabelSmoothera@  
    Adds label-smoothing on a pre-computed output from a Transformers model.

    Args:
        epsilon (`float`, *optional*, defaults to 0.1):
            The label smoothing factor.
        ignore_index (`int`, *optional*, defaults to -100):
            The index in the labels to ignore when computing the loss.
    g?epsilonr(   ignore_indexFc           
      C   s  t |tr|d n|d }|rL|dd dd d f  }|ddd f  }tjj|dd }| | d kr||d}|| j	}t
j|dd}|jd|d}|jdd	t
jd
}||d ||d | |   }	| |	 }| |	|jd   }d| j | | j|  S )Nlogitsr   .r   r   r)   )min)r*   indexT)r*   ZkeepdimrU   g        )r#   dictrt   r   Z
functionalZlog_softmaxr*   Z	unsqueezeeqr   r$   clampgathersumrW   Zmasked_fill_numellongr-   r   )
r   Zmodel_outputlabelsZshift_labelsr   Z	log_probsZpadding_maskZnll_lossZsmoothed_lossZnum_active_elementsr   r   r    __call__  s"    
zLabelSmoother.__call__N)F)
r   r   r   r   r   float__annotations__r   r   r   r   r   r   r    r     s   

r   c                    s   |dkr*t t|d  d}|dkr*d}tjt|d ||  fddtdtD }fd	d|D }fd
d|D }tt| }|| d |d d  |d d< || d< dd |D S )a  
    Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
    lengths. To do this, the indices are:

    - randomly permuted
    - grouped in mega-batches of size `mega_batch_mult * batch_size`
    - sorted by length in each mega-batch

    The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
    maximum length placed first, so that an OOM happens sooner rather than later.
    N   2   r   r   	generatorc                    s    g | ]} ||    qS r   )tolist)r;   r_   )r   megabatch_sizer   r    rr      s     z.get_length_grouped_indices.<locals>.<listcomp>c                    s"   g | ]}t | fd dddqS )c                    s    |  S rP   r   )r_   lengthsr   r    <lambda>!      z7get_length_grouped_indices.<locals>.<listcomp>.<lambda>T)rN   reverse)sortedr;   	megabatchr   r   r    rr   !  s     c                    s   g | ]} |d   qS )r   r   r   r   r   r    rr   %  s     c                 S   s   g | ]}|D ]}|qqS r   r   )r;   r   r_   r   r   r    rr   *  s       )r   r,   r$   Zrandpermru   Zargmaxrh   item)r   r   Zmega_batch_multr   ZmegabatchesZmegabatch_maximumsZmax_idxr   )r   r   r   r    get_length_grouped_indices
  s     *r   c                   @   sF   e Zd ZdZd
eee eee  ee dddZ	dd Z
dd	 ZdS )LengthGroupedSamplerz
    Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
    keeping a bit of randomness.
    N)r   r   r   model_input_namec                    s   |d kr|d krt d|| _|d kr d k	r2 nd t|d tsRt|d tr^ |d krnt d  d fdd|D }nt|tjrtd |	 }|| _
|| _d S )	N,One of dataset and lengths must be provided.	input_idsr   XCan only automatically infer lengths for datasets whose items are dictionaries with an '' key.c                    s   g | ]}t |  qS r   r,   r;   featurer   r   r    rr   I  s     z1LengthGroupedSampler.__init__.<locals>.<listcomp>zcIf lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...)
ValueErrorr   r#   r   r   r$   r%   r   infor   r   r   )r   r   r   r   r   r   r   r   r    r   3  s,    

zLengthGroupedSampler.__init__c                 C   s
   t | jS rP   )r,   r   r   r   r   r    r   S  s    zLengthGroupedSampler.__len__c                 C   s   t | j| j| jd}t|S Nr   )r   r   r   r   r   r   r   r   r    r   V  s    zLengthGroupedSampler.__iter__)NNNN)r   r   r   r   r   r
   r   r	   strr   r   r   r   r   r   r    r   -  s       
 r   c                
   @   sT   e Zd ZdZdeee ee ee eeeee  ee	 dddZ
edd	d
ZdS )DistributedLengthGroupedSamplerz
    Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
    length while keeping a bit of randomness.
    Nr   F)r   r   r   r   seed	drop_lastr   r   c	           	         sl  |d kr|d krt d|d kr8t s0tdt }|d krXt sPtdt }|| _|| _|| _d| _	|| _
|d krڈ d k	r nd t|d tst|d tr |d krt d  d fdd|D }nt|tjrtd	 | }|| _| j
r<t| j| j dkr<tt| j| j | j | _ntt| j| j | _| j| j | _|| _d S )
Nr   r   r   r   r   r   c                    s   g | ]}t |  qS r   r   r   r   r   r    rr     s     z<DistributedLengthGroupedSampler.__init__.<locals>.<listcomp>znIf lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to List[int]...)r   rv   r   r   rw   r   r   r   r   epochr   r#   r   r   r$   r%   r   r   r   r   r,   r   r   r   r   r   )	r   r   r   r   r   r   r   r   r   r   r   r    r   b  sL    

 z(DistributedLengthGroupedSampler.__init__)rj   c                 C   s   t  }|| j| j  t| j| j|d}| jsN||d | j	t
|  7 }n|d | j	 }t
|| j	ksnt|| j| j	| j }t
|| jkstt|S r   )r$   	Generatormanual_seedr   r   r   r   r   r   r   r,   rE   r   r   r   r   )r   gr   r   r   r    r     s    z(DistributedLengthGroupedSampler.__iter__)NNNr   FNN)r   r   r   r   r   r
   r   boolr	   r   r   r   r   r   r   r   r    r   [  s&   	       
<r   c                   @   s8   e Zd ZdZdeeeeedddZdd	 Zd
d Z	dS )ShardSamplera  
    Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
    size 4, the first two batches are `[0, 1, 2, 3, 4, 5, 6, 7]` and `[8, 9, 10, 11, 12, 13, 14, 15]`, which shard into
    `[0, 1, 2, 3]` and `[8, 9, 10, 11]` for GPU-0 and `[4, 5, 6, 7]` and `[12, 13, 14, 15]` for GPU-1.

    The sampler thus yields `[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and `[4, 5, 6, 7, 12, 13, 14, 15]` on GPU-1.
    r   Fr   )r   r   r   num_processesprocess_indexc                 C   s\   || _ || _|| _|| _|| _||  | _}|r<t|| ntt|| }|| | _	d S rP   )
r   r   r   r   r   total_batch_sizer,   r   r   total_num_samples)r   r   r   r   r   r   r   Znum_batchesr   r   r    r     s    "zShardSampler.__init__c                 C   s|   t tt| j}t|| jk r<||d | jt|  7 }qg }t| j| j | j| jD ]}||||| j  7 }qXt|S rP   )	rF   ru   r,   r   r  r   r   r   r   )r   r   r2   Zbatch_startr   r   r    r     s    zShardSampler.__iter__c                 C   s   | j | j S rP   )r  r   r   r   r   r    r     s    zShardSampler.__len__N)r   Fr   r   )
r   r   r   r   r   r   r   r   r   r   r   r   r   r    r     s       r   c                   @   sB   e Zd ZdZdeeeeeedddZdd	 Zd
d Z	dd Z
dS )IterableDatasetSharda  
    Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
    always yield a number of samples that is a round multiple of the actual batch size (which is `batch_size x
    num_processes`). Depending on the value of the `drop_last` attribute, it will either stop the iteration at the
    first batch that would be too small or loop with indices from the beginning.

    On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch size of
    2:

    - the shard on process 0 will yield `[0, 1, 4, 5, 8, 9]` so will see batches `[0, 1]`, `[4, 5]`, `[8, 9]`
    - the shard on process 1 will yield `[2, 3, 6, 7, 10, 11]` so will see batches `[2, 3]`, `[6, 7]`, `[10, 11]`

    <Tip warning={true}>

        If your IterableDataset implements some randomization that needs to be applied the same way on all processes
        (for instance, a shuffling), you should use a `torch.Generator` in a `generator` attribute of the `dataset` to
        generate your random numbers and call the [`~trainer_pt_utils.IterableDatasetShard.set_epoch`] method of this
        object. It will set the seed of this `generator` to `seed + epoch` on all processes before starting the
        iteration. Alternatively, you can also implement a `set_epoch()` method in your iterable dataset to deal with
        this.

    </Tip>

    Args:
        dataset (`torch.utils.data.IterableDataset`):
            The batch sampler to split in several shards.
        batch_size (`int`, *optional*, defaults to 1):
            The size of the batches per shard.
        drop_last (`bool`, *optional*, defaults to `False`):
            Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
            beginning.
        num_processes (`int`, *optional*, defaults to 1):
            The number of processes running concurrently.
        process_index (`int`, *optional*, defaults to 0):
            The index of the current process.
        seed (`int`, *optional*, defaults to 0):
            A random seed that will be used for the random number generation in
            [`~trainer_pt_utils.IterableDatasetShard.set_epoch`].
    r   Fr   )r   r   r   r   r   r   c                 C   s4   || _ || _|| _|| _|| _|| _d| _d| _d S )Nr   )r   r   r   r   r   r   r   num_examples)r   r   r   r   r   r   r   r   r   r    r     s    	zIterableDatasetShard.__init__c                 C   s"   || _ t| jdr| j| d S )N	set_epoch)r   r   r   r  )r   r   r   r   r    r    s    zIterableDatasetShard.set_epochc                 c   s&  d| _ t| jdsDt| jdrDt| jjtjrD| jj| j| j	  | j
| j }t| j| j
 | jd | j
 }d }g }| jD ]P}|  j d7  _ || t||kr||D ]}|| V  q|d kr| }g }q|| js"t|dkr"|d kr| }t||k r||7 }q|D ]}|| V  qd S )Nr   r  r   r   )r  r   r   r#   r   r$   r   r   r   r   r   r   ru   r   appendr,   copyr   )r   Zreal_batch_sizeZprocess_sliceZfirst_batchZcurrent_batchelementr_   r   r   r    r   "  s8    




zIterableDatasetShard.__iter__c                 C   sH   | j r"t| j| j| j  | j S tt| j| j| j  | j S d S rP   )r   r,   r   r   r   r   r   r   r   r   r    r   C  s    zIterableDatasetShard.__len__N)r   Fr   r   r   )r   r   r   r   r   r   r   r   r  r   r   r   r   r   r    r    s"   +     !r  c              
   C   s   | j r\z| j d }W q tk
rX } z"dt|krFtd d}n W 5 d }~X Y qX nDt| jtj	jj
r| jjd d }n| j d }t|r| }|S )Nr   zneed to call stepzQtried to get lr value before scheduler/optimizer started stepping, returning lr=0lr)Zis_deepspeed_enabledZlr_schedulerZget_last_lrrE   r   r   r   r#   r$   ZoptimZReduceLROnPlateauZ	optimizerZparam_groupsZ	is_tensorr   )r   Zlast_lrer   r   r    _get_learning_rateO  s    

r
  c                 C   s4   t t| t |  d }tjt | d d|dS )zG
    convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals
    d   )secondsr7   02d)r   absdatetime	timedelta)ZsecsZmsecr   r   r    _secs2timedeltaf  s    r  )metricsrj   c                 C   s   |  }| D ]x\}}d|kr4|d?  d||< qd|krJt|||< q|dkrjt|d?  d||< qt|| tkrt|d||< q|S )	z
    Reformat Trainer metrics values to a human-readable format

    Args:
        metrics (`Dict[str, float]`):
            The metrics returned from train/evaluate/predict

    Returns:
        metrics (`Dict[str, float]`): The reformatted metrics
    Z_mem_   MBZ_runtimeZ
total_flos   ZGFr   )r  rI   r  r   rD   r   round)r   r  Zmetrics_copyr@   vr   r   r    metrics_formato  s    r  c                 C   s   |   sdS td| d | |}tdd | D }tdd | D }t| D ],}td|d| d	|| d
|  q^dS )a@  
    Log metrics in a specially formatted way

    Under distributed environment this is done only for a process with rank 0.

    Args:
        split (`str`):
            Mode/split name: one of `train`, `eval`, `test`
        metrics (`Dict[str, float]`):
            The metrics returned from train/evaluate/predictmetrics: metrics dict

    Notes on memory reports:

    In order to get memory usage report you need to install `psutil`. You can do that with `pip install psutil`.

    Now when this method is run, you will see a report that will include: :

    ```
    init_mem_cpu_alloc_delta   =     1301MB
    init_mem_cpu_peaked_delta  =      154MB
    init_mem_gpu_alloc_delta   =      230MB
    init_mem_gpu_peaked_delta  =        0MB
    train_mem_cpu_alloc_delta  =     1345MB
    train_mem_cpu_peaked_delta =        0MB
    train_mem_gpu_alloc_delta  =      693MB
    train_mem_gpu_peaked_delta =        7MB
    ```

    **Understanding the reports:**

    - the first segment, e.g., `train__`, tells you which stage the metrics are for. Reports starting with `init_`
        will be added to the first stage that gets run. So that if only evaluation is run, the memory usage for the
        `__init__` will be reported along with the `eval_` metrics.
    - the third segment, is either `cpu` or `gpu`, tells you whether it's the general RAM or the gpu0 memory
        metric.
    - `*_alloc_delta` - is the difference in the used/allocated memory counter between the end and the start of the
        stage - it can be negative if a function released more memory than it allocated.
    - `*_peaked_delta` - is any extra memory that was consumed and then freed - relative to the current allocated
        memory counter - it is never negative. When you look at the metrics of any stage you add up `alloc_delta` +
        `peaked_delta` and you know how much memory was needed to complete that stage.

    The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the
    main process does the bulk of work, but it could be not quite so if model parallel is used and then other GPUs may
    use a different amount of gpu memory. This is also not the same under DataParallel where gpu0 may require much more
    memory than the rest since it stores the gradient and optimizer states for all participating GPUS. Perhaps in the
    future these reports will evolve to measure those too.

    The CPU RAM metric measures RSS (Resident Set Size) includes both the memory which is unique to the process and the
    memory shared with other processes. It is important to note that it does not include swapped out memory, so the
    reports could be imprecise.

    The CPU peak memory is measured using a sampling thread. Due to python's GIL it may miss some of the peak memory if
    that thread didn't get a chance to run when the highest memory was used. Therefore this report can be less than
    reality. Using `tracemalloc` would have reported the exact peak memory, but it doesn't report memory allocations
    outside of python. So if some C++ CUDA extension allocated its own memory it won't be reported. And therefore it
    was dropped in favor of the memory sampling approach, which reads the current process memory usage.

    The GPU allocated and peak memory reporting is done with `torch.cuda.memory_allocated()` and
    `torch.cuda.max_memory_allocated()`. This metric reports only "deltas" for pytorch-specific allocations, as
    `torch.cuda` memory management system doesn't track any memory allocated outside of pytorch. For example, the very
    first cuda call typically loads CUDA kernels, which may take from 0.5 to 2GB of GPU memory.

    Note that this tracker doesn't account for memory allocations outside of [`Trainer`]'s `__init__`, `train`,
    `evaluate` and `predict` calls.

    Because `evaluation` calls may happen during `train`, we can't handle nested invocations because
    `torch.cuda.max_memory_allocated` is a single counter, so if it gets reset by a nested eval call, `train`'s tracker
    will report incorrect info. If this [pytorch issue](https://github.com/pytorch/pytorch/issues/16266) gets resolved
    it will be possible to change this class to be re-entrant. Until then we will only track the outer level of
    `train`, `evaluate` and `predict` methods. Which means that if `eval` is called during `train`, it's the latter
    that will account for its memory usage and that of the former.

    This also means that if any other tool that is used along the [`Trainer`] calls
    `torch.cuda.reset_peak_memory_stats`, the gpu peak memory stats could be invalid. And the [`Trainer`] will disrupt
    the normal behavior of any such tools that rely on calling `torch.cuda.reset_peak_memory_stats` themselves.

    For best performance you may want to consider turning the memory profiling off for production runs.
    Nz***** z metrics *****c                 s   s   | ]}t t|V  qd S rP   r,   r   r   r   r   r    r>     s     zlog_metrics.<locals>.<genexpr>c                 s   s   | ]}t t|V  qd S rP   r  r   r   r   r    r>     s     z  z <z = >)is_world_process_zeroprintr  r/   keysvaluesr   )r   splitr  Zmetrics_formattedZk_widthZv_widthrN   r   r   r    log_metrics  s    O
r   Tc              	   C   s   |   sdS tj| jj| d}t|d}tj||ddd W 5 Q R X |rtj| jjd}tj	|rt|d}t
|}W 5 Q R X ni }|| t|d}tj||ddd W 5 Q R X dS )	a  
    Save metrics into a json file for that split, e.g. `train_results.json`.

    Under distributed environment this is done only for a process with rank 0.

    Args:
        split (`str`):
            Mode/split name: one of `train`, `eval`, `test`, `all`
        metrics (`Dict[str, float]`):
            The metrics returned from train/evaluate/predict
        combined (`bool`, *optional*, defaults to `True`):
            Creates combined metrics by updating `all_results.json` with metrics of this call

    To understand the metrics please read the docstring of [`~Trainer.log_metrics`]. The only difference is that raw
    unformatted numbers are saved in the current method.

    Nz_results.jsonr   r   T)indent	sort_keyszall_results.jsonr   )r  ospathjoinargs
output_diropenjsondumpexistsloadupdate)r   r  r  combinedr$  fZall_metricsr   r   r    save_metrics  s    
r0  c                 C   s.   |   sdS tj| jjd}| j| dS )z
    Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model

    Under distributed environment this is done only for a process with rank 0.
    Nztrainer_state.json)r  r#  r$  r%  r&  r'  stateZsave_to_json)r   r$  r   r   r    
save_state	  s    r2  c                    s4   t  rdd  ndd  t fdd|  D S )zn
    Calculate model's total param count. If trainable_only is True then count only those requiring grads
    c                 S   s   t | dr| jS |  S )Nds_numel)r   r3  r   pr   r   r    r     s    z$get_model_param_count.<locals>.numelc                 S   s   |   S rP   )r   r4  r   r   r    r   !  s    c                 3   s    | ]}r|j r |V  qd S rP   )Zrequires_grad)r;   r5  r   trainable_onlyr   r    r>   $  s       z(get_model_param_count.<locals>.<genexpr>)r   r   
parameters)modelr7  r   r6  r    get_model_param_count  s    
r:  c                    sL   g }|   D ](\ | fddt D 7 }q|t| j 7 }|S )zZ
    Returns the names of the model parameters that are not inside a forbidden layer.
    c                    s(   g | ] }t  ts d | qS )r7   )r#   rG   )r;   r=   childforbidden_layer_typesra   r   r    rr   -  s   z'get_parameter_names.<locals>.<listcomp>)Znamed_childrenget_parameter_namesrF   _parametersr  )r9  r=  r2   r   r;  r    r>  '  s    
r>  c                 C   sV   t |  }| jj|kr| jS t|dkr.dS |D ]}t||}|dk	r2|  S q2dS )z
    Gets a class from a module by its name.

    Args:
        module (`torch.nn.Module`): The module to get the class from.
        name (`str`): The name of the class.
    r   N)rF   childrenr   r   r,   get_module_class_from_name)modulera   Zmodules_childrenZchild_moduleZmodule_classr   r   r    rA  7  s    
rA  c                 C   s6   | r2|D ](}t j||}t j|rt | qd S rP   )r#  r$  r%  isfileremove)Zis_main_processr'  	filenamesfilenamefiler   r   r    remove_dummy_checkpointK  s
    rH  c                 C   s:   | f |}t |tr|d n|d }|| }| | |S )Nlossr   )r#   r   Zbackward)r9  inputsZgradient_accumulation_stepsoutputsrI  r   r   r    smp_forward_backwardV  s
    

rL  c                 C   s
   | f |S rP   r   )r9  rJ  r   r   r    smp_forward_only^  s    rM  c                 C   s   t | ttfr$t| dd | D S t | trHt| dd |  D S t | tjshtdt|  dt	
| t	jj}dd |D }tjd	d |D d
dS )Nc                 s   s   | ]}t |V  qd S rP   
smp_gatherrS   r   r   r    r>   d  s     zsmp_gather.<locals>.<genexpr>c                 S   s   i | ]\}}|t |qS r   rN  r;   r@   r  r   r   r    rC   f  s      zsmp_gather.<locals>.<dictcomp>z Can't gather the values of type z-, only of nested list/tuple/dicts of tensors.c                 S   s   g | ]}t |qS r   )r"   rS   r   r   r    rr   l  s     zsmp_gather.<locals>.<listcomp>c                 S   s   g | ]}|  qS r   )rT   rS   r   r   r    rr   m  s     r   r)   )r#   rF   rG   rD   r   rI   r$   r%   rK   smpZ	allgatherZ	CommGroupZDP_GROUPr.   )rh   Zall_tensorsr   r   r    rO  b  s    
rO  c                 C   sX   t | ttfr$t| dd | D S t | trHt| dd |  D S |    S )Nc                 s   s   | ]}t |V  qd S rP   smp_nested_concatrS   r   r   r    r>   q  s     z$smp_nested_concat.<locals>.<genexpr>c                 S   s   i | ]\}}|t |qS r   rR  rP  r   r   r    rC   s  s      z%smp_nested_concat.<locals>.<dictcomp>)	r#   rF   rG   rD   r   rI   rz   r[   rT   rq   r   r   r    rS  o  s
    
rS  )r(   )r(   )r(   )N)r(   )r(   )NN)T)F)r   )kr   r  r)  r   r#  sysr   collections.abcr   
contextlibr   dataclassesr   r   r   typingr   r   r   r	   r
   r   rX   r'   r$   Ztorch.distributeddistributedrv   r   Ztorch.utils.datar   r   r   r   Ztorch.utils.data.distributedr   Zintegrations.deepspeedr   Ztokenization_utils_baser   utilsr   r   r   add_handlerstdoutrb   rc   rd   rg   Ztorch.optim.lr_schedulerr   rf   Z
get_loggerr   r   r   r%   rJ   r"   r3   r6   r:   rM   rR   rZ   r^   r   rl   r}   r   r   r   r   r   r   datar   r   r   r   r   r   r   r   r   r   r  r
  r  r   r  r   r0  r2  r:  r>  rA  rH  Z!smdistributed.modelparallel.torchZmodelparallelrQ  steprL  rM  rO  rS  r   r   r   r    <module>   s    




	8


l*
#.X/m	 Z
&

