U
    d5P                     @   s   d Z ddlZddlmZmZmZ dd Zdd Zdd	 Zd
d Z	dd Z
dd Zdd Zdd Zdd Zedkre  \ZZejrejrejse   nXeddge ejrdgng  ejrdej gng   eee	eedZeeej e dS )a  
Benchmark for common convnets.

Speed on Titan X, with 10 warmup steps and 10 main steps and with different
versions of cudnn, are as follows (time reported below is per-batch time,
forward / forward+backward):

                    CuDNN V3        CuDNN v4
AlexNet         32.5 / 108.0    27.4 /  90.1
OverFeat       113.0 / 342.3    91.7 / 276.5
Inception      134.5 / 485.8   125.7 / 450.6
VGG (batch 64) 200.8 / 650.0   164.1 / 551.7

Speed on Inception with varied batch sizes and CuDNN v4 is as follows:

Batch Size   Speed per batch     Speed per image
 16             22.8 /  72.7         1.43 / 4.54
 32             38.0 / 127.5         1.19 / 3.98
 64             67.2 / 233.6         1.05 / 3.65
128            125.7 / 450.6         0.98 / 3.52

Speed on Tesla M40, which 10 warmup steps and 10 main steps and with cudnn
v4, is as follows:

AlexNet         68.4 / 218.1
OverFeat       210.5 / 630.3
Inception      300.2 / 1122.2
VGG (batch 64) 405.8 / 1327.7

(Note that these numbers involve a "full" backprop, i.e. the gradient
with respect to the input image is also computed.)

To get the numbers, simply run:

for MODEL in AlexNet OverFeat Inception; do
  PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py     --batch_size 128 --model $MODEL --forward_only True
done
for MODEL in AlexNet OverFeat Inception; do
  PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py     --batch_size 128 --model $MODEL
done
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py   --batch_size 64 --model VGGA --forward_only True
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py   --batch_size 64 --model VGGA

for BS in 16 32 64 128; do
  PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py     --batch_size $BS --model Inception --forward_only True
  PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py     --batch_size $BS --model Inception
done

Note that VGG needs to be run at batch 64 due to memory limit on the backward
pass.
    N)	workspacebrewmodel_helperc           
         s   t jdd}d}d d}t D ]Z}t|D ]L}|dkrDd||nd}d|d	 |}tj|||||d
i fd
i fd q,q t| fddt|D dg tj|dd|dd
i fd
i fd |jddgd}	|j	|	d ||fS )NMLP)name         r   fc_{}_{}data   
XavierFill)Zdim_inZdim_outZweight_initZ	bias_initc                    s   g | ]}d   |qS )r
   )format).0jdepth D/tmp/pip-unpacked-wheel-ua33x9lu/caffe2/python/convnet_benchmarks.py
<listcomp>U   s     zMLP.<locals>.<listcomp>sumlast  labelxentloss)
r   ModelHelperranger   r   fcr   netLabelCrossEntropyAveragedLoss)
ordercudnn_wsmodeldwidthir   currentZnext_r   r   r   r   r   B   sD    
	  	r   c                 C   s
  | ddd}|r||d< t jd|d}tj|dddd	d
di fdi fddd
}t||d}tj||dddd}tj||dd	dddi fdi fdd	}t||d}tj||dddd}	tj||	dddddi fdi fdd	}
t||
d}tj||dddddi fdi fdd	}t||d}tj||dddddi fdi fdd	}t||d}tj||dddd}t||ddd di fdi f}t||d}t||d!d d di fdi f}t||d!}t||d"d d#di fdi f}t||d$}|j	|d%gd&}|j
|d' |d(fS ))NTr"   Z	use_cudnnZcudnn_exhaustive_searchws_nbytes_limitZalexnetr   Z	arg_scoper   conv1r	   @      r   ConstantFill      stridepadpool1kernelr3   conv2      r4   pool2conv3  r   conv4r   conv5pool5fc6i $     fc7fc8r   predr   r   r      r   r   r   convrelumax_poolr   softmaxr   r    r!   r"   r#   my_arg_scoper$   r,   relu1r5   r8   relu2r<   r=   relu3r?   relu4r@   relu5rA   rB   relu6rD   relu7rE   rF   r   r   r   r   AlexNete   s      
                rV   c                 C   s  | ddd}|r||d< t jd|d}tj|dddd	d
di fdi fdd	}t||d}tj||dddd}t||dd	dddi fdi f}t||d}tj||dddd}	tj||	dddddi fdi fdd	}
t||
d}tj||dddddi fdi fdd	}t||d}tj||dddddi fdi fdd	}t||d}tj||dddd}t||ddd di fdi f}t||d}t||d!d d"di fdi f}t||d!}t||d#d"d$di fdi f}t||d%}|j	|d&gd'}|j
|d( |d)fS )*NTr)   r*   Zoverfeatr+   r   r,   r	   `   r.   r   r/   r0   )r3   r5   r1   r6   r8   r   r:   r<   r=      r   r;   r?      r@   rA   rB   i   i   rD   rC   rE   r   rF   r   r   r      rH   rM   r   r   r   OverFeat   s                           r[   c                  C   s  | ddd}|r||d< t jd|d}tj|dddd	dd
i fdi fdd	}t||d}tj||dddd}tj||dd	ddd
i fdi fdd	}t||d}tj||dddd}	tj||	ddddd
i fdi fdd	}
t||
d}tj||ddddd
i fdi fdd	}t||d}tj||dddd}tj||ddddd
i fdi fdd	}t||d}tj||ddddd
i fdi fdd	}t||d}tj||dddd}tj||ddddd
i fdi fdd	}t||d}tj||ddddd
i fdi fdd	}t||d}tj||dddd}t||dd d!d
i fdi f}t||d}t||d"d!d!d
i fdi f}t||d"}t||d#d!d$d
i fdi f}t||d%}|j	|d&gd'}|j
|d( |d)fS )*NTr)   r*   Zvggar+   r   r,   r	   r-   r   r/   r   r;   r5   r1   r6   r8      r<   r=   r   r?   pool4r@   rX   conv6pool6conv7conv8pool8fcixi b  rC   fcxfcxir   rF   r   r   r   rZ   rH   ) r"   r#   rN   r$   r,   rO   r5   r8   rP   r<   r=   rQ   r?   rR   r]   r@   rS   r^   rT   r_   r`   rU   ra   Zrelu8rb   rc   Zreluixrd   Zreluxre   rF   r   r   r   r   VGGA  s                   rf   c                 C   s  t | ||d ||ddi fdi f}t | ||}t | ||d ||d ddi fdi f}	t | |	|	}	t j| |	|d |d |d ddi fdi fdd		}
t | |
|
}
t | ||d
 ||d ddi fdi f}t | ||}t j| ||d |d |d ddi fdi fdd		}t | ||}t j| ||d dddd}t | ||d ||ddi fdi f}t | ||}t | ||
||g|}|S )Nz:conv1r   r   r/   z:conv3_reducer   z:conv3r	   r;   z:conv5_reducez:conv5r:   r1   z:poolr7   r3   r4   z
:pool_proj)r   rI   rJ   rK   concat)r$   Z
input_blobZinput_depthZoutput_nameZconv1_depthZconv3_depthsZconv5_depthsZ
pool_depthr,   Zconv3_reducer=   Zconv5_reducer@   poolZ	pool_projoutputr   r   r   _InceptionModule  s                            rk   c                 C   s  | ddd}|r||d< t jd|d}tj|dddd	d
di fdi fddd
}t||d}tj||ddddd}t||dd	d	ddi fdi f}t|||}tj||dd	dddi fdi fdd	}t||d}	tj||	ddddd}
t||
ddd	ddgddgd}t||dddddgddgd	}tj||ddddd}t||dd ddd!gdd"gd	}t||d#d$d%d&d'gd(d	gd	}t||d#d)dddgd(d	gd	}t||d#d*d&d+d,gdd	gd	}t||d-d.dd%d/gddgd}tj||d0dddd}t||d1d2dd%d/gddgd}t||d1d3d4dd4gd"dgd}tj||d5d
dd6}t||d7d8d9di fdi f}t	||d:}|j
|d;gd<}|j
|d= |d'fS )>NTr)   r*   Z	inceptionr+   r   r,   r	   r-      r   r/   r1   r2   r5   r   rg   conv2ar8   r9   r;   r<   inc3rW   r\          r   inc4rA   i  inc5   0   rX   inc6   p   rG      inc7inc8   i   i  inc9i@  pool9i@  inc10inc11r>   pool11r6   r   rY   r   rF   r   r   r   )r   r   r   rI   rJ   rK   rk   Zaverage_poolr   rL   r   r    r!   )r"   r#   rN   r$   r,   rO   r5   rm   r8   rP   r<   rn   rq   rA   rr   ru   ry   rz   r|   r}   r~   r   r   r   rF   r   r   r   r   	Inception  s4                                                                             r   c                 C   sj   t | d}| jj|dddddd}| jjg dd	gd
d}| jD ]$}| j| }| j||||g| q@dS )zC Simple plain SGD update -- not tuned to actually train the models iterLRg:0yEstepi'  g+?)Zbase_lrpolicyZstepsizegammaONEr         ?)shapevalueN)	r   r   r   ZLearningRateparam_init_netr/   paramsZparam_to_gradZWeightedSum)r$   ZITERr   r   paramZ
param_gradr   r   r   AddParameterUpdate6  s         

r   c              	   C   s  | |j |j\}}|j| _|j| _|j dkrD|jd||g}n|j||dg}|jdkrf|j|g}|jj	g d|ddd |jj
g d|jgd	d
d |jrtd|j n6td|j |dg t| |j dkrtd |js|j  |j  |jr"|j jD ]}|j|_q|jrtd|j|jd}|t|j  W 5 Q R X td|j|jd}|t|j  W 5 Q R X t|j t|j t|j j|j|j|j  d S )NNCHWr	   r   r   g        r   )r   ZmeanZstdr   r   i  )r   minmaxz{}: running forward only.z{}: running forward-backward.r   ZNHWCzU==WARNING==
NHWC order with CuDNN may not be supported yet, so I might
exit suddenly.z{0}_init_batch_{1}.pbtxtwz	{0}.pbtxt)!r"   r#   Znet_typeZPrototypeZnum_workers
batch_sizer$   r   ZGaussianFillZUniformIntFillZforward_onlyprintr   ZAddGradientOperatorsr   cpuZRunAllOnGPUr   ZengineopZ
dump_modelopenwritestrr   Z
RunNetOnceZ	CreateNetZBenchmarkNetr   Zwarmup_iterationsZ
iterationsZlayer_wise_benchmark)Z	model_genargr$   Z
input_sizeZinput_shaper   Zfidr   r   r   	BenchmarkA  sl    





 
  r   c                  C   s  t jdd} | jdtddd | jdtdd	 | jd
tddd | jdtdd	 | jdtddd | jdtddd | jdddd | jdddd | jdddd | jdtddd | jddd d | jd!td"d# | jd$td%d# | jd&d'dd( | jd)td* | S )+NzCaffe2 benchmark.)descriptionz--batch_sizer\   zThe batch size.)r   defaulthelpz--modelzThe model to benchmark.)r   r   z--orderr   zThe order to evaluate.z
--cudnn_wszThe cudnn workspace size.z--iterations
   z(Number of iterations to run the network.z--warmup_iterationsz1Number of warm-up iterations before benchmarking.z--forward_only
store_truez"If set, only run the forward pass.)actionr   z--layer_wise_benchmarkz.If True, run the layer-wise benchmark as well.z--cpuz+If True, run testing on CPU instead of GPU.z--engine z8If set, blindly prefer the given engine(s) for every op.z--dump_modelz*If True, dump the model prototxts to disk.z
--net_typeZdag)r   r   z--num_workersr1   z
--use-nvtxF)r   r   z--htrace_span_log_path)r   )argparseArgumentParseradd_argumentintr   )parserr   r   r   GetArgumentParser  s|    r   __main__Zcaffe2z--caffe2_log_level=0z--caffe2_use_nvtxz--caffe2_htrace_span_log_path=)rV   r[   rf   r   r   )__doc__r   Zcaffe2.pythonr   r   r   r   rV   r[   rf   rk   r   r   r   r   __name__parse_known_argsargs
extra_argsr   r$   r"   
print_helpZ
GlobalInitZuse_nvtxZhtrace_span_log_pathZ	model_mapr   r   r   r   <module>   sJ   :#]V @XB@
