U
    9%e                     @   s  d Z ddlZddlZddlZddlZddlZddlZddlZddlm	Z	m
Z
 ddlmZmZ ddlmZ ddlmZmZmZ ddlmZ ddlmZmZmZmZmZmZ d	d
lmZmZ d	dlmZ  d	dl!m"Z"m#Z#m$Z$m%Z%m&Z& ddl'm(Z( e% rddl)m*Z+ e$ rddl,m-Z. e" r"ddl/Z/e# r6ddl0m1Z2 e3 dkrRddl4m5Z6 nddl4m6Z6 e&7e8Z9da:edddddddgZ;eg df e<eg df dddZ=dd  Z>G d!d" d"eZ?G d#d$ d$eZ@G d%d& d&eZAG d'd( d(eZBG d)d* d*eZCee@ ZDd<eg df eEd,d-d.ZFd=eeeGeeG f  eeeGeeG f  eGeeeE  eDd0d1d2ZHd>eeD e<eeC d4d5d6ZIeEeEd7d8d9ZJG d:d; d;e	ZKdS )?z5
Utilities for working with the local dataset cache.
    N)ABCabstractmethod)defaultdict
namedtuple)datetime)PipeProcessQueue)
Connection)CallableIterableList
NamedTupleOptionalUnion   )
AutoConfigPretrainedConfig)__version__)is_psutil_availableis_py3nvml_availableis_tf_availableis_torch_availablelogging   )BenchmarkArguments)empty_cache)contextWindows)CTRL_C_EVENT)SIGKILLFBenchmarkOutputZtime_inference_resultZmemory_inference_resultZtime_train_resultZmemory_train_resultinference_summarytrain_summary)funcdo_multi_processingreturnc                    s.    fdd}|r&t d  d |S  S dS )a  
    This function wraps another function into its own separated process. In order to ensure accurate memory
    measurements it is important that the function is executed in a separate process

    Args:
        - `func`: (`callable`): function() -> ... generic function which will be executed in its own separate process
        - `do_multi_processing`: (`bool`) Whether to run function on separate process or not
    c                     sJ   t d fdd}t  }t||gt|  d}|  | }|  |S )N)queuec              
      sT   z | }W n8 t k
rD } zt| t| d}W 5 d }~X Y nX | | d S )NN/A)	Exceptionloggererrorprintput)r'   argsresulter$    e/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/benchmark/benchmark_utils.pywrapper_funcY   s    
zMseparate_process_wrapper_fn.<locals>.multi_process_func.<locals>.wrapper_func)targetr.   )r	   r   liststartgetjoin)r.   kwargsr4   r'   pr/   r1   r2   r3   multi_process_funcV   s    	z7separate_process_wrapper_fn.<locals>.multi_process_funcz	Function z" is executed in its own process...N)r*   info)r$   r%   r<   r2   r1   r3   separate_process_wrapper_fnL   s
    
r>   c                   C   s   t S N)_is_memory_tracing_enabledr2   r2   r2   r3   is_memory_tracing_enabledp   s    rA   c                   @   s:   e Zd ZU dZeed< eed< eed< eed< eed< dS )Framea  
    `Frame` is a NamedTuple used to gather the current frame state. `Frame` has the following fields:

        - 'filename' (string): Name of the file currently executed
        - 'module' (string): Name of the module currently executed
        - 'line_number' (int): Number of the line currently executed
        - 'event' (string): Event that triggered the tracing (default will be "line")
        - 'line_text' (string): Text of the line in the python script
    filenamemoduleline_numberevent	line_textN)__name__
__module____qualname____doc__str__annotations__intr2   r2   r2   r3   rB   u   s   

rB   c                   @   s*   e Zd ZU dZeed< eed< eed< dS )UsedMemoryStatea  
    `UsedMemoryState` are named tuples with the following fields:

        - 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file,
          location in current file)
        - 'cpu_memory': CPU RSS memory state *before* executing the line
        - 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if
          provided)
    frameZ
cpu_memoryZ
gpu_memoryN)rH   rI   rJ   rK   rB   rM   rN   r2   r2   r2   r3   rO      s   

rO   c                   @   s(   e Zd ZU dZeed< edddZdS )Memoryz
    `Memory` NamedTuple have a single field `bytes` and you can get a human readable str of the number of mega bytes by
    calling `__repr__`

        - `byte` (integer): number of bytes,
    bytesr&   c                 C   s   t t| jS r?   )rL   bytes_to_mega_bytesrR   selfr2   r2   r3   __repr__   s    zMemory.__repr__N)rH   rI   rJ   rK   rN   rM   rL   rW   r2   r2   r2   r3   rQ      s   
rQ   c                   @   s2   e Zd ZU dZeed< eed< eed< eed< dS )MemoryStatea  
    `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:

        - `frame` (`Frame`): the current frame (see above)
        - `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
        - `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
        - `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
    rP   cpugpucpu_gpuN)rH   rI   rJ   rK   rB   rM   rQ   r2   r2   r2   r3   rX      s
   
	rX   c                   @   s>   e Zd ZU dZee ed< ee ed< ee ed< eed< dS )MemorySummaryau  
    `MemorySummary` namedtuple otherwise with the fields:

        - `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` by
          subtracting the memory after executing each line from the memory before executing said line.
        - `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line
          obtained by summing repeated memory increase for a line if it's executed several times. The list is sorted
          from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory
          is released)
        - `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). Line with
          memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
    
sequential
cumulativecurrenttotalN)rH   rI   rJ   rK   r   rX   rM   rQ   r2   r2   r2   r3   r\      s
   
r\         ?)functionr&   c                    s  t t ddd t s&td d}nG  fdddt}t \}}|t ||}|  |	  z$|   |
d |	 }|	 }W nT tk
r   tt }	|	jd	d
D ]}
t|
jt q|d tdY nX |d|  |dks|dk rq|d }q:|S dS )a@  
    measures peak cpu memory consumption of a given `function` running the function for at least interval seconds and
    at most 20 * interval seconds. This function is heavily inspired by: `memory_usage` of the package
    `memory_profiler`:
    https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239

    Args:
        - `function`: (`callable`): function() -> ... function without any arguments to measure for which to measure
          the peak memory

        - `interval`: (`float`, `optional`, defaults to `0.5`) interval in second for which to measure the memory usage

        - `device_idx`: (`int`, `optional`, defaults to `None`) device id for which to measure gpu usage

    Returns:

        - `max_memory`: (`int`) consumed memory peak in Bytes
    )
process_idr&   c                 S   sT   t | }z&t|drdnd}t|| d }W n t jk
rN   tdY nX |S )z
        measures current cpu memory usage of a given `process_id`

        Args:
            - `process_id`: (`int`) process_id for which to measure memory

        Returns

            - `memory`: (`int`) consumed memory in Bytes
        memory_infoZget_memory_infor   zError with Psutil.)psutilr   hasattrgetattrZAccessDenied
ValueError)rc   processZmeminfo_attrmemoryr2   r2   r3   get_cpu_memory   s    
z/measure_peak_memory_cpu.<locals>.get_cpu_memoryzsPsutil not installed, we won't log CPU memory usage. Install Psutil (pip install psutil) to use CPU memory tracing.r(   c                       s8   e Zd ZdZeeed fddZfddZ  Z	S )z5measure_peak_memory_cpu.<locals>.MemoryMeasureProcessz
            `MemoryMeasureProcess` inherits from `Process` and overwrites its `run()` method. Used to measure the
            memory usage of a process
            )rc   child_connectionintervalc                    s2   t    || _|| _|| _d| _| j| _d S )Nr   )super__init__rc   rm   
connectionnum_measurements	mem_usage)rV   rc   rl   rm   )	__class__rk   r2   r3   ro     s    
z>measure_peak_memory_cpu.<locals>.MemoryMeasureProcess.__init__c                    sh   | j d d}t| j | j| _|  jd7  _|r8qH| j | j}q| j | j | j | j d S )Nr   Fr   )rp   sendmaxrr   rc   rq   pollrm   )rV   stoprk   r2   r3   run
  s    z9measure_peak_memory_cpu.<locals>.MemoryMeasureProcess.run)
rH   rI   rJ   rK   rN   r
   floatro   ry   __classcell__r2   rx   )rs   r3   MemoryMeasureProcess   s   r|   r   T)	recursivez Process killed. Error in Process      gư>
   N)rN   r   r*   warningr   r   osgetpidr7   recvrt   r)   re   childrenkillpidr    r9   RuntimeError)rb   rm   
device_idxZ
max_memoryr|   rl   Zparent_connectionZmem_processrq   parentchildr2   rx   r3   measure_peak_memory_cpu   s6    !



r   line)modules_to_tracemodules_not_to_traceevents_to_tracegpus_to_tracer&   c              	      s   t  rtt ntd dt rz0t	  |dkrLt
tt n| t  W n( ttjfk
r   td dY qX t pt ntd dg  fddt daS )	u  
    Setup line-by-line tracing to record rss mem (RAM) at each line of a module or sub-module. See `./benchmark.py` for
    usage examples. Current memory consumption is returned using psutil and in particular is the RSS memory "Resident
    Set Size” (the non-swapped physical memory the process is using). See
    https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info

    Args:
        - `modules_to_trace`: (None, string, list/tuple of string) if None, all events are recorded if string or list
          of strings: only events from the listed module/sub-module will be recorded (e.g. 'fairseq' or
          'transformers.models.gpt2.modeling_gpt2')
        - `modules_not_to_trace`: (None, string, list/tuple of string) if None, no module is avoided if string or list
          of strings: events from the listed module/sub-module will not be recorded (e.g. 'torch')
        - `events_to_trace`: string or list of string of events to be recorded (see official python doc for
          `sys.settrace` for the list of events) default to line
        - `gpus_to_trace`: (optional list, default None) list of GPUs to trace. Default to tracing all GPUs

    Return:

        - `memory_trace` is a list of `UsedMemoryState` for each event (default each line of the traced script).

            - `UsedMemoryState` are named tuples with the following fields:

                - 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current
                  file, location in current file)
                - 'cpu_memory': CPU RSS memory state *before* executing the line
                - 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only
                  `gpus_to_trace` if provided)

    `Frame` is a namedtuple used by `UsedMemoryState` to list the current frame state. `Frame` has the following
    fields: - 'filename' (string): Name of the file currently executed - 'module' (string): Name of the module
    currently executed - 'line_number' (int): Number of the line currently executed - 'event' (string): Event that
    triggered the tracing (default will be "line") - 'line_text' (string): Text of the line in the python script

    zsPsutil not installed, we won't log CPU memory usage. Install psutil (pip install psutil) to use CPU memory tracing.NzUError while initializing communication with GPU. We won't perform GPU memory tracing.Fzvpy3nvml not installed, we won't log GPU memory usage. Install py3nvml (pip install py3nvml) to use GPU memory tracing.c                    s  t sS dk	r@ttr&|kr&S tttfr@|kr@S d| jkrNS | jd  t tsfS dk	rttr krS tttfrt fddD rS dk	rttrʈ krʈS tttfrt fddD rS | j}| jd }|	ds|	dr&|dd	 }t
|| }t| |||}d
}dk	rb }|j}d
}	rt rzt  t rt   t  D ]$}
t|
}t|}|	|j7 }	qt  t|||	}| S )z
        Tracing method executed before running each line in a module or sub-module Record memory allocated in a list
        with debugging information
        NrH   c                 3   s   | ]}| kV  qd S r?   r2   .0mnamer2   r3   	<genexpr>  s     z8start_memory_tracing.<locals>.traceit.<locals>.<genexpr>c                 3   s   | ]}| kV  qd S r?   r2   r   r   r2   r3   r     s     __file__z.pycz.pyor   )r@   
isinstancerL   r6   tuple	f_globalsallanyf_linenoendswith	linecachegetlinerstriprB   rd   Zrssr   torch_empty_cacher   
tf_contextr   Z_clear_cachesnvmlnvmlInitnvmlDeviceGetHandleByIndexnvmlDeviceGetMemoryInfousednvmlShutdownrO   append)rP   rF   r.   linenorC   r   Ztraced_statecpu_memZmemgpu_memihandleZmeminfoZ	mem_stateZdevicesr   Zlog_gpumemory_tracer   r   ri   traceitr   r3   r     s^    


$$




z%start_memory_tracing.<locals>.traceitT)r   re   r   r   r   r*   r   r   r   r   r6   rangeZnvmlDeviceGetCountr   OSErrorZ	NVMLErrorr   r   syssettracer@   )r   r   r   r   r2   r   r3   start_memory_tracingC  s0    (

L
r   T)r   ignore_released_memoryr&   c              
   C   sz  da | dk	rvt| dkrvg }g }tdd }t| dd | dd D ]\\}}}\}}	}
|	| }|
| }|| }|t|t|t|t|d |t|t|	t|
t|
|	 d || d  |7  < || d  |7  < || d	  |7  < qJt| d
d dd}dd |D }t|dd dd}|rLt	dd |D }nt	dd |D }t|}t
||||dS dS )aW	  
    Stop memory tracing cleanly and return a summary of the memory trace if a trace is given.

    Args:
        `memory_trace` (optional output of start_memory_tracing, default: None):
            memory trace to convert in summary
        `ignore_released_memory` (boolean, default: None):
            if True we only sum memory increase to compute total memory

    Return:

        - None if `memory_trace` is None
        - `MemorySummary` namedtuple otherwise with the fields:

            - `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` by
              subtracting the memory after executing each line from the memory before executing said line.
            - `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each
              line obtained by summing repeated memory increase for a line if it's executed several times. The list is
              sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative
              if memory is released)
            - `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). Line with
              memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).

    `Memory` named tuple have fields

        - `byte` (integer): number of bytes,
        - `string` (string): same as human readable string (ex: "3.5MB")

    `Frame` are namedtuple used to list the current frame state and have the following fields:

        - 'filename' (string): Name of the file currently executed
        - 'module' (string): Name of the module currently executed
        - 'line_number' (int): Number of the line currently executed
        - 'event' (string): Event that triggered the tracing (default will be "line")
        - 'line_text' (string): Text of the line in the python script

    `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:

        - `frame` (`Frame`): the current frame (see above)
        - `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
        - `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
        - `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
    FNr   c                   S   s
   dddgS )Nr   r2   r2   r2   r2   r3   <lambda>      z%stop_memory_tracing.<locals>.<lambda>r   rP   rY   rZ   r[   r   r   c                 S   s   | d d S )Nr   r   r2   xr2   r2   r3   r   0  r   T)keyreversec                 S   s2   g | ]*\}\}}}t |t|t|t|d qS )r   )rX   rQ   )r   rP   cpu_mem_incgpu_mem_inccpu_gpu_mem_incr2   r2   r3   
<listcomp>2  s   z'stop_memory_tracing.<locals>.<listcomp>c                 S   s   | j jS r?   r[   rR   r   r2   r2   r3   r   <  r   c                 s   s   | ]}t d |jjV  qdS )r   N)ru   r[   rR   r   Z
step_tracer2   r2   r3   r   ?  s     z&stop_memory_tracing.<locals>.<genexpr>c                 s   s   | ]}|j jV  qd S r?   r   r   r2   r2   r3   r   A  s     )r]   r^   r_   r`   )r@   lenr   zipr   rX   rQ   sorteditemssumr\   )r   r   Zmemory_diff_traceZmemory_curr_traceZcumulative_memory_dictrP   r   r   Z
next_frameZnext_cpu_memZnext_gpu_memr   r   r   Zcumulative_memoryZtotal_memoryr2   r2   r3   stop_memory_tracing  sf    /	
	  
r   )memory_amountr&   c                 C   s   | d? S )zLUtility to convert a number of bytes (int) into a number of mega bytes (int)r~   r2   )r   r2   r2   r3   rT   O  s    rT   c                   @   sB  e Zd ZU dZeed< eed< eed< d*eedddZe	d	d
 Z
e	edd ZeeeeedddZeeeeedddZeeeeeee gdddZeeeeeee gdddZedddZedddZeee gdddZeee gdddZdd  Ze	d!d" Zd#d$ Zed%d&d'Zd(d) ZdS )+	Benchmarkz
    Benchmarks is a simple but feature-complete benchmarking script to compare memory and time performance of models in
    Transformers.
    r.   configs	frameworkN)r.   r   c                 C   s   || _ |d kr$dd | j jD | _ntt| j j|| _td| j dt | j j	rnt
ddkrntd d | _d | _d | _d S )Nc                 S   s   i | ]}|t |qS r2   )r   Zfrom_pretrainedr   
model_namer2   r2   r3   
<dictcomp>a  s     z&Benchmark.__init__.<locals>.<dictcomp>z
The class z is deprecated. Hugging Face Benchmarking utils are deprecated in general and it is advised to use external Benchmarking libraries  to benchmark Transformer models.Z TRANSFORMERS_USE_MULTIPROCESSINGr   zMemory consumption will not be measured accurately if `args.multi_process` is set to `False.` The flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing.)r.   model_namesZconfig_dictdictr   warningswarnrs   FutureWarningrj   r   getenvr*   r   	_print_fnZ_framework_version_environment_info)rV   r.   r   r2   r2   r3   ro   ^  s"    
zBenchmark.__init__c                    s2    j d kr, jjr& fdd}| _ nt _  j S )Nc               	      s:   t  jjd}|d| d  W 5 Q R X t|   d S )Na 
)openr.   Zlog_filenamewriter9   r,   )r.   Zlog_filerU   r2   r3   print_and_log}  s    z)Benchmark.print_fn.<locals>.print_and_log)r   r.   Z	log_printr,   )rV   r   r2   rU   r3   print_fnx  s    
zBenchmark.print_fnc                 C   s   d S r?   r2   rU   r2   r2   r3   framework_version  s    zBenchmark.framework_version)r   
batch_sizesequence_lengthr&   c                 C   s   d S r?   r2   rV   r   r   r   r2   r2   r3   _inference_speed  s    zBenchmark._inference_speedc                 C   s   d S r?   r2   r   r2   r2   r3   _train_speed  s    zBenchmark._train_speedc                 C   s   d S r?   r2   r   r2   r2   r3   _inference_memory  s    zBenchmark._inference_memoryc                 C   s   d S r?   r2   r   r2   r2   r3   _train_memory  s    zBenchmark._train_memoryrS   c                 O   s   t | j| jj||S r?   )r>   r   r.   r%   rV   r.   r:   r2   r2   r3   inference_speed  s    zBenchmark.inference_speedc                 O   s   t | j| jj||S r?   )r>   r   r.   r%   r   r2   r2   r3   train_speed  s    zBenchmark.train_speedc                 O   s   t | j| jj||S r?   )r>   r   r.   r%   r   r2   r2   r3   inference_memory  s    zBenchmark.inference_memoryc                 O   s   t | j| jj||S r?   )r>   r   r.   r%   r   r2   r2   r3   train_memory  s    zBenchmark.train_memoryc              	   C   s  dd | j jD }t|}t|}t|}t|}t| j jD ]n\}}| |d  dt| j j  | j j| j jdd | j jD d}t|||< t|||< t|||< t|||< d  }	}
| j jD ]}| j jD ]}| j j	rP| j j
r$| |||\}}	||| d | |< | j jrP| |||}||| d | |< | j jr| j j
r| |||\}}
||| d | |< | j jr| |||}||| d | |< qqqF| j j	r| j jr| dd	d
 d  | j|dd | || j j | j jr| d | j j
rV| ddd
 d  | j|dd | || j j | j jr| ddd
 d  | |	 | j jrJ| j jr| ddd
 d  | |d | || j j | j jr| d | j j
r| ddd
 d  | j|dd | || j j | j jrJ| ddd
 d  | |
 | j jr| ddd
 d  | ddd | j D d  | j jrt| j j ddd4}t!"|}| j D ]\}}|#||g qW 5 Q R X t$|||||	|
S )Nc                 S   s   i | ]
}|i qS r2   r2   r   r2   r2   r3   r     s      z!Benchmark.run.<locals>.<dictcomp>r   z / c                 S   s   i | ]
}|i qS r2   r2   )r   r   r2   r2   r3   r     s      )bsssr/   r/   z
====================zINFERENCE - SPEED - RESULT(   z====================z	Time in s)
type_labelzTPU was used for inference. Note that the time after compilation stabilized (after ~10 inferences model.forward(..) calls) was measured.zINFERENCE - MEMORY - RESULTzMemory in MBz,INFERENCE - MEMOMRY - LINE BY LINE - SUMMARYzTRAIN - SPEED - RESULTSzTPU was used for training. Note that the time after compilation stabilized (after ~10 train loss=model.forward(...) + loss.backward() calls) was measured.zTRAIN - MEMORY - RESULTSz(TRAIN - MEMOMRY - LINE BY LINE - SUMMARYzENVIRONMENT INFORMATIONr   c                 S   s    g | ]\}}d | d| qS )z- : r2   )r   propvalr2   r2   r3   r     s     z!Benchmark.run.<locals>.<listcomp>wr   )modenewline)%r.   r   copydeepcopy	enumerater   r   Zbatch_sizesZsequence_lengthsZ	inferencerj   r   speedr   Ztrainingr   r   centerprint_resultssave_to_csvZinference_time_csv_fileis_tpuZinference_memory_csv_fileZtrace_memory_line_by_lineprint_memory_trace_statisticsZtrain_time_csv_fileZtrain_memory_csv_fileZ	env_printr9   environment_infor   r   Zenv_info_csv_filecsvwriterwriterowr!   )rV   result_dictZinference_result_timeZinference_result_memoryZtrain_result_timeZtrain_result_memorycr   Z
model_dictr"   r#   r   r   rj   timecsv_filer  r   valuer2   r2   r3   ry     s    



 
















$

zBenchmark.runc                 C   s  | j d kri }t|d< | j|d< | jdkr8| jj|d< | jdkrZ| jj|d< | jj|d< | j|d< t	 |d	< t
 |d
< t |d< t d |d< tt |d< tt |d< | jj|d< | jj|d< | jj|d< t rtt j|d< ntd d|d< | jj|d< | jjrd|d< t rt  t| jj }t!||d< tt"|j|d< t#|d |d< t$||d< t%  n*td d|d< d|d< d|d< d|d< | jj&|d< || _ | j S ) NZtransformers_versionr   ZPyTorchZuse_torchscriptZ
TensorFlow
eager_modeuse_xlar   python_versionsystemrY   r   architecturedater
  fp16Zuse_multiprocessingonly_pretrain_modelZ
cpu_ram_mbzyPsutil not installed, we won't log available CPU memory. Install psutil (pip install psutil) to log available CPU memory.r(   Zuse_gpur   Znum_gpusrZ   Z
gpu_ram_mb  Zgpu_power_wattsZgpu_performance_statezypy3nvml not installed, we won't log GPU memory usage. Install py3nvml (pip install py3nvml) to log information about GPU.Zuse_tpu)'r   versionr   r.   Ztorchscriptr  r  r   platformr  r  	processorr  r   r  nowr
  r  r%   r  r   rT   re   Zvirtual_memoryr`   r*   r   Zis_gpur   r   r   r   r   ZnvmlDeviceGetNamer   Z!nvmlDeviceGetPowerManagementLimitZnvmlDeviceGetPerformanceStater   r  )rV   r=   r   r2   r2   r3   r    s\    





zBenchmark.environment_infoc              
   C   s  |  d |  dddd dd |d  |  d | jjD ]}|| d D ]}|| d D ]}|| d	 | | }t|trtd
| d
 }|dkrdnt|}nt|}|  |d d dt|d t|d|d qjqZqJ|  d d S )NzP--------------------------------------------------------------------------------z
Model Name   z
Batch Size   z
Seq Lengthr   r   r/   r  g        z< 0.001)r   r   r.   r   r   rz   roundrL   )rV   r  r   r   r   r   r/   r2   r2   r3   r   J  s&    
&


zBenchmark.print_results)summaryc              	   C   s   |  dddd |jD   |  dddd |jd d D   |  ddd	d |jd
d  D   |  d|j  d S )Nz"
Line by line memory consumption:
r   c                 s   s6   | ].}|j j d |j j d|j d|j j V  qdS ):: mem r   NrP   rC   rE   r[   rG   r   stater2   r2   r3   r   c  s   z:Benchmark.print_memory_trace_statistics.<locals>.<genexpr>z$
Lines with top memory consumption:
c              	   s   s8   | ]0}d |j j d|j j d|j d|j j V  qdS z=> r  r  r   Nr   r!  r2   r2   r3   r   j  s      z'
Lines with lowest memory consumption:
c              	   s   s8   | ]0}d |j j d|j j d|j d|j j V  qdS r#  r   r!  r2   r2   r3   r   q  s   iz
Total memory increase: )r   r9   r]   r^   r`   )rV   r  r2   r2   r3   r  `  s,    


z'Benchmark.print_memory_trace_statisticsc                 C   s   | j jsd S | d t|dd}t| j jdkrDtd| j dddg}tj||d	g d
}|	  | j jD ]\}|| d	 }|D ]F}|| D ]8}	|| |	 }
|
|||	t|
tsdnd|
d qqqrW 5 Q R X d S )NzSaving results to csv.r   )r   r   z,At least 1 model should be defined, but got modelr   r   r/   )
fieldnamesz{}z{:.4f})r%  r   r   r/   )r.   r  r   r   r   r   rh   r  
DictWriterwriteheaderr  r   rz   format)rV   r  rC   r  r&  r  r   Zresult_dict_modelr   r   Zresult_modelr2   r2   r3   r  x  s.    

zBenchmark.save_to_csv)NN)rH   rI   rJ   rK   r   rM   r   rL   ro   propertyr   r   r   rN   rz   r   r   rQ   r   r\   r   r   r   r   r   r   ry   r  r   r  r  r2   r2   r2   r3   r   T  sH   

  
  
c
:r   )ra   N)NNr   N)NT)LrK   r   r  r   r   r  r   r   abcr   r   collectionsr   r   r   multiprocessingr   r   r	   Zmultiprocessing.connectionr
   typingr   r   r   r   r   r   r   r   r   r   r  utilsr   r   r   r   r   Zbenchmark_args_utilsr   Z
torch.cudar   r   Ztensorflow.python.eagerr   r   re   Zpy3nvml.py3nvmlZpy3nvmlr   r  signalr   r    Z
get_loggerrH   r*   r@   r!   boolr>   rA   rB   rO   rQ   rX   r\   ZMemoryTracerN   r   rL   r   r   rT   r   r2   r2   r2   r3   <module>   s    
"$x    
     t