U
    0-e                     @   s  d Z ddlZddlZddlZddlZddlZddlZddlZddlm	Z	 ddl
mZmZ ddlmZ ddlmZmZmZmZmZmZmZ ddlZddlmZmZmZ dd	lmZ dd
lmZ G dd dZeG dd deZ eG dd deZ!eG dd deZ"eG dd deZ#eG dd deZ$G dd de%ej&Z'G dd de%ej&Z(G dd de%ej&Z)G dd de%ej&Z*G dd  d ej+Z,G d!d" d"ej&e,d#Z-G d$d% d%e-Z.G d&d' d'e-Z/G d(d) d)e-Z0G d*d+ d+ej&Z1eG d,d- d-Z2eG d.d/ d/Z3eG d0d1 d1eZ4eG d2d3 d3eZ5eG d4d5 d5Z6eG d6d7 d7Z7eG d8d9 d9Z8eG d:d; d;Z9dS )<z1
General namespace and dataclass related classes
    N)contextmanager)	dataclassfield)	timedelta)AnyCallableDictIterableListOptionalTuple   )FSDP_AUTO_WRAP_POLICYFSDP_BACKWARD_PREFETCHFSDP_STATE_DICT_TYPE)str_to_bool)compare_versionsc                   @   s    e Zd ZdZdd Zdd ZdS )KwargsHandlerzP
    Internal mixin that implements a `to_kwargs()` method for a dataclass.
    c                 C   s   t | jS N)copydeepcopy__dict__self r   ]/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/accelerate/utils/dataclasses.pyto_dict+   s    zKwargsHandler.to_dictc              	      sH   ddl m} |  |    W 5 Q R X |  } fdd| D S )zv
        Returns a dictionary containing the attributes with values different from the default of this class.
        r   )clear_environmentc                    s"   i | ]\}} | |kr||qS r   r   .0kvZdefault_dictr   r   
<dictcomp>8   s       z+KwargsHandler.to_kwargs.<locals>.<dictcomp>)otherr   	__class__r   items)r   r   Z	this_dictr   r"   r   	to_kwargs.   s
    zKwargsHandler.to_kwargsN)__name__
__module____qualname____doc__r   r'   r   r   r   r   r   &   s   r   c                   @   s*   e Zd ZU dZdZeed< dZeed< dS )AutocastKwargsa  
    Use this object in your [`Accelerator`] to customize how `torch.autocast` behaves. Please refer to the
    documentation of this [context manager](https://pytorch.org/docs/stable/amp.html#torch.autocast) for more
    information on each argument.

    Example:

    ```python
    from accelerate import Accelerator
    from accelerate.utils import AutocastKwargs

    kwargs = AutocastKwargs(cache_enabled=True)
    accelerator = Accelerator(kwargs_handlers=[kwargs])
    ```
    TenabledNcache_enabled)r(   r)   r*   r+   r-   bool__annotations__r.   r   r   r   r   r,   ;   s   
r,   c                   @   sf   e Zd ZU dZdZeed< dZeed< dZ	eed< dZ
eed	< dZeed
< dZeed< dZeed< dS )DistributedDataParallelKwargsa  
    Use this object in your [`Accelerator`] to customize how your model is wrapped in a
    `torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this
    [wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more
    information on each argument.

    <Tip warning={true}>

    `gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions.

    `static_graph` is only available in PyTorch 1.11.0 and later versions.

    </Tip>

    Example:

    ```python
    from accelerate import Accelerator
    from accelerate.utils import DistributedDataParallelKwargs

    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[kwargs])
    ```
    r   dimTbroadcast_buffers   bucket_cap_mbFfind_unused_parameterscheck_reductiongradient_as_bucket_viewstatic_graphN)r(   r)   r*   r+   r2   intr0   r3   r/   r5   r6   r7   r8   r9   r   r   r   r   r1   Q   s   
r1   c                   @   sN   e Zd ZU dZdZeed< dZeed< dZeed< dZ	e
ed	< d
Zeed< dS )GradScalerKwargsa  
    Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the
    `torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this
    [scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument.

    <Tip warning={true}>

    `GradScaler` is only available in PyTorch 1.5.0 and later versions.

    </Tip>

    Example:

    ```python
    from accelerate import Accelerator
    from accelerate.utils import GradScalerKwargs

    kwargs = GradScalerKwargs(backoff_filter=0.25)
    accelerator = Accelerator(kwargs_handlers=[kwargs])
    ```
    g      @
init_scaleg       @growth_factorg      ?backoff_factori  growth_intervalTr-   N)r(   r)   r*   r+   r<   floatr0   r=   r>   r?   r:   r-   r/   r   r   r   r   r;   u   s   
r;   c                   @   sD   e Zd ZU dZdZee ed< dZee ed< e	ddZ
e	ed< dS )	InitProcessGroupKwargsaB  
    Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer
    to the documentation of this
    [method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
    information on each argument.

    ```python
    from datetime import timedelta
    from accelerate import Accelerator
    from accelerate.utils import InitProcessGroupKwargs

    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=800))
    accelerator = Accelerator(kwargs_handlers=[kwargs])
    ```
    ZncclbackendNinit_methodi  )secondstimeout)r(   r)   r*   r+   rB   r   strr0   rC   r   rE   r   r   r   r   rA      s   
rA   c                   @   sl   e Zd ZU dZdZeed< dZeed< dZe	ed< dZ
eed< d	Ze	ed
< dZeeeef ed< dd ZdS )FP8RecipeKwargsa[  
    Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
    training. Please refer to the documentation of this
    [class](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transformer_engine.common.recipe.DelayedScaling)
    for more information on each argument.

    ```python
    from accelerate import Accelerator
    from accelerate.utils import FP8RecipeKwargs

    kwargs = FP8RecipeKwargs(fp8_format="HYBRID")
    accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs])
    ```
    r   marginr   intervalE4M3
fp8_formatamax_history_lenmost_recentamax_compute_algo)FFFoverride_linear_precisionc                 C   s4   | j  | _ | j dkrtd| jdkr0tdd S )N)rJ   ZHYBRIDz(`fp8_format` must be 'E4M3' or 'HYBRID'.)maxrM   z2`amax_compute_algo` must be 'max' or 'most_recent')rK   upper
ValueErrorrN   r   r   r   r   __post_init__   s
    

zFP8RecipeKwargs.__post_init__N)r(   r)   r*   r+   rH   r:   r0   rI   rK   rF   rL   rN   rO   r   r/   rS   r   r   r   r   rG      s   
rG   c                   @   s4   e Zd ZdZdZdZdZdZdZdZ	dZ
d	Zd
ZdS )DistributedTypea  
    Represents a type of distributed environment.

    Values:

        - **NO** -- Not a distributed environment, just a single process.
        - **MULTI_CPU** -- Distributed on multiple CPU nodes.
        - **MULTI_GPU** -- Distributed on multiple GPUs.
        - **MULTI_NPU** -- Distributed on multiple NPUs.
        - **MULTI_XPU** -- Distributed on multiple XPUs.
        - **DEEPSPEED** -- Using DeepSpeed.
        - **TPU** -- Distributed on TPUs.
    NO	MULTI_CPU	MULTI_GPU	MULTI_NPU	MULTI_XPU	DEEPSPEEDFSDPTPUMEGATRON_LMN)r(   r)   r*   r+   rU   rV   rW   rX   rY   rZ   r[   r\   r]   r   r   r   r   rT      s   rT   c                   @   s   e Zd ZdZdZdZdZdS )SageMakerDistributedTypea+  
    Represents a type of distributed environment.

    Values:

        - **NO** -- Not a distributed environment, just a single process.
        - **DATA_PARALLEL** -- using sagemaker distributed data parallelism.
        - **MODEL_PARALLEL** -- using sagemaker distributed model parallelism.
    rU   DATA_PARALLELMODEL_PARALLELN)r(   r)   r*   r+   rU   r_   r`   r   r   r   r   r^      s   r^   c                   @   s   e Zd ZdZdZdZdS )ComputeEnvironmentz
    Represents a type of the compute environment.

    Values:

        - **LOCAL_MACHINE** -- private/custom cluster hardware.
        - **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment.
    LOCAL_MACHINEAMAZON_SAGEMAKERN)r(   r)   r*   r+   rb   rc   r   r   r   r   ra      s   
ra   c                   @   s<   e Zd ZdZdZdZdZdZdZdZ	dZ
d	Zd
ZdZdZdS )DynamoBackenday  
    Represents a dynamo backend (see https://github.com/pytorch/torchdynamo).

    Values:

        - **NO** -- Do not use torch dynamo.
        - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo
          issues.
        - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's
          extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.
        - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton
          kernels. [Read
          more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747)
        - **NVFUSER** -- nvFuser with TorchScript. [Read
          more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
        - **AOT_NVFUSER** -- nvFuser with AotAutograd. [Read
          more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
        - **AOT_CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read
          more](https://github.com/pytorch/torchdynamo/pull/757)
        - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read
          more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html)
        - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read
          more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst)
        - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/)
        - **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read
          more](https://github.com/intel/intel-extension-for-pytorch).

    rU   EAGER	AOT_EAGERINDUCTORNVFUSERAOT_NVFUSERAOT_CUDAGRAPHSOFIFX2TRTONNXRTIPEXN)r(   r)   r*   r+   rU   re   rf   rg   rh   ri   rj   rk   rl   rm   rn   r   r   r   r   rd     s   rd   c                   @   s   e Zd ZdZdd ZdS )EnumWithContainsz\A metaclass that adds the ability to check if `self` contains an item with the `in` operatorc                 C   s(   z| | W n t k
r"   Y dS X dS )NFT)rR   )clsitemr   r   r   __contains__5  s
    zEnumWithContains.__contains__N)r(   r)   r*   r+   rr   r   r   r   r   ro   2  s   ro   c                   @   s$   e Zd ZdZdd Zedd ZdS )BaseEnumzDAn enum class that can get the value of an item with `str(Enum.key)`c                 C   s   | j S r   )valuer   r   r   r   __str__@  s    zBaseEnum.__str__c                 C   s   t tt| S )z.Method to list all the possible items in `cls`)listmaprF   )rp   r   r   r   rv   C  s    zBaseEnum.listN)r(   r)   r*   r+   ru   classmethodrv   r   r   r   r   rs   =  s   rs   )	metaclassc                   @   s(   e Zd ZdZdZdZdZdZdZdZ	dS )	
LoggerTypeaI  Represents a type of supported experiment tracker

    Values:

        - **ALL** -- all available trackers in the environment that are supported
        - **TENSORBOARD** -- TensorBoard as an experiment tracker
        - **WANDB** -- wandb as an experiment tracker
        - **COMETML** -- comet_ml as an experiment tracker
    allZaimZtensorboardZwandbZcomet_mlZmlflowN)
r(   r)   r*   r+   ALLZAIMZTENSORBOARDZWANDBZCOMETMLZMLFLOWr   r   r   r   rz   I  s   
rz   c                   @   s    e Zd ZdZdZdZdZdZdS )PrecisionTypezRepresents a type of precision used on floating point values

    Values:

        - **NO** -- using full precision (FP32)
        - **FP16** -- using half precision
        - **BF16** -- using brain floating point precision
    nofp8fp16bf16N)r(   r)   r*   r+   rU   FP8ZFP16ZBF16r   r   r   r   r}   \  s
   	r}   c                   @   s$   e Zd ZdZdZdZdZdZdZdS )RNGTypetorchcudaZnpuZxlaZxpu	generatorN)	r(   r)   r*   ZTORCHCUDAZNPUZXLAZXPU	GENERATORr   r   r   r   r   l  s   r   c                   @   s   e Zd ZdZdZdZdS )CustomDtypezd
    An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`.
    r   Zint4N)r(   r)   r*   r+   r   INT4r   r   r   r   r   u  s   r   c                   @   s"   e Zd ZU ejed< ejed< dS )TensorInformationshapedtypeN)r(   r)   r*   r   Sizer0   r   r   r   r   r   r     s   

r   c                   @   s   e Zd ZU dZedddidZeed< edddidZeed< ed	dd
idZ	e
ed< edddidZeed< edddidZeed< dedddZdd ZdS )ProjectConfigurationzP
    Configuration for the Accelerator object based on inner-project needs.
    Nhelpz'A path to a directory for storing data.defaultmetadataproject_dirziA path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`.logging_dirFz?Whether saved states should be automatically iteratively named.automatic_checkpoint_namingz1The maximum number of total saved states to keep.total_limitr   zThe current save iteration.	iteration)r   c                 C   s   || _ | jdkr|| _dS )zISets `self.project_dir` and `self.logging_dir` to the appropriate values.N)r   r   )r   r   r   r   r   set_directories  s    
z$ProjectConfiguration.set_directoriesc                 C   s   |  | j d S r   )r   r   r   r   r   r   rS     s    z"ProjectConfiguration.__post_init__)N)r(   r)   r*   r+   r   r   rF   r0   r   r   r/   r   r:   r   r   rS   r   r   r   r   r     s,   
 r   c                   @   sZ   e Zd ZU dZedddidZeed< edddidZe	ed	< eddd
idZ
e	ed< dS )GradientAccumulationPluginz?
    A plugin to configure gradient accumulation behavior.
    Nr   z0The number of steps to accumulate gradients for.r   	num_stepsTzWhether to adjust the scheduler steps to account for the number of steps being accumulated. Should be `True` if the used scheduler was not adjusted for gradient accumulation.adjust_schedulerzWhether to synchronize setting the gradients when at the end of the dataloader. Should only be set to `False` if you know what you're doing.sync_with_dataloader)r(   r)   r*   r+   r   r   r:   r0   r   r/   r   r   r   r   r   r     s   
  r   c                   @   s   e Zd ZU dZeddddd eD  idZeed< eddd	idZe	ed
< edddidZ
eed< edddidZeed< edddidZeed< edddidZeed< dd Zdd ZdS )TorchDynamoPluginzA
    This plugin is used to compile a model with PyTorch 2.0
    Nr   zPossible options are c                 C   s   g | ]}|j  qS r   )rt   lower)r   br   r   r   
<listcomp>  s     zTorchDynamoPlugin.<listcomp>r   rB   zCPossible options are 'default', 'reduce-overhead' or 'max-autotune'modez6Whether it is ok to break model into several subgraphs	fullgraphz(Whether to use dynamic shape for tracingdynamicz/A dictionary of options to pass to the backend.optionsFz-Turn torch.compile() into a no-op for testingdisablec                 C   s   d}| j d kr"tj|d d| _ t| j  | _ | jd krPtj|d d| _| jd krvttj|d ddk| _| j	d krttj|d	 ddk| _	d S )
NZACCELERATE_DYNAMO_ZBACKENDr~   ZMODEr   ZUSE_FULLGRAPHFalser   ZUSE_DYNAMIC)
rB   osenvirongetrd   rQ   r   r   r   r   r   prefixr   r   r   rS     s    



zTorchDynamoPlugin.__post_init__c                 C   s"   t | j}|d j |d< |S )NrB   )r   r   r   rt   r   )r   Zdynamo_configr   r   r   r     s    zTorchDynamoPlugin.to_dict)r(   r)   r*   r+   r   rd   rB   r0   r   rF   r   r/   r   r   r   r   rS   r   r   r   r   r   r     s   
 r   c                   @   sd  e Zd ZU dZedddidZeed< edddidZe	ed< eddd	idZ
eed
< edddidZe	ed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< dd Zd-ddZd.d d!Zd"d# Zd$d% Zd&d' Zed/d)d*Zd+d, ZdS )0DeepSpeedPluginz5
    This plugin is used to integrate DeepSpeed.
    Nr   zkpath to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`.r   hf_ds_configzNumber of steps to accumulate gradients before updating optimizer states. If not set, will use the value from the `Accelerator` directly.gradient_accumulation_stepsz#Enable gradient clipping with valuegradient_clippingzMPossible options are 0,1,2,3; Default will be taken from environment variable
zero_stageTzUIf both train & eval dataloaders are specified, this will decide the train_batch_sizeis_train_batch_minzMPossible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.offload_optimizer_devicezFPossible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.offload_param_devicezJPossible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.offload_optimizer_nvme_pathoffload_param_nvme_pathz{Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models.Only applicable with ZeRO Stage-3.zero3_init_flagzQFlag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.zero3_save_16bit_modelc           
   	      s(  ddl m}  jd kr:tjdd}| r4t|n| _ jd krdtjdd}|dkrdt	| _ j
d krttjdd _
 jd krtjd	d _ jd krtjd
d _ jd krtjdd _ jd krtjdd _ jd kr
tjdddk _ jd kr&tjdd _t jts\t jtrN jdks\t j|rLt j|sv| j _d jjkrd jjd< d jjkrtd   ddddddddd} fdd| D }| D ]} j|f|ddi q j  | D ]6\}} j|}|d k	r|dkrt || qnndd j j
 j jd krn jnd d! j jd kr jnd d! jd"d#}	 jr j|	d< ||	 _ jj _t	d$ jd%<  jd kr ttjd&t j  dk _ jr$ j  s$t!"d' d _d S )(Nr   HfDeepSpeedConfig&ACCELERATE_GRADIENT_ACCUMULATION_STEPSautoACCELERATE_GRADIENT_CLIPPINGnoneACCELERATE_DEEPSPEED_ZERO_STAGE   -ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE)ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE0ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH,ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH+ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODELfalsetrueZ ACCELERATE_DEEPSPEED_CONFIG_FILEr   zero_optimizationzDPlease specify the ZeRO optimization config in the DeepSpeed config.r   zzero_optimization.stagez*zero_optimization.offload_optimizer.devicez&zero_optimization.offload_param.devicez)zero_optimization.offload_param.nvme_pathz-zero_optimization.offload_optimizer.nvme_pathz;zero_optimization.stage3_gather_16bit_weights_on_model_save)r   r   r   r   r   r   r   r   c                    s*   i | ]"\}}t  |d k	r|t  |qS r   )getattrr   r   r   r   r#   W  s       z1DeepSpeedPlugin.__post_init__.<locals>.<dictcomp>
must_matchFZnvme)deviceZ	nvme_path)ZstageZoffload_optimizerZoffload_paramZ)stage3_gather_16bit_weights_on_model_save)train_batch_sizetrain_micro_batch_size_per_gpur   r   infZsteps_per_printZACCELERATE_DEEPSPEED_ZERO3_INITzSDeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.)#Z	deepspeedr   r   r   r   r   isdigitr:   r   r@   r   r   r   r   r   r   r   
isinstancedictrF   configrR   _deepspeed_config_checksr&   keys
fill_matchZset_stage_and_offload	get_valuesetattrdeepspeed_configr   r   Zis_zero3warningswarn)
r   r   gasr   Zplugin_to_config_mappingkwargskeyrt   Zconfig_valuer   r   r   r   rS     s    






 










zDeepSpeedPlugin.__post_init__c                 K   s   |d krg n|}| j |\}}|d kr,d S ||dkrh||krR|| ||< d S td| d| d|spd S ||}|d k	r||kr||| kr|d| d| d| d||   d S )Nr   `z'` not found in kwargs. Please specify `zY` without `auto`(set to correct value) in the DeepSpeed config file or pass it in kwargs.z- ds =z vs arg )r   Zfind_config_noder   rR   append)r   Zds_key_long
mismatchesr   r   r   Zds_keyZds_valr   r   r   r     s"    
zDeepSpeedPlugin.fill_match c           	      K   s   |dkrg n|}|dkr| j }| D ]R\}}t|tr\| jf || d |||d| q&| j|| |fd|i| q&t|dkr|dkrd|}td| d	dS )
z=Process the DeepSpeed config with the values from the kwargs.N.)r   r   r   r   r   r   r   
zSPlease correct the following DeepSpeed config values that mismatch kwargs  values:
zF
The easiest method is to set these DeepSpeed config values to 'auto'.)	r   r&   r   r   deepspeed_config_processr   lenjoinrR   )	r   r   r   r   r   r   r   rt   Zmismatches_msgr   r   r   r     s&    

   

z(DeepSpeedPlugin.deepspeed_config_processc                 C   s   | j }|dk|dkd}|dkr8d|krTddd|d< n|dkrTd|krTddi|d< |dkr|dkrhdnd}t||i dd d	krtd
| d| ddD ]}||krddi||< q| jdddi| | jdddi| d S )Nr   r   )fp16.enabledbf16.enabledT)r-   Z	auto_castr-   r~   r   r   z*`--mixed_precision` arg cannot be set to `z` when `z&` is set in the DeepSpeed config file.)r   r   Fr   r   r   )r   )r   )r   rF   r   r   rR   r   )r   mixed_precision	ds_configr   Z
diff_dtyper   r   r   r   set_mixed_precision  s*     z#DeepSpeedPlugin.set_mixed_precisionc                 C   s   ddl m} | jr| s tdt| j}d|ks@|d dkrHd|d< d|ks\|d dkrdd|d< |d dkrv|d= tdd	d
rddlm	} nddl
m	} ||| _d S )Nr   )is_transformers_availablezoWhen `zero3_init_flag` is set, it requires Transformers to be installed. Please run `pip install transformers`.r   r   r   r   Ztransformers<z4.33r   r   )Zimportsr   r   	Exceptionr   r   r   r   Ztransformers.deepspeedr   Ztransformers.integrationsdschf)r   r   r   r   r   r   r   set_deepspeed_weakref  s(    
z%DeepSpeedPlugin.set_deepspeed_weakrefc                 C   s   | j S r   )r   r   r   r   r   is_zero3_init_enabled  s    z%DeepSpeedPlugin.is_zero3_init_enabledFc                 c   sH   | j }||krd V  n.|| _ d | _|   d V  || _ d | _|   d S r   )r   r   r   )r   enableoldr   r   r   zero3_init_context_manager  s    z*DeepSpeedPlugin.zero3_init_context_managerc              	      sb   ddddddddd	g	 d
d  D  t jddd}t fdd|D r^td  dd S )Nr   r   r   r   r   r   r   r   ZACCELERATE_MIXED_PRECISIONc                 S   s$   g | ]}| d d dd qS )ZACCELERATE_r   Z
DEEPSPEED_)replacer   r   namer   r   r   r     s    z<DeepSpeedPlugin._deepspeed_config_checks.<locals>.<listcomp>ZACCELERATE_CONFIG_DS_FIELDSr   ,c                 3   s   | ]}| kV  qd S r   r   r   Zenv_variable_names_to_ignorer   r   	<genexpr>  s     z;DeepSpeedPlugin._deepspeed_config_checks.<locals>.<genexpr>z_When using `deepspeed_config_file`, the following accelerate config variables will be ignored: am  .
Please specify them appropriately in the DeepSpeed config file.
If you are using an accelerate config file, remove others config variables mentioned in the above specified list.
The easiest method is to create a new config following the questionnaire via `accelerate config`.
It will only ask for the necessary config variables when using `deepspeed_config_file`.)r   r   r   splitanyrR   )r   Z'deepspeed_fields_from_accelerate_configr   r   r   r     s$    
z(DeepSpeedPlugin._deepspeed_config_checks)NT)r   NNT)F)r(   r)   r*   r+   r   r   r   r0   r   r:   r   r@   r   r   rF   r   r/   r   r   r   r   r   rS   r   r   r   r   r   r   r   r   r   r   r   r   r     sr   
   h

r   c                   @   s  e Zd ZU dZedddidZded< edddidZded	< eddd
idZded< edddidZ	e
e ed< edddidZded< edddidZe
eejj  ed< edddidZded< edddidZded< edddidZded< edddidZeed< edddidZeed< edddidZe
eejjgdf  ed< eddd idZeed!< eddd"idZeed#< eddd$idZeed%< d&d' Zed(d) Zd*d+ Zd,d- Z d.d/ Z!dS )0FullyShardedDataParallelPluginzG
    This plugin is used to enable fully sharded data parallelism.
    Nr   zdFSDP Sharding Strategy of type `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`r   z
typing.Anysharding_strategyzdFSDP Backward Prefetch of type `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`backward_prefetcha  A config to enable mixed precision training with FullyShardedDataParallel. The 3 flags that are set are `param_dtype`, `reduce_dtype`, `buffer_dtype`. Each flag expects `torch.dtype` as the value. It is of type `torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision`.mixed_precision_policyzCA callable specifying a policy to recursively wrap layers with FSDPauto_wrap_policyzDecides Whether to offload parameters and gradients to CPU. It is of type `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload`.cpu_offloadz%A list of modules to ignore for FSDP.ignored_modulesz_FSDP State Dict Type of type `torch.distributed.fsdp.fully_sharded_data_parallel.StateDictType`state_dict_typezcFSDP State Dict Config of type `torch.distributed.fsdp.fully_sharded_data_parallel.StateDictConfig`state_dict_configzrFSDP Optimizer State Dict Config of type `torch.distributed.fsdp.fully_sharded_data_parallel.OptimStateDictConfig`optim_state_dict_configFaR  If False, then FSDP allows the CPU thread to schedule all-gathers without any extra synchronization. If True, then FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries.limit_all_gathersaB  If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019)use_orig_paramszA Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device.param_init_fnTzIf True, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initializationsync_module_stateszIf True, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. only use with Static graphs.forward_prefetchzIf True, activation checkpointing is a technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage.activation_checkpointingc                 C   sd  ddl m}m}m} d}| jd kr>|ttj|d d| _| j	d kr|t
tj|d ddkrp|dd	| _	n|d
d	| _	| jd krtj|d d}|td kr|t|d | _| jd krtj|d d}| | t
tj|d ddk| _t
tj|d ddk| _t
tj|d ddk| _t
tj|d ddk| _| jr`dd | _d S )Nr   )BackwardPrefetch
CPUOffloadShardingStrategyZFSDP_ZSHARDING_STRATEGYr   ZOFFLOAD_PARAMSr   T)Zoffload_paramsFZBACKWARD_PREFETCHZNO_PREFETCHZSTATE_DICT_TYPEFULL_STATE_DICTZUSE_ORIG_PARAMSZSYNC_MODULE_STATESTrueZFORWARD_PREFETCHZACTIVATION_CHECKPOINTINGc                 S   s   | j tj ddS )NF)r   recurse)Zto_emptyr   r   Zcurrent_device)xr   r   r   <lambda>      z>FullyShardedDataParallelPlugin.__post_init__.<locals>.<lambda>)2torch.distributed.fsdp.fully_sharded_data_parallelr  r  r  r   r:   r   r   r   r  r   r   r   indexr  set_state_dict_typer	  r  r  r  r
  )r   r  r  r  r   Zprefetch_policystate_dict_type_policyr   r   r   rS   |  s*    




z,FullyShardedDataParallelPlugin.__post_init__c                 C   sX   t |  }| jj|kr| jS t|dkr.dS |D ] }t||}|dk	r2|  S q2dS )z
        Gets a class from a module by its name.

        Args:
            module (`torch.nn.Module`): The module to get the class from.
            name (`str`): The name of the class.
        r   N)rv   childrenr%   r(   r   r   get_module_class_from_name)moduler   Zmodules_childrenZchild_moduleZmodule_classr   r   r   r    s    	z9FullyShardedDataParallelPlugin.get_module_class_from_namec                 C   s   ddl m}m} t|dd d k	r,d|jnd}| jd krtj	dd}|t
d krtj	d|d}t }|D ],}t||}	|	d krtd	qr||	 qrtj||d
| _n6|t
d krttj	dd}
|
dkrtj||
d| _d S )Nr   )size_based_auto_wrap_policytransformer_auto_wrap_policy_no_split_modulesr   r   r   ZNO_WRAPZFSDP_TRANSFORMER_CLS_TO_WRAPz@Could not find the transformer layer class to wrap in the model.)Ztransformer_layer_clsr   ZFSDP_MIN_NUM_PARAMS)min_num_params)Ztorch.distributed.fsdp.wrapr  r   r   r   r!  r  r   r   r   r   r   setr   r  r   add	functoolspartialr:   )r   modelr  r   Z%default_transformer_cls_names_to_wrapr  Ztransformer_cls_names_to_wrapZtransformer_cls_to_wrapZlayer_classZtransformer_clsr"  r   r   r   set_auto_wrap_policy  s:    
 

 z3FullyShardedDataParallelPlugin.set_auto_wrap_policyc                 C   sX   |dkrt j}n|dkr t j}ntd| ddlm} | jd krT||||d| _d S )Nr   r   zUnknown mixed precision value: r   )MixedPrecision)Zparam_dtypeZreduce_dtypeZbuffer_dtype)r   float16bfloat16rR   r  r)  r  )r   r   r   r)  r   r   r   r     s    
z2FullyShardedDataParallelPlugin.set_mixed_precisionc                 C   sh   ddl m}m}m} |t|d | _| j|jkrd| jd krL|ddd| _| j	d krd|ddd| _	d S )Nr   )FullOptimStateDictConfigFullStateDictConfigStateDictTyper   T)Zoffload_to_cpuZ
rank0_only)
r  r,  r-  r.  r   r  r  r  r  r  )r   r  r,  r-  r.  r   r   r   r    s    

z2FullyShardedDataParallelPlugin.set_state_dict_type)"r(   r)   r*   r+   r   r   r0   r   r  r  r   r   r  r  r	   r   nnModuler  r  r  r  r/   r	  r
  r  r  r  rS   staticmethodr  r(  r   r  r   r   r   r   r     s   
   	     	      	
 r   c                   @   s<  e Zd ZU dZedddidZeed< edddidZeed< eddd	idZ	eed
< edddidZ
eed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed < eddd!idZeed"< eddd#idZeed$< ed%dd&idZeed'< eddd(idZeed)< eddd*idZeed+< eddd,idZeed-< eddd.idZeed/< eddd0idZeed1< ed2dd3idZeed4< eddd5idZ e!e ed6< eddd7idZ"e#e$ ed8< eddd9idZ%e#e$ ed:< ed;dd<idZ&eed=< ed>dd?idZ'eed@< edddAidZ(eedB< edddCidZ)eedD< edddEidZ*eedF< edddGidZ+eedH< ed>ddIidZ,eedJ< edKddLidZ-eedM< edNddOidZ.eedP< ed>ddQidZ/eedR< edddSidZ0e#e1 edT< edddUidZ2e#e3ee1f  edV< edddWidZ4e#e$ edX< edddYidZ5e#e$ edZ< eddd[idZ6e#e3ee1f  ed\< d]d^ Z7dkd_d`Z8dadb Z9dcdd Z:dedf Z;dgdh Z<didj Z=dS )lMegatronLMPluginz
    Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective
    activation recomputation and optimized fused kernels.
    Nr   ztensor parallelism degree.r   	tp_degreezpipeline parallelism degree.	pp_degreeznumber of micro-batches.num_micro_batchesz>gradient clipping value based on global L2 Norm (0 to disable)r   zenable sequence parallelismsequence_parallelismz)enable selective activation recomputationrecompute_activationzenable distributed optimizeruse_distributed_optimizerz/Rank where encoder and decoder should be split."pipeline_model_parallel_split_rankz,Number of layers per virtual pipeline stage.%num_layers_per_virtual_pipeline_stageTzUIf both train & eval dataloaders are specified, this will decide the micro_batch_sizer   zTotal number of iterations to train over all training runs. Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`train_iterszTotal number of samples to train over all training runs. Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`train_samplesZconstantzKWeight decay increment function. choices=["constant", "linear", "cosine"]. weight_decay_incr_stylez7Initial weight decay coefficient for L2 regularization.start_weight_decayz:End of run weight decay coefficient for L2 regularization.end_weight_decayZlinearzGLearning rate decay function. choices=['constant', 'linear', 'cosine'].lr_decay_stylezPNumber of iterations for learning rate decay. If None defaults to `train_iters`.lr_decay_iterszONumber of samples for learning rate decay. If None defaults to `train_samples`.lr_decay_samplesz;number of iterations to linearly warmup learning rate over.lr_warmup_itersz8number of samples to linearly warmup learning rate over.lr_warmup_sampleszLfraction of lr-warmup-(iters/samples) to linearly warmup learning rate over.lr_warmup_fractionr   zPMinumum value for learning rate. The scheduler clip values below this threshold.min_lrz^Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call.consumed_samplesz"Condition to disable weight decay.no_wd_decay_condz!Condition to scale learning rate.scale_lr_cond      ?zLearning rate multiplier.lr_multFzUWhether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format.megatron_dataset_flagz#Maximum sequence length to process.
seq_lengthz3Maximum sequence length to process for the encoder.encoder_seq_lengthz3Maximum sequence length to process for the decoder.decoder_seq_lengthzPath to save tensorboard logs.tensorboard_dirz#Whether to set all logging options.set_all_logging_optionsd   z?Number of iterations to run for evaluation validation/test for.
eval_itersi  z6Interval between running evaluation on validation set.eval_intervalz(Whether to return logits from the model.return_logitszCustom train step class.custom_train_step_classzCustom train step kwargs.custom_train_step_kwargszCustom model provider function.custom_model_provider_functionzCustom prepare model function.custom_prepare_model_functionz5Other Megatron-LM arguments. Please refer Megatron-LMother_megatron_argsc                 C   s
  d}| j d kr&ttj|d d| _ | jd krHttj|d d| _| jd krjttj|d d| _| jd krttj|d d| _| j	d krt
tj|d d	dk| _	| jd krt
tj|d
 d	dk| _| jd krt
tj|d d	dk| _| jdks| jrd| _nd| _| jd k	rjt| jdkrN| jddg nt| jdkrj| jd | j | j| j| j| j| j| j| j| j| j| j| j| j| j| j| jd| _| j	rd| jd< | jd k	r| j| jd< | jr|   | jd k	r| j| j d S )NZMEGATRON_LM_Z	TP_DEGREEr   Z	PP_DEGREEZNUM_MICRO_BATCHESZGRADIENT_CLIPPINGrJ  ZRECOMPUTE_ACTIVATIONr   ZUSE_DISTRIBUTED_OPTIMIZERZSEQUENCE_PARALLELISMlocalr   r   r   )Ztensor_model_parallel_sizeZpipeline_model_parallel_sizer9  r:  DDP_implr8  Zsequence_parallelZ	clip_gradr5  rG  rH  rI  rK  rL  rS  rT  Z	selectiveZrecompute_granularityrP  ) r3  r:   r   r   r   r4  r5  r   r@   r7  r   r8  r6  r\  rG  r   extendr   r9  r:  rH  rI  rK  rL  rS  rT  megatron_lm_default_argsrP  rQ  set_tensorboard_logging_optionsrZ  updater   r   r   r   rS     sb    







zMegatronLMPlugin.__post_init__c                 C   s  d|j j krd}|j j}|j j}|j j}|j j}|j j}|j j}	d|j	j
 krXd}
| jd k	r| jd k	rvtd | j| _n4| jd k	r| j| _n |d k	r|d jd | _n|| _| j| jd< nd	|j j krd
}|j j}|j j}|j j}|j j}|j j}	d}
| jd k	r2| jd k	r(td | j| _n8| jd k	rH| j| _n"|d k	rd|d jd | _n|| _| j| jd< | j| jd< d| jd< nd|j j krXd}|j j}|j j}|j j}t|j dr|j jnd}|j j}	d}
| jd kr|d k	r
|d jd | _n|| _| jd kr>|d k	r8|d jd | _n|| _| j| jd< | j| jd< ntd|| jd< || jd< || jd< || jd< || jd< |
| jd< |	| jd< |j j| jd< |dkr|| jd< d S )Nzmegatron-bertZbertZmaskedlmTzOBoth `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.Z	input_idsr   rM  Zgpt2ZgptzOBoth `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.rU  ZGPT2BPETokenizerZtokenizer_typeZt5n_positionsi   labelsrN  rO  u   🤗 Accelerate Megatron-LM integration supports only BERT, GPT and T5 model. Please check the model you are using is one of those.model_type_name
num_layershidden_sizenum_attention_headsmax_position_embeddingspretraining_flagorig_vocab_sizeZmodel_return_dict
num_labels)r   Z
model_typer   Znum_hidden_layersre  rf  rg  rj  Z
vocab_sizer%   r(   rM  rN  r   r   r   r^  Zn_layerZn_embdZn_headra  rO  rU  rd  Zd_modelZ	num_headshasattrrR   Zreturn_dict)r   r'  Z
batch_datarc  rd  re  rf  rg  rj  ri  rh  r   r   r   set_network_size_args  s    



















z&MegatronLMPlugin.set_network_size_argsc                 C   s<   |dkrd| j d< n$|dkr8d| j d< d| _| j| j d< d S )Nr   Tr   r[  r\  )r^  r\  )r   r   r   r   r   r     s    
z$MegatronLMPlugin.set_mixed_precisionc                 C   sD   || _ || _|| | j | _| j | jd< | j| jd< | j| jd< d S )Ndata_parallel_sizemicro_batch_sizeglobal_batch_size)rm  rn  r5  ro  r^  )r   rn  Z	dp_degreer   r   r   set_training_args  s    z"MegatronLMPlugin.set_training_argsc                 C   s   |j j }d|krXd| jd< |jd d | jd< |jd d | jd< |jd | jd	< n4d
|kr|d
| jd< |jd | jd< ntd| d|jd | jd< |jd | jd< d S )NZadam	optimizerZbetasr   Z
adam_beta1r   Z
adam_beta2ZepsZadam_epsZsgdZmomentumZsgd_momentumz
Optimizer z  is not supported by Megatron-LMlrZweight_decay)r%   r(   r   r^  defaultsrR   )r   rq  Zoptimizer_namer   r   r   set_optimizer_type%  s    

z#MegatronLMPlugin.set_optimizer_typec                 C   s   | j d kr6|j| jd  | _ | jd k	r6d | _td | jd krl|j| jd  | _| jd k	rftd d| _| j | jd< | j| jd< | j| jd< | j| jd< | j	| jd	< | j
| jd
< | j| jd< | j| jd< | j| jd< | j| jd< | j| jd< | j| jd< d S )Nrm  zXIgnoring `train_samples` as `train_iters` based on scheduler is being used for training.z`Ignoring `lr_warmup_samples` as `lr_warmup_iters` based on scheduler is being used for training.r   r;  rC  r<  rD  rA  rB  rE  r@  r=  r>  r?  rF  )r;  Ztotal_num_stepsr^  r<  r   r   rC  Zwarmup_num_stepsrD  rA  rB  rE  r@  r=  r>  r?  rF  )r   Z	schedulerr   r   r   set_scheduler_args5  s4    



z#MegatronLMPlugin.set_scheduler_argsc                 C   s|   ddl m} t }||}| }t|d | _| j D ]:\}}|drZd| j	|< q<|dr<d| j	|
dd< q<d S )Nr   )_add_logging_argsZlog_TZno_log_Zno_r   )Zmegatron.argumentsrv  argparseArgumentParserparse_known_argsvarsZdataset_argsr&   
startswithr^  r   )r   rv  parserZlogging_argsr   rt   r   r   r   r_  R  s    

z0MegatronLMPlugin.set_tensorboard_logging_options)N)>r(   r)   r*   r+   r   r3  r:   r0   r4  r5  r   r@   r6  r/   r7  r8  r9  r:  r   rF   r;  r<  r=  r>  r?  r@  rA  rB  rC  rD  rE  rF  rG  r
   rH  r   r   rI  rK  rL  rM  rN  rO  rP  rQ  rS  rT  rU  rV  r   rW  r   rX  rY  rZ  rS   rl  r   rp  rt  ru  r_  r   r   r   r   r2    s8  
        9
Sr2  c                   @   s   e Zd ZU dZedddidZeed< edddidZe	ed	< eddd
idZ
eed< edddidZeed< edddidZeed< edddidZeed< edddidZejed< edddidZee ed< edddidZee ed< dd ZdS )BnbQuantizationConfigzD
    A plugin to enable BitsAndBytes 4bit and 8bit quantization
    Fr   zenable 8bit quantization.r   load_in_8bit      @zEvalue of the outliner threshold. only relevant when load_in_8bit=Truellm_int8_thresholdzenable 4bit quantization.load_in_4bitfp4z\set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}.bnb_4bit_quant_typezlenable nested quantization where the quantization constants from the first quantization are quantized again.bnb_4bit_use_double_quantr   zThis sets the computational type which might be different than the input time. For example, inputs might be fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}.bnb_4bit_compute_dtypeNzthis sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the valueto `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model torch_dtypezian explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`.skip_moduleszXan explicit list of the modules that we don't quantize. We keep them in `torch.float32`.keep_in_fp32_modulesc                 C   sV  t | jtstdt | jts(td| jr<| jr<td| jsP| jsPtdt | jttfshtdt | jt	s~tdn| jdkrtd| j t | j
tstd	t | jt	r| jd
krtj| _n8| jdkrtj| _n$| jdkrtj| _ntd| j nt | jtjs td| jdk	rBt | jtsBtd| jdk	rdt | jtsdtd| jrttj| _| jrtj| _| jr| jdkrtd t | jt	r| jd
krtj| _n<| jdkrtj| _n&| jdkrtj| _ntd| j | jr| jdkrtj| _| jr:| jdkr:| j| _t | jtjsRtddS )z~
        Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
        zload_in_8bit must be a booleanzload_in_4bit must be a booleanz-load_in_4bit and load_in_8 can't be both Truez.load_in_4bit and load_in_8 can't be both Falsez,llm_int8_threshold must be a float or an intz$bnb_4bit_quant_type must be a string)r  Znf4z7bnb_4bit_quant_type must be in ['fp4','nf4'] but found z+bnb_4bit_use_double_quant must be a booleanZfp32r   r   zCbnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found z8bnb_4bit_compute_dtype must be a string or a torch.dtypeNz&skip_modules must be a list of stringsz/keep_in_fp_32_modules must be a list of stringsr  z<llm_int8_threshold can only be used for model loaded in 8bitz8torch_dtype must be in ['fp32','fp16','bf16'] but found z!torch_dtype must be a torch.dtype)r   r~  r/   rR   r  r  r:   r@   r  rF   r  r  r   Zfloat32r*  r+  r   r  rv   r  r   r   Ztarget_dtypeZint8r   r   r  r   r   r   r   rS     sh    












z#BnbQuantizationConfig.__post_init__)r(   r)   r*   r+   r   r~  r/   r0   r  r@   r  r  rF   r  r  r  r   r   r  r
   r  rS   r   r   r   r   r}  `  sT   
      r}  ):r+   rw  r   enumr%  r   typingr   
contextlibr   dataclassesr   r   datetimer   r   r   r   r	   r
   r   r   r   	constantsr   r   r   environmentr   versionsr   r   r,   r1   r;   rA   rG   rF   EnumrT   r^   ra   rd   EnumMetaro   rs   rz   r}   r   r   r   r   r   r   r   r   r2  r}  r   r   r   r   <module>   sr   $#,	%#  - X  x