U
    d                     @   s   d Z ddlZddlZddlZddlZddlZddlZddlmZm	Z	 e
dZeej dZddd	ZG d
d deZG dd dZG dd deZG dd deZG dd deZe add ZdS )a<  
This module provides a python-land multithreaded mechanism for executing work.

Basic usage is as follows:
   coordinator = parallel_workers.init_workers(
      my_worker_fun,
      worker_name="train"
   )
   ...
   coordinator.start()

First argument is the function to run in a loop on potentially multiple threads.
It has the call signature
    worker_fun(worker_id)

Argument 'worker_name' is used to distinguish different workers,
such as workers processing train data or workers processing test data.

Optionally, one can define an "init function" that is called once before
threads start, and has call signature:
   my_init_fun(worker_coordinator, global_coordinator)

Note that for data_parallel_models, init_workers will be called
for each GPU. Note that the 'coordinator' returned by the function is same
each time.
    N)ABCMetaabstractmethodZparallel_workers<      trainc                    sT   t |dd t|D }t||||d  fdd|D }| _t  tS )Nc                 S   s   g | ]}t  qS  )global_coordinatorget_new_worker_id).0ir   r   B/tmp/pip-unpacked-wheel-ua33x9lu/caffe2/python/parallel_workers.py
<listcomp>>   s   z init_workers.<locals>.<listcomp>)shutdown_func                    s0   g | ](}t jtd | t |gdqS )zparallel_workers worker id {})targetnameargs)	threadingThread
run_workerformatWorker)r
   	worker_idcoordinatormetrics
worker_funr   r   r   H   s   )MetricsrangeWorkerCoordinator_workersr   add)r   Znum_worker_threadsworker_nameinit_funexternal_loggersr   
worker_idsworkersr   r   r   init_workers2   s     
   	
r&   c                   @   s.   e Zd Zdd Zdd Zdd Zddd	Zd
S )r   c                 C   s   t dd | _|| _d S )Nc                   S   s   dS Nr   r   r   r   r   r   <lambda>Y       z"Metrics.__init__.<locals>.<lambda>)collectionsdefaultdict_metrics_external_loggers)selfr#   r   r   r   __init__X   s    zMetrics.__init__c                 C   s   t dd | _d S )Nc                   S   s   dS r'   r   r   r   r   r   r(   ]   r)   z'Metrics.reset_metrics.<locals>.<lambda>)r*   r+   r,   r.   r   r   r   reset_metrics\   s    zMetrics.reset_metricsc                 C   s\   | j s
d S | j D ]F}z|| j W q tk
rT } ztd| W 5 d }~X Y qX qd S )Nz!Failed to call ExternalLogger: {})r-   logr,   	Exceptionprintr   )r.   loggerer   r   r   log_metrics_   s    
zMetrics.log_metricsTc                 C   s6   | j |  |7  < |r2d|}| j |  d7  < d S )Nz{}_count   )r,   r   )r.   keyvaluecountZ	count_keyr   r   r   
put_metrich   s    
zMetrics.put_metricN)T)__name__
__module____qualname__r/   r1   r7   r<   r   r   r   r   r   W   s   	r   c                   @   s4   e Zd ZeZedd Zedd Zedd ZdS )Statec                 C   s   d S Nr   r0   r   r   r   startr   s    zState.startc                 C   s   d S rA   r   r0   r   r   r   stopv   s    z
State.stopc                 C   s   d S rA   r   r0   r   r   r   cleanupz   s    zState.cleanupN)	r=   r>   r?   r   __metaclass__r   rB   rC   rD   r   r   r   r   r@   o   s   

r@   c                   @   sJ   e Zd ZdddZdd Zdd Zdd	 Zdd
dZdddZdd Z	dS )r   Nc                 C   s4   d| _ d| _g | _|| _|| _|| _|| _|| _d S )NTF)_active_startedr   _worker_name_worker_ids	_init_fun_state_shutdown_fun)r.   r!   r$   r"   stater   r   r   r   r/      s    zWorkerCoordinator.__init__c                 C   s   | j S rA   )rF   r0   r   r   r   	is_active   s    zWorkerCoordinator.is_activec                 C   s    | j r| js| }|  || d S rA   )rJ   rG   )r.   r   Zdata_coordinatorr   r   r   init   s    zWorkerCoordinator.initc                 C   sD   | j r
d S d| _d| _ | jr&| j  | jD ]}d|_|  q,d S NT)rG   rF   rK   rB   r   daemon)r.   wr   r   r   _start   s    

zWorkerCoordinator._startc                 C   sL   d| _ |d k	rtd| | jr2| jr2|   | jrB| j  d| _d S )NFz%Data input failed due to an error: {})rF   r2   errorr   rL   rG   rK   rC   )r.   reasonr   r   r   _stop   s    
zWorkerCoordinator._stopc                 C   s   t d| j | jD ]}|t kr|d qd}| jD ]}| r<t d| d}q<|rp| jrp| j	  t d| |S )NzWait for workers to die: {}g      @Tz'Worker {} failed to close while waitingFzAll workers terminated: {})
r4   r   rH   r   r   current_threadjoinis_aliverK   rD   )r.   rD   rR   successr   r   r   _wait_finish   s    



zWorkerCoordinator._wait_finishc                 C   s   | j S rA   rI   r0   r   r   r   get_worker_ids   s    z WorkerCoordinator.get_worker_ids)NN)N)N)
r=   r>   r?   r/   rN   rO   rS   rV   r[   r]   r   r   r   r   r      s      


r   c                   @   sL   e Zd Zdd Zdd Zdd Zdd Zd	d
 Zdd Zdd Z	dd Z
dS )GlobalWorkerCoordinatorc                 C   s   g | _ d| _g | _|   d S r'   )_coordinators_fetcher_id_seqrI   register_shutdown_handlerr0   r   r   r   r/      s    z GlobalWorkerCoordinator.__init__c                 C   s   | j | d S rA   )r_   append)r.   r   r   r   r   r       s    zGlobalWorkerCoordinator.addc                 C   s$   | j }| j| |  j d7  _ |S )Nr8   )r`   rI   rb   )r.   r   r   r   r   r	      s    z)GlobalWorkerCoordinator.get_new_worker_idc                 C   s   | j S rA   r\   r0   r   r   r   r]      s    z&GlobalWorkerCoordinator.get_worker_idsc                 C   s.   | j D ]}||  q| j D ]}|  qd S rA   )r_   rO   rS   )r.   cr   r   r   rB      s    

zGlobalWorkerCoordinator.startc                 C   s>   d}| j D ]}|  q
| j D ]}| }|o0|}qg | _ |S rP   )r_   rV   r[   )r.   Zall_successrc   rZ   r   r   r   rC      s    



zGlobalWorkerCoordinator.stopc                    s@   | j D ]}|j kr|  |  q fdd| j D | _ dS )z-
        Stop a specific coordinator
        c                    s   g | ]}|j  kr|qS r   )rH   )r
   rc   r!   r   r   r      s   
z<GlobalWorkerCoordinator.stop_coordinator.<locals>.<listcomp>N)r_   rH   rV   r[   )r.   r!   rc   r   rd   r   stop_coordinator   s    



z(GlobalWorkerCoordinator.stop_coordinatorc                    s    fdd}t | d S )Nc                      s       d S rA   )rC   r   r0   r   r   rD      s    zBGlobalWorkerCoordinator.register_shutdown_handler.<locals>.cleanup)atexitregister)r.   rD   r   r0   r   ra      s    z1GlobalWorkerCoordinator.register_shutdown_handlerN)r=   r>   r?   r/   r    r	   r]   rB   rC   re   ra   r   r   r   r   r^      s   
r^   c                   @   s6   e Zd ZdddZdd Zdd Zdd	 Zd
d ZdS )r   Nc                 C   s   || _ || _|| _|| _d S rA   )_coordinator
_worker_id_worker_funr,   )r.   r   r   r   r   r   r   r   r/      s    zWorker.__init__c                 C   s   t   | _d S rA   )time_start_timer0   r   r   r   rB     s    zWorker.startc                 C   s   |  | j d S rA   )rj   ri   r0   r   r   r   run  s    z
Worker.runc                 C   s.   t   td| | jd| j| d S )NzException in workerzException in worker {}: {})	traceback	print_exclogging	exceptionrh   rV   r   ri   )r.   r6   r   r   r   handle_exception  s    
 zWorker.handle_exceptionc                 C   s&   | j dt | j  | j   d S )NZworker_time)r,   r<   rk   rl   r7   r0   r   r   r   finish  s
     zWorker.finish)NN)r=   r>   r?   r/   rB   rm   rr   rs   r   r   r   r   r      s     
r   c              
   C   s`   |   r\|  z>z|  W n, tk
rJ } z|| W 5 d }~X Y nX W 5 |  X q d S rA   )rN   rB   rs   rm   r3   rr   )r   Zworkerr6   r   r   r   r     s     r   )r   r   NNN)__doc__rp   r   rf   rk   r*   rn   abcr   r   	getLoggerr2   setLevelINFOZLOG_INT_SECSr&   objectr   r@   r   r^   r   r   r   r   r   r   r   <module>	   s.   
     
%C9 