U
    d^                     @   s   d dl mZmZ d dlmZmZ d dlmZ d dlmZ d dl	m
Z
 dd ZG dd	 d	ejZG d
d dejZG dd deZdd Zdd ZG dd dejZG dd deZdd ZG dd deZG dd dejZG dd deZdS )    )corecontext)Fieldfrom_blob_list)defaultdict)copy)	viewitemsc                 C   s.   | d kr|S |d kr| S t | }|| |S N)r   update)abc r   6/tmp/pip-unpacked-wheel-ua33x9lu/caffe2/python/task.py_merge_node_kwargs   s    
r   c                   @   s8   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d ZdS )Clusterz
    Context that keeps track of all the node names used.
    Users shouldn't have to use them directly, since a Cluster is automatically
    generated at the first usage of 'Node'.
    c                 C   s   g | _ i | _d S r	   )_nodes_node_kwargsselfr   r   r   __init__   s    zCluster.__init__c                 C   sF   t || jkr| jt | t| | jt || jt |< d S r	   )strr   appendr   kwargsr   get)r   noder   r   r   add_node"   s    zCluster.add_nodec                 C   s   | j S )zQ
        Returns the list of unique node names used within this context.
        )r   r   r   r   r   nodes)   s    zCluster.nodesc                 C   s   | j S r	   )r   r   r   r   r   node_kwargs/   s    zCluster.node_kwargsc                 C   s   d |  |  S )Nz!Cluster(nodes={}, node_kwargs={}))formatr   r   r   r   r   r   __repr__2   s     zCluster.__repr__N)	__name__
__module____qualname____doc__r   r   r   r   r    r   r   r   r   r      s   r   c                   @   s2   e Zd ZdZdddZdd Zdd Zd	d
 ZdS )Nodea  
    A Node context is used to indicate that all Tasks instantiated within will
    run on the given node name. (Only the name of the node actually counts.)
    Example:

        with TaskGroup() as tg:
            with Node('node1'):
                s1 = execution_step(...)
                Task(step=s1)
            with Node('node2'):
                s2 = execution_step(...)
            with Node('node1'):
                s3 = execution_step(...)

        In this example, all three execution steps will run in parallel.
        Moreover, s1 and s3 will run on the same node, and can see each
        others blobs.

        Additionally, a Node can be passed implementation-specific kwargs,
        in order to specify properties of the node.
    localc                 K   s"   t || _|| _t |  d S r	   )r   _name_kwargsr   currentr   )r   r   r   r   r   r   r   N   s    
zNode.__init__c                 C   s   | j S r	   )r'   r   r   r   r   __str__S   s    zNode.__str__c                 C   s   d | j| jS )NzNode(name={}, kwargs={}))r   r'   r(   r   r   r   r   r    V   s    zNode.__repr__c                 C   s   | j S r	   )r(   r   r   r   r   r   Y   s    zNode.kwargsN)r&   )r!   r"   r#   r$   r   r*   r    r   r   r   r   r   r%   7   s
   
r%   c                   @   s   e Zd ZdZdZdZdS )WorkspaceTypez
    Determines whether tasks of a TaskGroup will run directly at the global
    workspace, which is kept alive across runs, or whether a new child
    workspace will be created for the run and destroyed afterwards.
    privateglobalN)r!   r"   r#   r$   PRIVATEGLOBALr   r   r   r   r+   ]   s   r+   c                 C   s  t | d }t | d }g }g }g }|D ]6}t|drJ||| 7 }q,t|dr,||| 7 }q,|D ]}	t|	dr|	jrqht|	dr|	j|krqht|	dr|	|}
t|
t	t
fr||
7 }n6t|
t jt jfr||
 n|
d k	rtdt|
 d	|	_t|	d
rh|	|}
t|
t	t
fr,||
7 }n:t|
t jt jfrL||
 n|
d k	rftdt|
 d	|	_qht| jdkr|d| t| jdkr|d| ||fS )Nz/initz/exitget_all_attributesget_attributes_setup_used_setup_targetsetupzUnsupported type for setup: %sTexitr   )r   Nethasattrr0   r1   r2   r3   r4   
isinstancelisttupleExecutionStepr   	TypeErrortyper5   lenProtoopinsert)keyZsteps_or_netstargetinit_netexit_net	init_nets	exit_netsobjsZstep_or_netobjnetsr   r   r   get_setup_netsg   sL    









rK   c                 C   sb   |s|s| S g }|r*| td| | | |  t|dkrV| td| | t||S )N%s:initr   %s:exit)r   r   execution_stepr>   )steprF   rG   namestepsr   r   r   add_setup_steps   s    
rR   c                   @   s   e Zd ZdZdZd ddZdd Zdd	 Zd
d Zdd Z	dd Z
dd Zd!ddZd"ddZd#ddZd$ddZdd Zdd ZdS )%	TaskGroupa  
    Context that gathers tasks which will run concurrently, potentially on
    multiple nodes. All tasks in the same node will share the same workspace
    and thus can share blobs, while tasks running in different nodes won't
    be able to directly share data.

    All tasks of the task group will start concurrently, and the task group
    will finish execution when the last task of the group finishes.

    Example:
        # suppose that s1 ... s5 are execution steps or nets.
        with TaskGroup() as tg:
            # these tasks go to default node 'local'
            Task(step=s1)
            Task(step=s2)

            with Node('n2'):
                Task(step=s3)
            with Node('n1'):
                Task(step=s4)
            with Node('n2'):
                Task(step=s5)

        # this will run all steps in parallel.
        # s1 and s2 will run at default node 'local'
        # s3 and s5 will run at node 'n2'
        # s4 will run at node 'n1'
        session.run(tg)
    Zlocal_setupNc                 C   s@   d | _ g | _d| _d | _g | _i | _g | _|| _d | _g | _	d S NF)
Z_plan_cache_tasks_already_usedZ_prev_active_tasks_to_add_report_nets_report_steps_workspace_type_tasks_by_node_remote_nets)r   workspace_typer   r   r   r      s    zTaskGroup.__init__c                 C   s   | j | d S r	   )r\   r   )r   netr   r   r   add_remote_net   s    zTaskGroup.add_remote_netc                 C   s   | j S r	   )r\   r   r   r   r   remote_nets   s    zTaskGroup.remote_netsc                 C   st   | j rtd| jd ks2|jd ks2| j|jks2t|jd krJ| jpFtj|_| jd kr\|j| _|  | j| d S )Nz-Cannot add Task to an already used TaskGroup.)rV   AssertionErrorrZ   r+   r.   _notify_usedrU   r   r   taskr   r   r   add   s     



zTaskGroup.addc                 C   s(   | j D ]}| | qg | _ d| _| jS NT)rW   re   rV   rU   rc   r   r   r   tasks   s
    
zTaskGroup.tasksc                 C   s   t | jt | j S r	   )r>   rW   rU   r   r   r   r   num_registered_tasks   s    zTaskGroup.num_registered_tasksc                 C   s0   g }| j | j D ]}|j|kr||j q|S r	   )rU   rW   r   r   )r   usedrd   r   r   r   
used_nodes   s
    
zTaskGroup.used_nodes  c                 C   s6   t |}|| | jt|p(t||f dS )a6  
        Add a "report step" to this TaskGroup. This step will run repeatedly
        every `interval_ms` milliseconds for the duration of the TaskGroup
        execution on each of the nodes. It is guaranteed that this step
        will be run at least once after every Task in the node has finished.
        N)r   to_execution_stepRunEveryMillisrY   r   r   r%   r)   )r   rO   r   interval_msr   r   r   report_step   s    

zTaskGroup.report_step   c                 C   s`   t |pt|}|dks(|| jks(t|| jkrR|r:|ntd| |f| j|< | j| d S )z6
        DEPRECATED. Use report_step instead.
        Nz%s/reporterr   )r   r%   r)   rX   ra   r   r6   )r   r^   r   Zreport_intervalr   r   r   
report_net   s    

zTaskGroup.report_netc              	   C   sx  i }|   D ]}|r||jn|j||j< q| jd k	rT| j\}}||ksPtd|S t| jD ] \}\}}| j|||d d q^i | _tt}|   D ]}||j }	||	 	| qtt}
| j
D ]\}}|
||  	| qt }t|D ]z\}}|
| }ttjdd |D | | \}}|}g }tj}|D ]X}| }|| tjk |d k	rd|	| || 7 }| tjkr0tj}q0t|dkr|	tdg  t|dkr|d }ntjd	| |d
d}t|dkst|dkrTg }t|dkr|	td| | |	| t|dkrH|	td| | t||}t|||d||d q||f| _|S )Nz)Cannot call tasks_by_node multiple times.rk   )r   rn   c                 S   s   g | ]}|  qS r   )get_step.0tr   r   r   
<listcomp>)  s     z+TaskGroup.tasks_by_node.<locals>.<listcomp>r   empty   %s:bodyT)Zconcurrent_substepsrL   rM   grouped_by_node)r   rO   outputsrP   groupr]   )rg   r   r[   ra   r   rX   ro   r   r9   r   rY   rS   rK   LOCAL_SETUPr+   r.   rr   SetCreateWorkspacer]   r{   r/   r>   r   rN   Task)r   Z
node_remapZnode_maprd   tasks_by_nodeZprev_node_mapr   r^   intervalZmapped_nodeZreport_steps_by_nodeZoriginal_noderO   rz   rg   report_stepsZ
node_initsZ
node_exitsrQ   r{   Zgrouped_workspace_typer   r   r   r   	  s    








  
   
zTaskGroup.tasks_by_nodec                    s>   t t  |  fdd }t|dkr6t S |d S )Nc                    s    S r	   r   )xr   r   r   <lambda>T      z#TaskGroup.to_task.<locals>.<lambda>r   )r   r%   r)   r   rg   r>   r   )r   r   rg   r   r   r   to_taskR  s
    zTaskGroup.to_taskc                 C   s   | j S r	   rZ   r   r   r   r   r]   Y  s    zTaskGroup.workspace_typec                 C   s   d | j| j |  |  S )Nz6TaskGroup(tasks={}, workspace_type={}, remote_nets={}))r   rU   rW   r]   r`   r   r   r   r   r    \  s
    
zTaskGroup.__repr__)N)NNrk   )NNrp   )N)N)r!   r"   r#   r$   r}   r   r_   r`   re   rg   rh   rj   ro   rq   r   r   r]   r    r   r   r   r   rS      s   



I
rS   c                   @   s:   e Zd ZdZdd ZdddZdd Zd	d
 Zdd ZdS )
TaskOutputzd
    Represents the output of a task. An output can be a blob,
    a list of blob, or a record.
    c                 C   sT   d | _ d| _t|tr&|| _ | j  }t|ttfk| _| jrD|g}|| _d | _	d S rT   )
_schema
_is_scalarr8   r   Zfield_blobsr=   r:   r9   names_values)r   r   r   r   r   r   i  s    

zTaskOutput.__init__Nc                 C   s&   t |t | jkst|| _|| _d S r	   )r>   r   ra   r   _fetch_func)r   valuesr   r   r   r   setu  s    zTaskOutput.setc                 C   s@   | j d k	std| jr"| j d S | jr6t| j| j S | j S d S )NzOutput value not set yet.r   )r   ra   r   r   r   r   r   r   r   r   z  s    
zTaskOutput.getc                    sN    j d k	std fdd jD } jr4|d S  jrFt j|S |S d S )Nz#Cannot fetch value for this output.c                    s   g | ]}  |qS r   )r   )rt   vr   r   r   rv     s     z$TaskOutput.fetch.<locals>.<listcomp>r   )r   ra   r   r   r   r   )r   Zfetched_valsr   r   r   fetch  s    zTaskOutput.fetchc                 C   s   d | j| jS )NzTaskOutput(names={}, values={}))r   r   r   r   r   r   r   r      s    zTaskOutput.__repr__)N)	r!   r"   r#   r$   r   r   r   r   r    r   r   r   r   r   c  s   
	r   c                 C   s   t jddpt  }|| S )a  
    Adds an output to the current Task, or if no task is active,
    create a dummy task that returns the given blob or record
    to the client. This will return the value of the blob or record when
    the last task of the TaskGroup for a given node finishes.
    Frequired)r   r)   
add_output)Zblob_or_recordZcur_taskr   r   r   final_output  s    r   c                   @   s4   e Zd ZdZdddZdd ZdddZd	d
 ZdS )TaskOutputListz$ Keeps a list of outputs for a task Nc                 C   s   |pg | _ d S r	   )r{   r   r{   r   r   r   r     s    zTaskOutputList.__init__c                 C   s   g }| j D ]}||j7 }q
|S )z[
        Retrive the output names.
        TODO(azzolini): make this schema-based.
        )r{   r   )r   r   or   r   r   r     s    
zTaskOutputList.namesc                 C   sR   d}| j D ].}t|j}|||||  | ||7 }q
|t|ksNtdd S )Nr   zWrong number of output values.)r{   r>   r   r   ra   )r   r   r   offsetr   numr   r   r   
set_values  s    


zTaskOutputList.set_valuesc                 C   s   d | jS )NzTaskOutputList(outputs={}))r   r{   r   r   r   r   r      s    zTaskOutputList.__repr__)N)N)r!   r"   r#   r$   r   r   r   r    r   r   r   r   r     s
   


r   c                       s   e Zd ZdZdZdZdZe Ze	dd Z
d"dd	Z fd
dZ fddZdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zd d! Z  ZS )#r   aA  
    A Task is composed of an execution step and zero or more outputs.
    Tasks are executed in the context of a TaskGroup, which, in turn, can
    be run by a Session.

    Task outputs are fetched by the session at the end of the run.

    The recommended way of creating a task is by using `net_builder.ops`.
    Example:

        from net_builder import ops
        with Node('trainer'), Task(name='my_task', num_instances=2):
            with ops.task_init():
                globl = ops.Const(0)
            with ops.task_instance_init():
                local = ops.Const(0)
            with ops.loop(100):
                ops.Copy(globl, local)
            with ops.task_instance_exit():
                ops.Add([globl, local], [globl])
            with ops.task_exit():
                ops.Mul([globl, globl], [globl])

    The task above will create 2 instances that will run in parallel.
    Each instance will copy `local` to `globl` 100 times, Then Add `local`
    to `globl` once. The `Mul` will only execute once, after all the instances
    of the task have finished.
    Z
task_setupZtask_instance_setupro   c                 C   s`   t | d t | }|d kr"tjntdd |jD }|}d}||kr\|d7 }d||f }q>|S )N/c                 s   s   | ]}|j V  qd S r	   )rP   rs   r   r   r   	<genexpr>  s     z&Task._get_next_name.<locals>.<genexpr>r   rx   z%s:%d)r   r   _global_names_usedr   rW   )r   r|   rP   basenameZ
names_usedZcur_nameir   r   r   _get_next_name  s    zTask._get_next_nameNc                 C   s   |st |tjr| j}|s"d}tt|dkr4dnt|| _t	j|dd| _
t| j| j
|| _| j
dk	r~| j
j|  d| _d| _d| _g | _|dk	r| | |dk	r| | d| _d| _|| _d| _|| _dS )aV  
        Instantiate a Task and add it to the current TaskGroup and Node.

        Args:
           step:    If provided, this task will run this ExecutionStep.
           outputs: If provided, the task will return the provided outputs
                    to the client at completion time.
           node:    If provided, force task execution on the given node.
           name:    Name of the Task.
           num_instances: If provided, this task will be cloned num_instances
                          times at runtime, and all instances will run
                          concurrently.
        rd   NFr   )r8   r   r;   r?   rP   r   r%   r)   r   rS   r|   r   r   rW   r   rV   _step_step_with_setup_outputsset_stepadd_outputsZ	_pipelineZ_is_pipeline_contextrZ   Z_report_net_num_instances)r   rO   r{   r]   r|   r   rP   Znum_instancesr   r   r   r     s,    
 


zTask.__init__c                    sj   t t|   | jd k	r&| jj|  |   | jd ks@tdddl	m
} |j| jd| _| j  | S )Nz(This Task already has an execution step.r   )net_builder)Z	_fullname)superr   	__enter__r|   rW   remove_assert_not_usedr   ra   caffe2.pythonr   Z
NetBuilderrP   _net_builder)r   r   	__class__r   r   r     s    

zTask.__enter__c                    sZ   t t| ||| | j||| |d kr8| | j | jd k	rP| jj|  d | _d S r	   )r   r   __exit__r   r   r|   rW   r   )r   r=   value	tracebackr   r   r   r   '  s    
zTask.__exit__c                 C   s   | j S r	   r   r   r   r   r   r]   1  s    zTask.workspace_typec                 C   s   | j rtdd S )Nz1Cannot modify task since it is already been used.)rV   ra   r   r   r   r   r   4  s    zTask._assert_not_usedc                 C   s.   |    t|tr|nt|}| j| |S r	   )r   r8   r   r   r   )r   outputr   r   r   r   8  s
    zTask.add_outputc                    s8       t|ttfkr" |S  fdd|D S d S )Nc                    s   g | ]}  |qS r   )r   )rt   r   r   r   r   rv   D  s     z$Task.add_outputs.<locals>.<listcomp>)r   r=   r9   r:   r   r   r   r   r   r   ?  s    
zTask.add_outputsc                 C   s   |    t|| _d S r	   )r   r   rl   r   )r   rO   r   r   r   r   F  s    zTask.set_stepc           
      C   sf  | j d k	r| j S | jd kr0t| jg | _ | j S dd | jtjD }|D ]}d|_|	 j
sL|d qLttj| jg| | \}}ttj| jg| | \}}t| jdkrtd| j }| |jg dtjjdd || |s| jntd	| j || jg }t|||| jd
 }	| jrN| jdkrN|	d tjd|	g| jd}	t|	||| j| _ | j S )Nc                 S   s   g | ]}t |d s|qS )_report_step_used)r7   )rt   sr   r   r   rv   R  s   
z!Task.get_step.<locals>.<listcomp>Trk   r   z	%s:outputrx   )Zdtyper   ry   z	:instancez%s:parallel)Znum_concurrent_instances)r   r   r   rN   rP   r0   r   REPORT_STEPr   r?   Zrun_every_msrm   rK   
TASK_SETUPTASK_INSTANCE_SETUPr>   r   r6   r   ZConstantFillZDataTypeZINT32r   rR   r   r~   )
r   r   rO   Ztask_init_netsZtask_exit_netsZinstance_init_netsZinstance_exit_netsZ
output_netbodyZstep_with_instance_setupr   r   r   rr   J  sn    


 
  
    
 
  
   zTask.get_stepc                 C   s
   t | jS r	   )r   r   r   r   r   r   output_listy  s    zTask.output_listc                 C   s   | j S r	   )r   r   r   r   r   r{   |  s    zTask.outputsc                 C   s   |    d| _d S rf   )rr   rV   r   r   r   r   rb     s    zTask._notify_usedc                 C   s   d | j| j|  S )Nz"Task(name={}, node={}, outputs={}))r   rP   r   r{   r   r   r   r   r      s
      zTask.__repr__)NNNNNNN)r!   r"   r#   r$   r   r   r   r   r   staticmethodr   r   r   r   r]   r   r   r   r   rr   r   r{   rb   r    __classcell__r   r   r   r   r     s6   
           
.
/r   c                   @   s2   e Zd ZdZdddZdd Zdd Zd	d
 ZdS )	SetupNetsa  
    Allow to register a list of nets to be run at initialization
    and finalization of Tasks or TaskGroups.
    For example, let's say you have the following:

        init_net = core.Net('init')
        my_val = init_net.ConstantFill([], 'my_val', value=0)

        net = core.Net('counter')
        net.Add([my_val, net.Const(1),], [my_val])

        with TaskGroup() as task_group:
            with Node('trainer'):
                my_task = Task(step=[net])

    In order to have `init_net` run once before `net` runs for the
    first time, you can do one of the following:

        net.add_attribute(Task.TASK_SETUP, SetupNets([init_net]))

    or

        net.add_attribute(TaskGroup.LOCAL_SETUP, SetupNets([init_net]))

    - With Task.TASK_SETUP, init_net will run once at my_task startup.
    - With TaskGroup.LOCAL_SETUP, init_net will run once on node 'trainer',
      before any task of the task group is run on that node.

    The same SetupNets object can be added to multiple nets. It will only
    run once per Task/TaskGroup run.
    Nc                 C   s   || _ || _d S r	   )rF   rG   )r   rF   rG   r   r   r   r     s    zSetupNets.__init__c                 C   s   | j S r	   )rF   )r   rD   r   r   r   r4     s    zSetupNets.setupc                 C   s   | j S r	   )rG   )r   rE   r   r   r   r5     s    zSetupNets.exitc                 C   s   d | j| jS )Nz%SetupNets(init_nets={}, exit_nets={}))r   rF   rG   r   r   r   r   r      s     zSetupNets.__repr__)NN)r!   r"   r#   r$   r   r4   r5   r    r   r   r   r   r     s
    
r   N)r   r   r   Zcaffe2.python.schemar   r   collectionsr   r   Zfuture.utilsr   r   ZDefaultManagedr   r%   objectr+   rK   rR   ZManagedrS   r   r   r   r   r   r   r   r   r   <module>   s$   !&
, E/ Q