U
    0-e                     @   s  d dl Z d dlmZ d dlmZmZmZmZ d dlZd dl	m
Z
mZmZ ddlmZ ddlmZmZmZmZ ddlmZmZmZmZmZmZmZmZmZmZmZ ee Z!dd	ddd dd	d	d dddd
d	dZ"i Z#e#$ D ]\Z%Z&ede%re"'e& qG dd de
Z(G dd deZ)G dd dZ*G dd dee*Z+ed	drdd dl,m-  m.Z/ G dd de/j0Z1G dd dee*Z2d$eeej3 ee4 ee4 e5e5eeee6ef   ee5 e5ee edddZ7G dd de
Z8G d d! d!eZ9d%d"d#Z:dS )&    N)suppress)CallableListOptionalUnion)BatchSampler
DataLoaderIterableDataset   )
get_logger)AcceleratorStateDistributedTypeGradientStateis_tpu_available)RNGType	broadcastbroadcast_object_listconcatenatefind_batch_sizeget_data_structureinitialize_tensorsis_torch_versionsend_to_deviceslice_tensorssynchronize_rng_statesF   )
batch_sizeshufflesamplerbatch_samplerZnum_workersZ
collate_fnZ
pin_memory	drop_lasttimeoutZworker_init_fnZmultiprocessing_context	generatorZprefetch_factorZpersistent_workers>=c                   @   sT   e Zd ZdZdeeeeedddZed	d
 Z	dd Z
dd Zdd Zdd ZdS )BatchSamplerSharda  
    Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
    always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
    Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
    at the first batch that would be too small / not present on all processes or loop with indices from the beginning.

    Args:
        batch_sampler (`torch.utils.data.sampler.BatchSampler`):
            The batch sampler to split in several shards.
        num_processes (`int`, *optional*, defaults to 1):
            The number of processes running concurrently.
        process_index (`int`, *optional*, defaults to 0):
            The index of the current process.
        split_batches (`bool`, *optional*, defaults to `False`):
            Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
            yielding different full batches on each process.

            On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:

            - the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
              this argument is set to `False`.
            - the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
              then `[6, 7]` if this argument is set to `True`.
        even_batches (`bool`, *optional*, defaults to `True`):
            Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
            multiple of (original batch size / number of processes).

    <Tip warning={true}>

    `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
    equal to `False`

    </Tip>r
   r   FT)r   num_processesprocess_indexsplit_batcheseven_batchesc                 C   s   |r*|j | dkr*td|j  d| d|| _|| _|| _|| _|| _t|dd | _ t|dd| _| j d kr|| jr|tdd S )	Nr   zDTo use `BatchSamplerShard` in `split_batches` mode, the batch size (;) needs to be a round multiple of the number of processes ().r   r    FzNYou need to use `even_batches=False` when the batch sampler has no batch size.)	r   
ValueErrorr   r%   r&   r'   r(   getattrr    )selfr   r%   r&   r'   r(    r.   W/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/accelerate/data_loader.py__init__f   s    zBatchSamplerShard.__init__c                 C   s
   t | jS Nlenr   r-   r.   r.   r/   total_length}   s    zBatchSamplerShard.total_lengthc                 C   s   | j rt| jS t| j| j dkr4t| j| j S t| j| j }| jrN|S | jr\|d S | jt| j| j k rz|d S |S d S Nr   r
   )r'   r3   r   r%   r    r(   r&   r-   lengthr.   r.   r/   __len__   s    
zBatchSamplerShard.__len__c                 C   s   | j r|  S |  S r1   )r'   _iter_with_split_iter_with_no_splitr4   r.   r.   r/   __iter__   s    zBatchSamplerShard.__iter__c                 c   s   g }| j j| j }t| j D ]@\}}|dkr0|}t|| jkr||| j || jd   V  q| jst|dkrt|| jk r| jst||| j kr||| j || jd   V  n>t|| jk r||7 }q|| }||| j || jd   V  d S r6   )r   r   r%   	enumerater3   r&   r    r(   )r-   initial_dataZbatch_lengthidxbatchr.   r.   r/   r:      s       
z"BatchSamplerShard._iter_with_splitc                 c   sl  g }g }t | jD ]j\}}| js2|| jk r2||7 }|| j | jkrF|}|| j | jd kr| jd ksrt|| jkr|V  g }q| jsht|dkrh| jst|dkr|V  nt|| jkr|V  t|| j| j k r||7 }qt|| jkrg }|d7 }d}|| j dkst|dkrh|| j t| }|||| 7 }|| j | jkrV|V  |}g }|d7 }qd S Nr
   r   )r=   r   r    r%   r&   r   r3   r(   )r-   r>   Zbatch_to_yieldr?   r@   Zcycle_indexZ	end_indexr.   r.   r/   r;      sD    
z%BatchSamplerShard._iter_with_no_splitN)r
   r   FT)__name__
__module____qualname____doc__r   intboolr0   propertyr5   r9   r<   r:   r;   r.   r.   r.   r/   r$   C   s$   %    
r$   c                   @   s2   e Zd ZdZdeeeeeedddZdd	 Zd
S )IterableDatasetSharda  
    Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
    always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
    `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
    `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
    be too small or loop with indices from the beginning.

    Args:
        dataset (`torch.utils.data.dataset.IterableDataset`):
            The batch sampler to split in several shards.
        batch_size (`int`, *optional*, defaults to 1):
            The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
            `split_batches=True`).
        drop_last (`bool`, *optional*, defaults to `False`):
            Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
            beginning.
        num_processes (`int`, *optional*, defaults to 1):
            The number of processes running concurrently.
        process_index (`int`, *optional*, defaults to 0):
            The index of the current process.
        split_batches (`bool`, *optional*, defaults to `False`):
            Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
            yielding different full batches on each process.

            On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:

            - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
              argument is set to `False`.
            - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
              this argument is set to `True`.
    r
   Fr   )datasetr   r    r%   r&   r'   c                 C   sV   |r.|dkr.|| dkr.t d| d| d|| _|| _|| _|| _|| _|| _d S )Nr
   r   zGTo use `IterableDatasetShard` in `split_batches` mode, the batch size (r)   r*   )r+   rJ   r   r    r%   r&   r'   )r-   rJ   r   r    r%   r&   r'   r.   r.   r/   r0      s    	zIterableDatasetShard.__init__c                 c   s   | j r| jn
| j| j }| j r*| j| j n| j}t| j| | jd | }d }g }| jD ]B}|| t||krX|D ]}|| V  qv|d kr| }g }qX| j	st|dkr|d kr| }t||k r||7 }q|D ]}|| V  qd S rA   )
r'   r   r%   ranger&   rJ   appendr3   copyr    )r-   Zreal_batch_sizeZprocess_batch_sizeZprocess_slicefirst_batchcurrent_batchelementir.   r.   r/   r<     s(    


zIterableDatasetShard.__iter__N)r
   Fr
   r   F)	rB   rC   rD   rE   r	   rF   rG   r0   r<   r.   r.   r.   r/   rI      s   #     rI   c                   @   s0   e Zd ZdZdd Zdd Zdd Zdd	 Zd
S )DataLoaderStateMixina  
    Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
    end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
    useful information that might be needed.

    **Available attributes:**

        - **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
        - **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
          batch size

    c                 K   s   d| _ d| _d S NFend_of_dataloader	remainder)clskwargsr.   r.   r/   __init_subclass__;  s    z&DataLoaderStateMixin.__init_subclass__c                 C   s   d| _ d| _d S rS   rU   r4   r.   r.   r/   reset?  s    zDataLoaderStateMixin.resetc              	   C   sL   |    tt& t| jdt| j}|| j | _W 5 Q R X | j	|  dS )z6Prepares the gradient state for the current dataloadertotal_dataset_lengthN)
r[   r   	Exceptionr,   rJ   r3   total_batch_sizerW   gradient_stateZ_add_dataloaderr7   r.   r.   r/   beginC  s
    
zDataLoaderStateMixin.beginc                 C   s   | j |  dS )z9Cleans up the gradient state after exiting the dataloaderN)r_   Z_remove_dataloaderr4   r.   r.   r/   endK  s    zDataLoaderStateMixin.endN)rB   rC   rD   rE   rZ   r[   r`   ra   r.   r.   r.   r/   rR   -  s
   rR   c                       sF   e Zd ZdZd fdd	Z fddZedd	 Zed
d Z  Z	S )DataLoaderSharda  
    Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.

    Args:
        dataset (`torch.utils.data.dataset.Dataset`):
            The dataset to use to build this datalaoder.
        device (`torch.device`, *optional*):
            If passed, the device to put all batches on.
        rng_types (list of `str` or [`~utils.RNGType`]):
            The list of random number generators to synchronize at the beginning of each iteration. Should be one or
            several of:

            - `"torch"`: the base torch random number generator
            - `"cuda"`: the CUDA random number generator (GPU only)
            - `"xla"`: the XLA random number generator (TPU only)
            - `"generator"`: an optional `torch.Generator`
        synchronized_generator (`torch.Generator`, *optional*):
            A random number generator to keep synchronized across processes.
        split_batches (`int`, *optional*, defaults to 0):
            The number of batches to skip at the beginning.
        kwargs:
            All other keyword arguments to pass to the regular `DataLoader` initialization.

    **Available attributes:**

        - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
            Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
            number of processes

        - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
    Nr   c                    s4   t  j|f| || _|| _|| _|| _t | _d S r1   )superr0   device	rng_typessynchronized_generatorskip_batchesr   r_   )r-   rJ   rd   re   rf   rg   rY   	__class__r.   r/   r0   q  s    zDataLoaderShard.__init__c                 #   s   | j d k	rt| j | j |   t  }zt|}W n tk
rP   d V  Y nX d}z>| jd k	rnt	|| j}t|}|| j
kr|V  |d7 }|}W qV tk
r   d| _|| j
kr|V  Y qY qVX qV|   d S )Nr   r
   T)re   r   rf   r`   rc   r<   nextStopIterationrd   r   rg   rV   ra   )r-   Zdataloader_iterrO   batch_index
next_batchrh   r.   r/   r<   y  s.    




zDataLoaderShard.__iter__c                 C   s<   t | jtr| jn| j}t|ddr*|jS |jt|dd S )Nr'   Fr%   r
   )
isinstancer   r   r   r,   r   )r-   r   r.   r.   r/   r^     s
    
z DataLoaderShard.total_batch_sizec                 C   s"   t | jdr| jjS t| jS d S )Nr5   )hasattrrJ   r5   r3   r4   r.   r.   r/   r\     s    z$DataLoaderShard.total_dataset_length)NNNr   )
rB   rC   rD   rE   r0   r<   rH   r^   r\   __classcell__r.   r.   rh   r/   rb   P  s    
rb   )Zcheck_devicec                       sN   e Zd ZdZeejd fddZ fddZe	dd Z
e	d	d
 Z  ZS )MpDeviceLoaderWrappera  
        Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.

        XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
        prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
        thread only.

        **Available attributes:**

        - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
            Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
            number of processes

        - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
        )
dataloaderrd   c                    s$   t  || | jj| _d | j_d S r1   )rc   r0   _loaderre   
_rng_types)r-   rr   rd   rh   r.   r/   r0     s    
zMpDeviceLoaderWrapper.__init__c                    s$   | j d k	rt| j | jj t  S r1   )rt   r   rs   rf   rc   r<   r4   rh   r.   r/   r<     s    
zMpDeviceLoaderWrapper.__iter__c                 C   s   | j jS r1   )rs   r^   r4   r.   r.   r/   r^     s    z&MpDeviceLoaderWrapper.total_batch_sizec                 C   s   | j jS r1   )rs   r\   r4   r.   r.   r/   r\     s    z*MpDeviceLoaderWrapper.total_dataset_length)rB   rC   rD   rE   rb   torchrd   r0   r<   rH   r^   r\   rp   r.   r.   rh   r/   rq     s   
rq   c                       sb   e Zd ZdZdeed fddZdd	 Z fd
dZ fddZe	dd Z
e	dd Z  ZS )DataLoaderDispatchera  
    Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
    process their part of the batch.

    Args:
        split_batches (`bool`, *optional*, defaults to `False`):
            Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
            yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
            `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
            the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
            `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
            size of the `dataloader` is a round multiple of `batch_size`.
        skip_batches (`int`, *optional*, defaults to 0):
            The number of batches to skip at the beginning of an iteration.

    **Available attributes:**

        - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
            Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
            number of processes

        - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
    Fr   N)r'   
_drop_lastc           	         s   d}t ddr*ddlm} t||r*|j}t j|f| || _|rXtj	j
jj||d t | _t | _|| _|| _|d krtn|| _d S )NFr#   z1.11.0r   )ShufflerIterDataPipe)r   )r   Z-torch.utils.data.datapipes.iter.combinatoricsrx   rn   Z_shuffle_enabledrc   r0   r'   ru   utilsdataZgraph_settingsZapply_shuffle_settingsr   r_   r   staterw   rg   r   slice_fn)	r-   rJ   r'   rg   rw   r|   rY   r   rx   rh   r.   r/   r0     s    

zDataLoaderDispatcher.__init__c                 C   s   d\}}| j jdkrzP| jr&t|}n0g }t| j jD ]}|t| q6t|dd}t|dg}W q t	k
r   d dg}Y qX n
d | j
g}t| |d | _
| j
r| js| js| j jdkrt|dkrt|dd}t|dg}nd dg}t| ||fS )N)NNr   dimFTr
   )r{   r&   r'   rj   rK   r%   rL   r   r   rk   _stop_iterationr   rw   r3   )r-   iteratorZbatchesr@   _
batch_infor.   r.   r/   _fetch_batches  s.    


z#DataLoaderDispatcher._fetch_batchesc                 #   s  |    d }tddr"t  }n| jjdkr8t  }d}d| _d }| |\}}d}|s|| }}| jjdkrt|d }t	|| jj
}t|dd}| js|d kr| j|td| jj| jj| jjd}|d krtd| dt|}	|	| jj }
| j}|s,| |\}}| jr,|d d kr,d	}| jsd|rd|	| jj dkrdt||gdd
}|
d7 }
t| jj|
 | jjd |
 }| j||| jj| jjd}|rd	| _|	| _|| jkr|V  |d7 }qX|   d S )Nr#   z2.0.1r   F)Zfrom_process)r&   r%   z"Batch does not contain any data (`zM`). At the end of all iterable data available before expected stop iteration.Tr}   r
   )r`   r   rc   r<   r{   r&   r   r   r   r   rd   r   rw   r|   slicer%   r+   r   r   rV   rW   rg   ra   )r-   Zmain_iteratorZstop_iterationrN   rm   Znext_batch_inforl   r@   r   Zobserved_batch_sizer   Z
data_slicerh   r.   r/   r<   #  sh    



 
zDataLoaderDispatcher.__iter__c                    s<   t   }| jr|S | jr&|| jj S t|| jj S d S r1   )rc   r9   r'   rw   r{   r%   mathceil)r-   Zwhole_lengthrh   r.   r/   r9   l  s    
zDataLoaderDispatcher.__len__c                 C   s   | j r| jjS | jj| jj S r1   )r'   rJ   r   r%   r4   r.   r.   r/   r^   u  s    z%DataLoaderDispatcher.total_batch_sizec                 C   s
   t | jS r1   )r3   rJ   r4   r.   r.   r/   r\   {  s    z)DataLoaderDispatcher.total_dataset_length)Fr   FN)rB   rC   rD   rE   rG   r0   r   r<   r9   rH   r^   r\   rp   r.   r.   rh   r/   rv     s            %I	
rv   T)rr   rd   r%   r&   r'   put_on_devicere   dispatch_batchesr(   slice_fn_for_dispatchreturnc
                    s  |dkr|sd}nt  jt}|r.|s.tdt }
|dkrB|
j}|dkrP|
j}|r jdkr j| dkrtd j d| d j}t |ts jnd}d}d}|dks|
j	t
jkrr|srt |trt jd	ddk	r jj}t| j j|||d
}npt  jt}|r jj}n jj}t|d	rN|jdkrHt |_|j}|rZ jn j}t|||||d}dddddg|dk	r|dkrd	|kr|d	  fddtD }|dkr j|d< |r|s j| n j|d< |r|d	 t|f|| j|	d| np|r\t|f|rB|
j	t
jkrB|nd| j||d| n2t|f|rz|
j	t
jkrz|nd|||d| |
j	t
jkrt |S  S )a  
    Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.

    Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
    at the first batch that would be too small / not present on all processes or loop with indices from the beginning.

    Args:
        dataloader (`torch.utils.data.dataloader.DataLoader`):
            The data loader to split across several devices.
        device (`torch.device`):
            The target device for the returned `DataLoader`.
        num_processes (`int`, *optional*):
            The number of processes running concurrently. Will default to the value given by
            [`~state.AcceleratorState`].
        process_index (`int`, *optional*):
            The index of the current process. Will default to the value given by [`~state.AcceleratorState`].
        split_batches (`bool`, *optional*, defaults to `False`):
            Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
            yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
            `num_processes` batches at each iteration).

            Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
            this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
            otherwise.

            Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
            `batch_size`.
        put_on_device (`bool`, *optional*, defaults to `False`):
            Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
            dictionaries of tensors).
        rng_types (list of `str` or [`~utils.RNGType`]):
            The list of random number generators to synchronize at the beginning of each iteration. Should be one or
            several of:

            - `"torch"`: the base torch random number generator
            - `"cuda"`: the CUDA random number generator (GPU only)
            - `"xla"`: the XLA random number generator (TPU only)
            - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
              dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.

        dispatch_batches (`bool`, *optional*):
            If set to `True`, the datalaoder prepared is only iterated through on the main process and then the batches
            are split and broadcast to each process. Will default to `True` when the underlying dataset is an
            `IterableDataset`, `False` otherwise.
        even_batches (`bool`, *optional*, defaults to `True`):
            If set to `True`, in cases where the total batch size across all processes does not exactly divide the
            dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
            all workers.
        slice_fn_for_dispatch (`Callable`, *optional*`):
            If passed, this function will be used to slice tensors across `num_processes`. Will default to
            [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
            ignored otherwise.

    Returns:
        `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches

    <Tip warning={true}>

    `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
    equal to `False`

    </Tip>
    NFz<Using `dispatch_batches=True` requires `put_on_device=True`.r
   r   z?To use a `DataLoader` in `split_batches` mode, the batch size (r)   r*   r"   )r   r    r%   r&   r'   )r%   r&   r'   r(   r   r   r   r   r    c                    s&   i | ]}|kr|t  |t| qS r.   r,   _PYTORCH_DATALOADER_KWARGS.0krr   Zignore_kwargsr.   r/   
<dictcomp>  s    z'prepare_data_loader.<locals>.<dictcomp>)r'   r   rw   r|   )rd   r   r   re   rf   )rd   r   re   rf   )rn   rJ   r	   r+   r   r%   r&   r   r   Zdistributed_typer   ZMEGATRON_LMr,   r"   rI   r    r   r   ro   ru   	Generatorr$   remover   poprv   rb   ZTPUrq   )rr   rd   r%   r&   r'   r   re   r   r(   r   r{   Znew_datasetnew_batch_samplersampler_is_batch_samplerrf   r   r   rY   r.   r   r/   prepare_data_loader  s    K








	
r   c                   @   s6   e Zd ZdZdddZdd Zedd Zd	d
 ZdS )SkipBatchSamplerzx
    A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
    r   c                 C   s   || _ || _d S r1   )r   rg   )r-   r   rg   r.   r.   r/   r0   I  s    zSkipBatchSampler.__init__c                 c   s(   t | jD ]\}}|| jkr
|V  q
d S r1   )r=   r   rg   )r-   indexZsamplesr.   r.   r/   r<   M  s    
zSkipBatchSampler.__iter__c                 C   s
   t | jS r1   r2   r4   r.   r.   r/   r5   R  s    zSkipBatchSampler.total_lengthc                 C   s   t | j| j S r1   )r3   r   rg   r4   r.   r.   r/   r9   V  s    zSkipBatchSampler.__len__N)r   )	rB   rC   rD   rE   r0   r<   rH   r5   r9   r.   r.   r.   r/   r   D  s   

r   c                       s.   e Zd ZdZd fdd	Z fddZ  ZS )SkipDataLoadera  
    Subclass of a PyTorch `DataLoader` that will skip the first batches.

    Args:
        dataset (`torch.utils.data.dataset.Dataset`):
            The dataset to use to build this datalaoder.
        skip_batches (`int`, *optional*, defaults to 0):
            The number of batches to skip at the beginning.
        kwargs:
            All other keyword arguments to pass to the regular `DataLoader` initialization.
    r   c                    s   t  j|f| || _d S r1   )rc   r0   rg   )r-   rJ   rg   rY   rh   r.   r/   r0   g  s    zSkipDataLoader.__init__c                 #   s,   t t  D ]\}}|| jkr|V  qd S r1   )r=   rc   r<   rg   )r-   r   r@   rh   r.   r/   r<   k  s    
zSkipDataLoader.__iter__)r   )rB   rC   rD   rE   r0   r<   rp   r.   r.   rh   r/   r   Z  s   r   c                    sN   j }d}t|trd}n(t jt}|r0 jn j}t||d}dddddg fd	d
tD }|dkr j|d<  j	|d< t t
r|dkr||d< t
|f j| jd| nt tr|dkr||d< n |r||d<  j	|d< n||d< t|f j j jd| n4|dkr6t|fd|i| nt|fd|i|  S )ze
    Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
    FN)rg   r   r   r   r   r    c                    s&   i | ]}|kr|t  |t| qS r.   r   r   r   r.   r/   r     s    z&skip_first_batches.<locals>.<dictcomp>rg   )r'   r   rw   )rd   re   rf   )rJ   rn   r	   r   r   r   r   r   r    r   rv   r'   rw   rb   rd   re   rf   r   r   )rr   Znum_batchesrJ   r   r   r   rY   r.   r   r/   skip_first_batchesq  sf    





r   )	NNNFFNNTN)r   );r   
contextlibr   typingr   r   r   r   ru   Ztorch.utils.datar   r   r	   loggingr   r{   r   r   r   r   ry   r   r   r   r   r   r   r   r   r   r   r   rB   loggerr   Z%_PYTORCH_DATALOADER_ADDITIONAL_KWARGSitemsvZadditional_kwargsupdater$   rI   rR   rb   Z%torch_xla.distributed.parallel_loaderdistributedZparallel_loaderZxplZMpDeviceLoaderrq   rv   rd   rF   rG   strr   r   r   r   r.   r.   r.   r/   <module>   s~   4
 Q#W% 4          E