U
    0-eeo                     @   s  d dl Z d dlZd dlZd dlmZ d dlmZmZmZm	Z	m
Z
 d dlZddlmZ ddlmZ ddlmZmZmZmZmZmZmZ g Ze reej e reej e reej e reej e reej eeZ dd	 Z!d
d Z"G dd dZ#G dd de#Z$G dd de#Z%G dd de#Z&G dd de#Z'G dd de#Z(e'e&e(e$e%dZ)dee
e*ee#f  e
e*ej+f dddZ,dS )    Nwraps)AnyDictListOptionalUnion   )
get_logger)PartialState)
LoggerTypeis_aim_availableis_comet_ml_availableis_mlflow_availableis_tensorboard_availableis_wandb_availablelistifyc                    s   t   fdd}|S )a  
    Decorator to selectively run the decorated function on the main process only based on the `main_process_only`
    attribute in a class.

    Checks at function execution rather than initialization time, not triggering the initialization of the
    `PartialState`.
    c                    s8   t | ddr$t  | f||S  | f||S d S )Nmain_process_onlyF)getattrr   on_main_process)selfargskwargsfunction T/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/accelerate/tracking.pyexecute_on_main_processD   s    z0on_main_process.<locals>.execute_on_main_processr   )r   r   r   r   r   r   ;   s    	r   c                   C   s   t S )z@Returns a list of all supported available trackers in the system)_available_trackersr   r   r   r   get_available_trackersN   s    r   c                   @   sH   e Zd ZdZdZdddZedddZeee	 d	d
dZ
dd ZdS )GeneralTrackera`  
    A base Tracker class to be used for all logging integration implementations.

    Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to
    [`Accelerator`].

    Should implement `name`, `requires_logging_directory`, and `tracker` properties such that:

    `name` (`str`): String representation of the tracker class name, such as "TensorBoard" `requires_logging_directory`
    (`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
    tracking mechanism used by a tracker class (such as the `run` for wandb)

    Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
    other functions should occur on the main process or across all processes (by default will use `True`)
    TFc                 C   s   |sd}t | ds|d7 }t | ds@t|dkr8|d7 }|d7 }dt| krht|dkr`|d7 }|d	7 }t|dkrtd
| d S )N namez`name`requires_logging_directoryr   z, z`requires_logging_directory`trackerz	`tracker`zThe implementation for this tracker class is missing the following required attributes. Please define them in the class definition: )hasattrlendirNotImplementedError)r   Z_blankerrr   r   r   __init__f   s     

zGeneralTracker.__init__valuesc                 C   s   dS )a  
        Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration
        functionality of a tracking API.

        Args:
            values (Dictionary `str` to `bool`, `str`, `float` or `int`):
                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
                `str`, `float`, `int`, or `None`.
        Nr   r   r,   r   r   r   store_init_configuration|   s    
z'GeneralTracker.store_init_configurationr,   stepc                 K   s   dS )a  
        Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with
        special behavior for the `step parameter.

        Args:
            values (Dictionary `str` to `str`, `float`, or `int`):
                Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
        Nr   r   r,   r0   r   r   r   r   log   s    zGeneralTracker.logc                 C   s   dS )z
        Should run any finalizing functions within the tracking API. If the API should not have one, just don't
        overwrite that method.
        Nr   r   r   r   r   finish   s    zGeneralTracker.finishN)F)__name__
__module____qualname____doc__r   r*   dictr.   r   intr2   r4   r   r   r   r   r    S   s   
r    c                       s   e Zd ZdZdZdZeeeee	j
f d fddZedd Zeed	d
dZedeee dddZeeee dddZedd Z  ZS )TensorBoardTrackera  
    A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.

    Args:
        run_name (`str`):
            The name of the experiment run
        logging_dir (`str`, `os.PathLike`):
            Location for TensorBoard logs to be stored.
        kwargs:
            Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.
    tensorboardTrun_namelogging_dirc                    s   zddl m} W n tk
r,   dd l}Y nX t   || _tj	||| _
|j| j
f|| _td| j d| j
  td d S )Nr   )r<   z Initialized TensorBoard project z logging to aMake sure to log any initial configurations with `self.store_init_configuration` before training!)Ztorch.utilsr<   ModuleNotFoundErrorZtensorboardXsuperr*   r>   ospathjoinr?   ZSummaryWriterwriterloggerdebug)r   r>   r?   r   r<   	__class__r   r   r*      s    
zTensorBoardTracker.__init__c                 C   s   | j S NrF   r3   r   r   r   r$      s    zTensorBoardTracker.trackerr+   c              	   C   s   | j j|i d | j   t }tj| jt|}tj	|dd t
tj|dd<}zt|| W n$ tjjk
r   td  Y nX W 5 Q R X td dS )	a  
        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
        hyperparameters in a yaml file for future use.

        Args:
            values (Dictionary `str` to `bool`, `str`, `float` or `int`):
                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
                `str`, `float`, `int`, or `None`.
        )Zmetric_dictT)exist_okzhparams.ymlwz-Serialization to store hyperparameters failedzQStored initial configuration hyperparameters to TensorBoard and hparams yaml fileN)rF   Zadd_hparamsflushtimerC   rD   rE   r?   strmakedirsopenyamldumpZrepresenterZRepresenterErrorrG   errorrH   )r   r,   Zproject_run_namedir_nameoutfiler   r   r   r.      s    

z+TensorBoardTracker.store_init_configurationNr/   c                 K   s   t |}| D ]|\}}t|ttfrB| jj||fd|i| qt|trh| jj||fd|i| qt|t	r| jj
||fd|i| q| j  td dS )a  
        Logs `values` to the current run.

        Args:
            values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
                Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
                `str` to `float`/`int`.
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
            kwargs:
                Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
                `SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
        global_stepz"Successfully logged to TensorBoardN)r   items
isinstancer:   floatrF   Z
add_scalarrQ   add_textr9   Zadd_scalarsrO   rG   rH   r   r,   r0   r   kvr   r   r   r2      s    


zTensorBoardTracker.logc                 K   s:   |  D ]"\}}| jj||fd|i| qtd dS )a  
        Logs `images` to the current run.

        Args:
            values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
                Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
            kwargs:
                Additional key word arguments passed along to the `SummaryWriter.add_image` method.
        rY   z)Successfully logged images to TensorBoardN)rZ   rF   Z
add_imagesrG   rH   r^   r   r   r   
log_images   s    zTensorBoardTracker.log_imagesc                 C   s   | j   td dS )z-
        Closes `TensorBoard` writer
        zTensorBoard writer closedN)rF   closerG   rH   r3   r   r   r   r4     s    
zTensorBoardTracker.finish)N)r5   r6   r7   r8   r"   r#   r   rQ   r   rC   PathLiker*   propertyr$   r9   r.   r   r:   r2   ra   r4   __classcell__r   r   rI   r   r;      s    
r;   c                       s   e Zd ZdZdZdZdZeed fddZ	e
dd Zeed	d
dZedeee dddZedeee dddZedeee eee  eee dddZedd Z  ZS )WandBTrackera  
    A `Tracker` class that supports `wandb`. Should be initialized at the start of your script.

    Args:
        run_name (`str`):
            The name of the experiment run.
        kwargs:
            Additional key word arguments passed along to the `wandb.init` method.
    wandbFr>   c                    sP   t    || _dd l}|jf d| ji|| _td| j  td d S )Nr   projectzInitialized WandB project r@   )rB   r*   r>   rg   initrunrG   rH   )r   r>   r   rg   rI   r   r   r*     s    
zWandBTracker.__init__c                 C   s   | j S rK   )rk   r3   r   r   r   r$   (  s    zWandBTracker.trackerr+   c                 C   s&   ddl }|jj|dd td dS )u  
        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.

        Args:
            values (Dictionary `str` to `bool`, `str`, `float` or `int`):
                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
                `str`, `float`, `int`, or `None`.
        r   NT)Zallow_val_changez5Stored initial configuration hyperparameters to WandB)rg   configupdaterG   rH   )r   r,   rg   r   r   r   r.   ,  s    
z%WandBTracker.store_init_configurationNr/   c                 K   s&   | j j|fd|i| td dS )a,  
        Logs `values` to the current run.

        Args:
            values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
                Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
                `str` to `float`/`int`.
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
            kwargs:
                Additional key word arguments passed along to the `wandb.log` method.
        r0   zSuccessfully logged to WandBN)rk   r2   rG   rH   r1   r   r   r   r2   ;  s    zWandBTracker.logc                    sP   ddl  | D ]0\}}| j| fdd|D ifd|i| qtd dS )a  
        Logs `images` to the current run.

        Args:
            values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
                Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
            kwargs:
                Additional key word arguments passed along to the `wandb.log` method.
        r   Nc                    s   g | ]}  |qS r   )ZImage).0imagerg   r   r   
<listcomp>\  s     z+WandBTracker.log_images.<locals>.<listcomp>r0   z#Successfully logged images to WandB)rg   rZ   r2   rG   rH   r^   r   rq   r   ra   L  s    *zWandBTracker.log_images)
table_namecolumnsdata	dataframer0   c           	      K   s6   ddl }||j|||di}| j|fd|i| dS )a  
        Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either
        with `columns` and `data` or with `dataframe`.

        Args:
            table_name (`str`):
                The name to give to the logged table on the wandb workspace
            columns (List of `str`'s *optional*):
                The name of the columns on the table
            data (List of List of Any data type *optional*):
                The data to be logged in the table
            dataframe (Any data type *optional*):
                The data to be logged in the table
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
        r   N)rt   ru   rv   r0   )rg   Tabler2   )	r   rs   rt   ru   rv   r0   r   rg   r,   r   r   r   	log_table_  s    zWandBTracker.log_tablec                 C   s   | j   td dS )z'
        Closes `wandb` writer
        zWandB run closedN)rk   r4   rG   rH   r3   r   r   r   r4   ~  s    
zWandBTracker.finish)N)N)NNNN)r5   r6   r7   r8   r"   r#   r   r   rQ   r*   rd   r$   r9   r.   r   r:   r2   ra   r   r   rx   r4   re   r   r   rI   r   rf     s8   

    
rf   c                       sv   e Zd ZdZdZdZeed fddZe	dd Z
eed	d
dZedeee dddZedd Z  ZS )CometMLTrackeraZ  
    A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.

    API keys must be stored in a Comet config file.

    Args:
        run_name (`str`):
            The name of the experiment run.
        kwargs:
            Additional key word arguments passed along to the `Experiment.__init__` method.
    comet_mlFrh   c                    sP   t    || _ddlm} |f d|i|| _td| j  td d S )Nr   )
Experimentproject_namezInitialized CometML project r@   )rB   r*   r>   rz   r{   rF   rG   rH   )r   r>   r   r{   rI   r   r   r*     s    
zCometMLTracker.__init__c                 C   s   | j S rK   rL   r3   r   r   r   r$     s    zCometMLTracker.trackerr+   c                 C   s   | j | td dS )rl   z7Stored initial configuration hyperparameters to CometMLN)rF   Zlog_parametersrG   rH   r-   r   r   r   r.     s    
z'CometMLTracker.store_init_configurationNr/   c                 K   s   |dk	r| j | | D ]r\}}t|ttfrN| j j||fd|i| qt|trl| j j||f| qt|t	r| j j
|fd|i| qtd dS )a  
        Logs `values` to the current run.

        Args:
            values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
                Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
                `str` to `float`/`int`.
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
            kwargs:
                Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`,
                or `Experiment.log_metrics` method based on the contents of `values`.
        Nr0   zSuccessfully logged to CometML)rF   Zset_steprZ   r[   r:   r\   Z
log_metricrQ   Z	log_otherr9   log_metricsrG   rH   r^   r   r   r   r2     s    

zCometMLTracker.logc                 C   s   | j   td dS )z*
        Closes `comet-ml` writer
        zCometML run closedN)rF   endrG   rH   r3   r   r   r   r4     s    
zCometMLTracker.finish)N)r5   r6   r7   r8   r"   r#   r   rQ   r*   rd   r$   r9   r.   r   r:   r2   r4   re   r   r   rI   r   ry     s   
ry   c                   @   s~   e Zd ZdZdZdZedeee	ee
jf  dddZedd	 Zeed
ddZeeee dddZedd ZdS )
AimTrackera  
    A `Tracker` class that supports `aim`. Should be initialized at the start of your script.

    Args:
        run_name (`str`):
            The name of the experiment run.
        kwargs:
            Additional key word arguments passed along to the `Run.__init__` method.
    aimT.r=   c                 K   sP   || _ ddlm} |f d|i|| _| j | j_td| j   td d S )Nr   )RunrepozInitialized Aim project r@   )r>   r   r   rF   r"   rG   rH   )r   r>   r?   r   r   r   r   r   r*     s    
zAimTracker.__init__c                 C   s   | j S rK   rL   r3   r   r   r   r$     s    zAimTracker.trackerr+   c                 C   s   || j d< dS )
        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.

        Args:
            values (`dict`):
                Values to be stored as initial hyperparameters as key-value pairs.
        ZhparamsNrL   r-   r   r   r   r.     s    	z#AimTracker.store_init_configurationr/   c                 K   s0   |  D ]"\}}| jj|f||d| qdS )a}  
        Logs `values` to the current run.

        Args:
            values (`dict`):
                Values to be logged as key-value pairs.
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
            kwargs:
                Additional key word arguments passed along to the `Run.track` method.
        )r"   r0   N)rZ   rF   track)r   r,   r0   r   keyvaluer   r   r   r2     s    zAimTracker.logc                 C   s   | j   dS )z%
        Closes `aim` writer
        N)rF   rb   r3   r   r   r   r4     s    zAimTracker.finishN)r   )r5   r6   r7   r8   r"   r#   r   rQ   r   r   rC   rc   r*   rd   r$   r9   r.   r:   r2   r4   r   r   r   r   r     s   
"

r   c                
   @   s   e Zd ZdZdZdZedeee	ee
jf  ee ee	eeef ef  ee ee ee dddZedd	 Zeed
ddZeeee dddZedd ZdS )MLflowTrackeru  
    A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script.

    Args:
        experiment_name (`str`, *optional*):
            Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument.
        logging_dir (`str` or `os.PathLike`, defaults to `"."`):
            Location for mlflow logs to be stored.
        run_id (`str`, *optional*):
            If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s
            end time is unset and its status is set to running, but the run’s other attributes (source_version,
            source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument.
        tags (`Dict[str, str]`, *optional*):
            An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a
            run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are
            set on the new run. Environment variable MLFLOW_TAGS has priority over this argument.
        nested_run (`bool`, *optional*, defaults to `False`):
            Controls whether run is nested in parent run. True creates a nested run. Environment variable
            MLFLOW_NESTED_RUN has priority over this argument.
        run_name (`str`, *optional*):
            Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified.
        description (`str`, *optional*):
            An optional string that populates the description box of the run. If a run is being resumed, the
            description is set on the resumed run. If a new run is being created, the description is set on the new
            run.
    mlflowFN)experiment_namer?   run_idtags
nested_runr>   descriptionc                 C   s   t d|}t d|}t d|}t|tr8t|}t d|}dd l}|jd| dd}	t|	dkrt|	d	krt	
d
 |	d j}
n|j|||d}
|j||
||||d| _t	d|  t	d d S )NZMLFLOW_EXPERIMENT_NAMEZMLFLOW_RUN_IDZMLFLOW_TAGSZMLFLOW_NESTED_RUNr   zname = '')Zfilter_stringr	   z?Multiple experiments with the same name found. Using first one.)r"   Zartifact_locationr   )r   experiment_idr>   nestedr   r   zInitialized mlflow experiment r@   )rC   getenvr[   rQ   jsonloadsr   Zsearch_experimentsr&   rG   warningr   Zcreate_experimentZ	start_run
active_runrH   )r   r   r?   r   r   r   r>   r   r   Zexpsr   r   r   r   r*   :  s:    


	zMLflowTracker.__init__c                 C   s   | j S rK   )r   r3   r   r   r   r$   i  s    zMLflowTracker.trackerr+   c              
   C   s   ddl }t| D ]H\}}tt||jjjkrt	d| d| d|jjj d ||= qt| }t
dt||jjjD ]$}|t||||jjj   qtd dS )r   r   Nz)Trainer is attempting to log a value of "z" for key "zJ" as a parameter. MLflow's log_param() only accepts values no longer than z) characters so we dropped this attribute.z6Stored initial configuration hyperparameters to MLflow)r   listrZ   r&   rQ   utilsZ
validationZMAX_PARAM_VAL_LENGTHrG   r   rangeZMAX_PARAMS_TAGS_PER_BATCHZ
log_paramsr9   rH   )r   r,   r   r"   r   Zvalues_listir   r   r   r.   m  s    	"z&MLflowTracker.store_init_configurationr/   c              
   C   st   i }|  D ]B\}}t|ttfr,|||< qtd| dt| d| d qddl}|j||d t	d dS )	a  
        Logs `values` to the current run.

        Args:
            values (`dict`):
                Values to be logged as key-value pairs.
            step (`int`, *optional*):
                The run step. If included, the log will be affiliated with this step.
        z/MLflowTracker is attempting to log a value of "z
" of type z
 for key "zc" as a metric. MLflow's log_metric() only accepts float and int types so we dropped this attribute.r   N)r0   zSuccessfully logged to mlflow)
rZ   r[   r:   r\   rG   r   typer   r}   rH   )r   r,   r0   Zmetricsr_   r`   r   r   r   r   r2     s    
zMLflowTracker.logc                 C   s   ddl }|  dS )z,
        End the active MLflow run.
        r   N)r   Zend_run)r   r   r   r   r   r4     s    zMLflowTracker.finish)NNNNFNN)r5   r6   r7   r8   r"   r#   r   rQ   r   r   rC   rc   r   r   boolr*   rd   r$   r9   r.   r:   r2   r4   r   r   r   r   r     s8          .
r   )r   rz   r   r<   rg   )log_withr?   c                 C   s  g }| dk	rt | ttfs"| g} d| ks4tj| krJdd | D t  }n| D ]}|tkrtt|tst	d| dt  tt|tr|
| qNt|}||krN|t krtt| }t|dr|dkrt	d| d	|
| qNtd
| d qN|S )a  
    Takes in a list of potential tracker types and checks that:
        - The tracker wanted is available in that environment
        - Filters out repeats of tracker types
        - If `all` is in `log_with`, will return all trackers in the environment
        - If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None`

    Args:
        log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
            A list of loggers to be setup for experiment tracking. Should be one or several of:

            - `"all"`
            - `"tensorboard"`
            - `"wandb"`
            - `"comet_ml"`
            - `"mlflow"`
            If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
            also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
        logging_dir (`str`, `os.PathLike`, *optional*):
            A path to a directory for storing logs of locally-compatible loggers.
    Nallc                 S   s   g | ]}t t|tr|qS r   )
issubclassr   r    )ro   or   r   r   rr     s      z#filter_trackers.<locals>.<listcomp>z Unsupported logging capability: z. Choose between r#   zLogging with `z+` requires a `logging_dir` to be passed in.zTried adding logger z+, but package is unavailable in the system.)r[   r   tupler   ALLr   r   r   r    
ValueErrorappendLOGGER_TYPE_TO_CLASSrQ   r   rG   rH   )r   r?   loggersZlog_typeZtracker_initr   r   r   filter_trackers  s.    



r   )N)-r   rC   rP   	functoolsr   typingr   r   r   r   r   rT   loggingr
   stater   r   r   r   r   r   r   r   r   r   r   ZTENSORBOARDZWANDBZCOMETMLZAIMZMLFLOWr5   rG   r   r   r    r;   rf   ry   r   r   r   rQ   rc   r   r   r   r   r   <module>   sP   $Jo{QC 
  