U
    9%e                   @  sbD  d Z ddlmZ ddlZddlZddlZddlZddlZddlm	Z	m
Z
mZmZmZmZ ddlZddlm  mZ ddlZddlZddlmZ ddlmZmZmZmZmZ ddlmZ ddlmZmZm Z  dd	l!m"Z" d
dddddddddddddddddddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2d3d4d5d6d7d8d9d:d;d<d=d>d?d@dAdBdCdDdEdFdGdHdIdJdKdLdMdNdOdPdQdRdSdTdUdVdWdXdYdZd[d\d]d^d_d`dadbdcdddedfdgdhdidjdkdldmdndodpdqdrdsdtdudvdwdxdydzd{d|d}d~dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd ddddddddd	d
ddgZ#ej$e j%ddZ&dd Z'ddddZ(ej)dd Z*e&dej)ddddZ+e&dej)ddddZ,e&de-dej)ddd dфZ.e&d!e-dej)ddd"dЄZ/e&d#ej)dFddd$dZ0e&d%ej)dGddd&dZ1e&d'ej)dHddd(dՄZ2e&d)ej)ddd*dZ3e&d+ej)ddd,d<Z4e&d-e5d.d.d.d/ej)dIddd1dZ6e5d.d.d2ej)ddd3d4Z7ej)ddd5d6Z8ej)ddd7d8Z9e&d9ej)ddd:dLZ:e&d;ej)ddd<dNZ;e&d=ej)ddd>dZ<e&d?ej)ddd@dȄZ=e&dAe5d.dBej)dddCd$Z>e&dDe5d.dBej)dddEdZ?e&dFej)dddGdHZ@e&dIej)dddJdZAe&dKej)dddLd ZBe&dMej)dddNdZCe&dOe5d.d.d.ddej)dddPdZDe&dQej)dddRdZEe&dSej)dddTdZFe&dUej)dddVdԄZGe&dWej-ddXdYdZej)ddd[dZHe&d\ej)ddd]dބZIe&d^ej)ddd_d6ZJe&d`ej)dddadZKe&dbej)dddcdZLe&ddej)dddedZMe&dfej)dddgdZNe&dhej)dddidZOe&djej-ddkddZej)dddldۄZPe&dmej)dddnd܄ZQe-dej)dddodpZRej)dddqdrZSej)dJdsdtZTej)dud ZUe&dve'dwdxgdye&dze'd{d|gdye&d}e'd~dddgdyej)dKddddddZVe&de5d.dBdej)dddd9ZWe&dej)ddddZXe&dej)ddddZYe&dej)ddddZZe&de-dej)ddddZ[e&de-dej)ddddHZ\e&de-dej)dddd"Z]e&de-ddej)ddddGZ^e&de-de5d.d.dBdd.ej)ddddAZ_e&de-de5d.d.d.dBdBdBd.dBdB	ej)dddd@Z`e&dej-dddej)dLdddd߄Zae&de-de5d.dBdBej)ddddZbe&de5d.dej)ddddZce&de-dej)ddddZde&dej)ddddZee&de5d.dBdBdBej)dMddddZfe&de5d.d.dBdBej)dNddddZge&dej)dOddddZhe&de5d.ddBdBej)dPddddZie&dej)dQddddZje&de5d.dBdBej)dRddddZke&de-de5d.dBd.ej)ddddلZle&dej)ddddZme&dej)dSddddZne&dej)ddddZoe&dej)dddd݄Zpe&dej)ddddZqej)ddddÄZre&dăe-dej)ddddʄZse&dƃe-dej)dddd˄Zte&dȃej)dddd&Zue&dʃej)ddddMZve&d̃ej)ddd͐d΄Zwe&dσe5d.ddej)ddddZxe&dуe-de5d.d/dej)dTddҐdӐddԜddlZye&dփe5d.dBej)ddddVZze&d؃e5d.dBdej)dUddddZ{e&dڃej)ddddZ|e&d܃ej)ddU Z}e&de'dej~jjjddde(d߃gdye&de'dej~jjjddde(dgdye&de'dej~jjjddde(dgdyej)dd Ze&dedej~jjjdddZe&dedej~jjjdddZe&dedej~jjjdddZe&de'dej~jjje(dgdye&de'dej~jjje(dgdye&de'dej~jjje(dgdyej)dd Ze&de'ddej~jjje(dgdye&de'ddej~jjje(dgdye&de'ddej~jjje(dgdye&de'ddej~jjjee(dgdye&de'd dej~jjjee(d gdye&de'ddej~jjjee(dgdyej)dVddZej)ddddZej)d	d
 Ze&dej)dddd+Zej)ddҐdҐdddZe&de&de&dej)ddddɄZe&de&de&dej)ddddτZe&dej)ddҐdҐdҐdҐdddZe&de'ddde(dgdye&de'dd de(dgdye&d!e'd"d#de(d"gdye&d$e'd%ddse(d%gdye&d&e'd'd dse(d'gdye&d(e'd)d#dse(d)gdyej)dddd*d+d,Ze&d-ej)ddd.d/Ze&d0ej)ddd1dZe&d2ej)d3d Zej)d4d Zej)d5d5d6d7d	Ze&d8ej)ddd9d:Ze&d;e-ddej)ddd<dDZe&d=e-ddeej)ddd>dZe&d?e-ddej)ddd@dXZej)dddAdBZe&dCe-ddej)dddDdZej)dddEdFZe&dGe-ddeej)dddHdSZe&dIe-ddeej)dddJdkZe&dKej)dddLdMZe&dNej)dddOdPZe&dQej)dddRdSZe&dTedUej)dddVd{Ze&dWedUej)dddXd}Ze&dYedUej)dddZd~Ze&d[ej)ddd\d|Ze&d]ej)ddd^d_Ze&d`ej)dddadbZe&dce5d.d.d.dBej)dWdddddZe&dee5d.dBdej)dXdddfdvZe&dge5d.dBdBej)dddhdiZe&dje5d.d.d.ddddBddBdBdBdBdBej)dYdddkdlZe&dme5d.d.d.dd2ddBej)dddndoZe&dpe5d.d.d.ddddBddB	ej)dddqd5Ze&dre5d.d.d.dd.ddBej)dddsd1Ze&dte5d.d.d.dd.ddBej)dddud2Ze&dve5d.d.d.dd.ddBej)dddwd3Ze&dxe5d.d.d.ddddBdej)dddyd.Ze&dze5d.d.d.ddddBdej)ddd{d/Ze&d|e5d.d.d.ddddBdej)ddd}d0Ze&d~e5d.d.d.d.d.dBd/d/dB	ej)ddddZe&de-dddde5d.dd.d.d/ej)ddҐddҐdҐdӐddddZe&de-dddde5d.dd.d.d/dej)ddҐddҐdҐdӐddҐdddjZe&de5d.d.d.d.d.dd/d/d	ej)dddddddddZe&de5d.dBdBdBej)ddddZe&de-de5d.dddej)dddd?Ze&de-dej)ddddڄZe&de5d.dBd.ej)ddddbZe&dej)ddddaZe&dej)dddd`Ze&dej)dddd_Ze&de5d.d.ddej)dZdddd#Ze&dej)ddddZe&de5d.d.dBd/ej)dddd7Ze&dej)ddddZe&dej)dddd*Ze&dej)dddd
Ze&dej)ddddwZe&dej)ddddyZe&dej)ddddxZe&dej)ddddZe&dej)dddd)Ze&de5d.d.ej)dddd(Ze&de5d.d.ej)dddd'Ze&dej)d[ddddZe&de-ddej)ddddZe&dej)d\ddddZe&de-ddej)ddddZe&de-de5d.ddBej)ddddZe&de-de5d.ddBej)ddddZe&de-de5d.d.dBej)ddddZe&dej)ddddFZe&dăe&dŃe5d.d/dBej)dddd>Ze&de'dǃgdye&de'dȃgdye&de'dɃgdye&de'dʃgdye&de'd˃gdye&de'd̃gdyej)ddd͐d΄Ze&dσe5d.dddBd.ej)d]ddddZe&dуe5d.d.d.dBej)dddd-Ze&dӃe5d.dBdBej)dddԐdՄZe&dփe5d.dBdBdBej)dddאd؄Ze&dكeېdڐdېdܡej)dddݐdބZe&d߃eېdڐdېdܡej)ddddZe&deېdڐdېdܡej)ddddZe&deېdڐdېdܡej)ddddZe&deېdڐdېdܡej)ddddZe&deېdڐdېdܡej)ddddZe&deېdڐdېdܡej)ddddZe&deېdڐdېdܡej)ddddZe&deېdڐdېdܡej)ddddZe&de5d.dBd.d.d.d.ej)d^ddddCZe&de5d.dBd.d.d.d.ej)d_ddddBZe&dej)d`ddddZe&dej)ddddքZe&dej)daddd dZe&dej)dbddddZe&de5d.dBd.d.d.ej)dcddddZe&de5d.dBd.d.d.d.ej)dddddd
Ze&dej)deddddZe&d	ej)ddd
dZe&de5d.dBd.d.d.ej)dfddddZe&de5d.dBd.d.d.d.ej)dgddddZe&dej)dhddddZe&dej)diddddQZe&dej)djddddPZe&dej)dkddddZe&dej)ddddIZe&dej)ddddZe&de-de5d.d/d/ej)ddҐdӐdӐddd]Ze&de-de5d.ej)dddd\Ze&d ej-ddkddZe5d.ej)ddd!d[Ze&d"e5d.ej)ddd#dZe&d$e5d.d/ej)ddd%dZZe&d&e5d.d/ej)ddd'dZe&d(ej)ddd)dZe&d*e5d.dBej)ddd+dZe&d,e5d.dBdBdej)dlddd-dZe&d.ej)ddd/dZ e&d0e5d.dBdBdBdBdej)dmddd1dZe&d2ej)ddd3d4Ze&d4ej)ddd5dZe&d6ej)ddd7d΄Ze&d8ej)dnddd9d̈́Ze&d:e5d.dBej)ddd;dZe&d<e5d.dBej)ddd=dZej)doddd>d?Ze5d.d.d.dBdBd/dBdBdB	ej)ddd@dAZ	e5d.d.d.d.dBdBd/dBdB	ej)dddBdCZ
e&dDej)dddEdZe&dFej)dddGdZe&dHe'dIe(dJgdye&dKe'dLe(dMgdye&dNe'dOe(dPgdyddQdRdSZe&dTe5d.dBej)dddUdVZe&dWej)dddXd:Ze&dYe5d.dBej)dddZd,Ze&d[e5d.d.dBej)ddd\d]Ze&d^e5d.d.dBdd.ej)ddd_d`Ze&daej)dddbdńZe&dcej)dddddĄZe&deej)dddfdǄZe&dgej)dddhdÄZe&diej)dpdddjdƄZe&dkej)dqdddldZe&dme5d.d/d/dBdej)dddndӄZe&doej)drdddpdZe&dqe5d.ej)dddrduZe&dse5d.ej)dddtdEZe&due-ddde5d.dBdBej)dddvdKZe&dwe5d.ej)dddxdZe&dyej)dsdddzdZe&d{e5d.ej)ddd|dgZ e&d}ej)ddd~dZ!e&dej)ddddZ"e&de5d.dBdBdBej)ddddZ#e&de5d.d.dej)dddddddZ$e&de5d.d.dej)dddddddZ%e&de5d.dBd.d.ej)dddd؄Z&e&de5d.dBd.d.ej)ddddׄZ'e&dej)ddddzZ(e&dej)ddddeZ)e&dej)ddddZ*e&deej)ddddZ+e&dej)ddddZ,e&de5d.dBd.d.ej)dtddddRZ-e5d.ddBdBej)ddddZ.e&dej)ddddZ/e&dej)ddddZ0e&dej)ddddZ1e&dej)ddddZ2e&de5d.ddBej)ddddZ3e&dej)ddddZ4e&dej)ddddtZ5e&dej)ddddnZ6e&dej)ddddZ7e&dej)ddddZ8e&dej)ddddcZ9e&de5d.d.ddd.ej)dddddddddqZ:e&de5d.d/ddd.ej)dddӐddddddrZ;e&de5d.d.ddd.ej)dddddddddpZ<e&de5d.d.dBej)duddddoZ=e&de5d.ddej)dvddddOZ>e&dÃe5d.dBdd.ej)dwddddZ?e&dŃej)ddddZ@e&dǃe5d.d2ej)dxddȐdɜddZAe&d˃ej)dddd̄ZBe&d̓e5d.d2ej)dyddddΜddTZCe&dЃe-dddde5d.dBd.d.d/dBej)ddddWZDe&d҃e5d.d.dBej)dddӐdԄZEe&dՃej)dddd;ZFe&d׃ej)dddؐdلZGe&dڃej)dddېd܄ZHe&d݃ej)ddddhZIe&d߃ej)ddddZJej)ddddZKej)ddddZLe&de5d.d.dBdej)ddddiZMe&de5d.d.dBej)ddddZNe&de-de5d.d.ddBej)dzddddZOe&dej)ddddZPe&dej)ddddZQe&dej)ddddsZRe&de5d.ddBd.d.d.d.ej)d{dddddYZSe&dej)ddddZTe&dej)dddd=ZUe&de5d.ddej)ddddZVe&de5d.d.ej)ddddJZWe&dej)d|ddd d^ZXe&de5d.ddej)dddd҄ZYe&de5d.d.dBej)d}dddd8ZZe&dej)d~dddd%Z[e&d	ej)ddd
dmZ\e&dej)dddd!Z]e&ddddddfZ^e&dej)ddddZ_e&dej)ddddZ`e&dej)ddddZae&dej)ddddZbe&dej)dddddZce&dej)ddddZde&dddddZee&dej)ddddZfe&dej)dd d!d"dZge&d#ej)ddd$dZhe&d%ej)ddd&dZie&d'ej)ddd(dZje&d)ej)ddd*dZke&d+ej)ddd,dZle&d-ej)dd.d!d/dZme&d0ej)dd1d!d2dZne&d3ej)dd1d!d4dZoe&d5ej)ddd6dZpe&d7ej)ddҐd8d9dZqe&d:ej)ddd;dZre&d<e&d=ej)ddҐd>d?dZse&d@e&dAej)ddҐd>dBd Zte&dCej)ddddDdEdZudS (  zhThis file exports ONNX ops for opset 9.

Opset 9 is supported by ONNX release 1.4.1
release on 01/23/19
    )annotationsN)CallableListOptionalSequenceTupleUnion)_C)
_constants_deprecation_type_utilserrorssymbolic_helper)GLOBALS)	_beartype	jit_utilsregistration)Numberabsacosaddaddcmuladdmmaliasamaxaminaminmaxarangeargmaxargmin
as_strided	as_tensorasinatanatan2baddbmm
batch_norm	bernoullibitwise_not
bitwise_orbmmbroadcast_tensorsbroadcast_to	bucketizecatcdistceil	clamp_max	clamp_minclampcloneconstant_pad_nd
contiguousconv_tbcconv_transpose1dconv_transpose2dconv_transpose3dconv1dconv2dconv3dconvert_element_typeconvolutioncoscosine_similaritycrosscumsumdetachdimdivdotdropouteluembedding_bag	embedding
empty_likeemptyeqerfexp	expand_asexpandeyefillflattenfloor_dividefloorfloordivfrobenius_norm	full_likefullgathergegeluget_pool_ceil_paddingglu
group_normgthann_window
hardshrinkhardsigmoid	hardswishhardtanh	index_add
index_copy
index_fill	index_putindex_selectindexinstance_normis_floating_point	is_pinnedisnanitemkl_div
layer_normle
leaky_relulerpliftlinalg_crosslinalg_matrix_normlinalg_normlinalg_vector_normlinearlinspacelog_sigmoidlog_softmaxloglog10log1plog2logical_andlogical_not
logical_orlogical_xorlogit	logsumexp	lstm_celllstmltmasked_fillmasked_fill_matmulmax_pool1d_with_indicesmax_pool2d_with_indicesmax_pool3d_with_indicesmaxmaximummeshgridminminimummishmmmovedimmse_lossmulmultinomialmvnarrownative_layer_normneneg	new_emptynew_fullnew_ones	new_zerosnonzero_numpynonzeronormnumelnumpy_Tone_hot	ones_likeonesonnx_placeholderoverload_by_arg_countpadpairwise_distancepermutepixel_shufflepixel_unshufflepowpreluprim_constant_chunkprim_constant_splitprim_constant	prim_dataprim_device
prim_dtypeprim_ifprim_layoutprim_list_constructprim_list_unpack	prim_loopprim_maxprim_min
prim_shapeprim_tolistprim_tuple_construct	prim_typeprim_unchecked_castprim_uninitialized	rand_likerandrandint_likerandint
randn_likerandn
reciprocalreflection_padrelurelu6	remainderrepeat_interleaverepeatreplication_pad
reshape_asreshaperollrrelursqrtrsubscalar_tensorscatter_addscatterselectselusigmoidsignsilusinsizeslicesoftmaxsoftplus
softshrinksortsplit_with_sizessplitsqrtsquaresqueezestackstd_meanstdsubttaketantanh
tanhshrinktensor	thresholdtotopk	transposetrue_dividetype_asunbindunfoldunsafe_chunkunsafe_split_with_sizesunsafe_split	unsqueezeunsupported_complex_operatorsnoop_complex_operatorsunusedvar_meanvarview_asviewwherewrap_logical_op_with_cast_towrap_logical_op_with_negation
zeros_likezeroszero	   )opsetc                    s    fdd}|S )z_Returns a decorator that calls the decorated (higher-order) function with the given parameters.c                   s
   |  S N fnargskwargsr  Y/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py_apply+  s    z_apply_params.<locals>._applyr  )r  r  r!  r  r  r   _apply_params(  s    r"  strnamec                   s    fdd}|S )z5Exports the function in the current global namespace.c                   s   | t   < t  | S r  )globals__all__appendfuncr$  r  r   wrapper4  s    

z_export.<locals>.wrapperr  )r%  r+  r  r$  r   _export1  s    r,  c                 C  s   |  d}|tj  |S )z%Represents "missing" optional inputs.prim::Constant)opsetTyper	   OptionalTypeZofTensor)gnr  r  r   r  <  s    
zaten::_shape_as_tensorzjit_utils.GraphContextr1  c                 C  s   |  d|S NShaper.  r1  inputr  r  r   _shape_as_tensorD  s    r9  zaten::_reshape_from_tensorc                 C  s*   t |tr| jd|ddi}t| ||S )NConcataxis_ir   )r:  )
isinstancelistr.  r   )r1  r8  shaper  r  r   _reshape_from_tensorJ  s    
r?  zaten::reshapeTc                 C  s   t | ||S r  )r   _reshape_helperr1  selfr>  r  r  r   r   R  s    zaten::reshape_asc                 C  s   |  d|}t| ||S r4  r.  r   r1  rB  otherr>  r  r  r   r   Y  s    z	aten::addc                 C  sZ   t |r&t |r&t dddd|S |rLt t |dkrL| d||}| d||S )NAddr     z)Add between list of tensors not supported   Mul)r   	_is_value_is_tensor_list _onnx_opset_unsupported_detailed_scalar_maybe_get_scalarr.  r1  rB  rE  alphar  r  r   r   a  s        z	aten::subc                 C  s4   |r&t t |dkr&| d||}| d||S )NrH  rI  Sub)r   rM  rN  r.  rO  r  r  r   r   m  s    z
aten::rsubc                 C  s   t | |||dS )N)rP  )r   rO  r  r  r   r   u  s    z	aten::mulc                 C  s4   t |r"t |r"| d||S | d||S d S )NAndrI  )r   _is_boolr.  r1  rB  rE  r  r  r   r   {  s    z	aten::divc                 G  s.   t |dkrt| ||S t| ||f| S d S Nr   )lenr  _div_rounding_mode)r1  rB  rE  r  r  r  r   rF     s    zaten::addcmulvf      ?c              	   C  s2   | j dt|gd}t| |t| t| |||S NConstantZvalue_t)r.  torchr   r   r   )r1  rB  Ztensor1Ztensor2valueZ
value_tensr  r  r   r     s    sc                 C  sT   |d krt | ||S |dkr(t| ||S |dkr<t| ||S td| d|d S )NrW   trunczUnsupported rounding mode: "z$". Expected None, "floor" or "trunc")r  _floor_divide_trunc_divider   SymbolicValueError)r1  rB  rE  Zrounding_moder  r  r   rW    s    
rW  c                 C  s   |  d||}| j d|tjjd}tj|tjj}|tjjkrt	|sjt	|rj| j d|tjj
d}q| j d|| d}n| j d|tjj
d}|S )NDivCastZto_i)r.  _C_onnxTensorProtoDataTypeINT64r   JitScalarType
from_value	UNDEFINEDr   _is_fpFLOAT	onnx_type)r1  rB  rE  outscalar_typer  r  r   rc    s      rc  c                 C  s   t |st |r,t| ||}| d|S | d||}| jdtjdtjdd}| dt | ||t | ||}| d|| d	||}| d
|| d| d||}| jdtjdtjdd}	| d	||	}
| d||
S d S )NFloorre  r\  r   dtyper]  XorrQ  rI  rR  NotEqualrH  )r   rn  r  r.  r^  r   int64Z
_lt_helper)r1  rB  rE  rq  rF   r  negativemodZ
fixup_maskonefixupr  r  r   rb    s     rb  zaten::floor_dividec                 C  s   t | ||S r  )rc  rT  r  r  r   rV     s    zaten::floordivc                 C  s   t | ||S r  )rV   rT  r  r  r   rX     s    zaten::true_dividec                 C  s   t |st |r"| d||S t }tjj}|tjksJ|tj	ksJt
t tj	kr`tjj}| jd||d}| jd||d}| d||S )a  Division where both inputs are cast to floating types

    If both inputs are floating, performs div as usual
    If only one input is a floating type, the other input is cast to its type
    If neither input is a floating type, both inputs are cast to the default scalar type
    re  rf  rg  )r   rn  r.  r^  get_default_dtyperh  ri  ro  floatdoubleAssertionErrorDOUBLE)r1  rB  rE  rr  Zonnx_scalar_typer  r  r   r    s    zaten::reciprocalc                 C  s*   t |s| jd|tjjd}| d|S )Nrf  rg  
Reciprocal)r   rn  r.  rh  ri  ro  r1  rB  r  r  r   r   	  s    
z	aten::catic                   s   t |}g  |D ]&}t |r.t |ds.q | qt dksJtt fdd D sdt| 	   D ]}| 
| qtt |}| jd|d|iS )Nr   c                 3  sF   | ]>}t  d  dkp<t |dkp<t |t  d  kV  qdS r   N)r   _get_tensor_rank.0r   Znonempty_tensorsr  r   	<genexpr>"  s   zcat.<locals>.<genexpr>r:  r;  )r:  )r   _unpack_list_is_constant_get_tensor_dim_sizer(  rV  r  allnodeZremoveAllInputsZaddInputr.  )r1  tensor_listrE   tensorsr   r  r  r   r.     s$    
 
zaten::stackc                   s.    fddt |D }jd|d iS )Nc                   s   g | ]}t | gqS r  r   _unsqueeze_helperr  rE   r1  r  r   
<listcomp>5  s   zstack.<locals>.<listcomp>r:  r;  )r:  )r   r  r.  )r1  r  rE   Z
unsqueezedr  r  r   r   1  s    z
aten::listc                 C  s   |S r  r  r  r  r  r   _list<  s    r  zaten::mmc                 C  s,   | j dtdgd}| j d|||dddS )Nr\  rH  r]  Gemm        rZ  Zbeta_falpha_fr.  r^  r   )r1  rB  rE  Cr  r  r   r   B  s    z	aten::bmmc                 C  s   |  d||S NMatMulr6  rT  r  r  r   r*   K  s    zaten::matmulc                 C  s   |  d||S r  r6  rT  r  r  r   r   Q  s    zaten::addmmc              	   C  sH  d }t |}t |}t |}	|d k	r0|}n|d k	r>|}n|	d k	rJ|	}t |}
t |}dd }|d k	r&||
ds||dr&| d||}|}t |}t |}|dkr| jdtj|| dd}| d	||}|dkr| jdtjt || dd}| d	||}| d
||S | jd|||t |t |dS )Nc                 S  s   | d k	o| |kS r  r  )rX  ur  r  r   is_not_none_nori  s    zaddmm.<locals>.is_not_none_nor   r  rH  r\  rt  r]  rI  rF  r  r  )r   _try_get_scalar_typer  r.  rM  r^  r   ru  )r1  rB  Zmat1Zmat2betarP  rr  self_scalar_typeZmat1_scalar_typeZmat2_scalar_typeZ	mat1_rankZ	mat2_rankr  Zres1Zres2r  r  r   r   W  s\    







 
 z	aten::negc                 C  s   |  d|S )NZNegr6  r  r  r  r   r     s    z
aten::sqrtc                 C  sT   t j|t jjt jjt jjt jjt jjt jjhkrH| j	d|t
jjd}| 	d|S )Nrf  rg  Sqrt)r   rk  rl  rm  UINT8INT8INT16INTrj  r.  rh  ri  ro  r  r  r  r   r     s     
zaten::rsqrtc                 C  s"   |  dttd|t| |S )Nre  rH  )r.  r   _if_scalar_type_asr^  r   r   r  r  r  r   r     s
      z
aten::tanhg      ?   )scaleZ
zero_pointc                 C  s   |  d|S )NTanhr6  r  r  r  r   r     s    z	aten::sinc                 C  s   |  d|S )NZSinr6  r  r  r  r   r     s    z	aten::cosc                 C  s   |  d|S )NZCosr6  r  r  r  r   r@     s    z	aten::tanc                 C  s   |  d|S )NZTanr6  r  r  r  r   r     s    z
aten::asinc                 C  s   |  d|S )NZAsinr6  r  r  r  r   r"     s    z
aten::acosc                 C  s   |  d|S )NZAcosr6  r  r  r  r   r     s    z
aten::atanc                 C  s   |  d|S )NAtanr6  r  r  r  r   r#     s    zaten::atan2c              
   C  s   |  d||}|  d|}| j dtdd}| j dttjd}|  d||}|  d||  d|||  d	||}|  d
||}	|  d|	||}
|
S )Nre  r  r\  r   r]  GreaterWhererF  rQ  Less)r.  r^  r   mathpi)r1  rB  rE  sloper#   Z
const_zeroZconst_piZ"condition_second_or_third_quadrantZsecond_third_quadrantZcondition_14_or_23_quadrantresultr  r  r   r$     s    zaten::sigmoidg      p?c                 C  s   |  d|S )NSigmoidr6  r  r  r  r   r     s    z
aten::signc                 C  s   |  d|S )NZSignr6  r  r  r  r   r     s    c                 C  sR   t |t |kstt |dkr>|d dkr>|d tjkr>|S | jd||||dS )NrH  r   Slice)axes_iZstarts_iZends_i)rV  r  r
   	INT64_MAXr.  )r1  r8  axesstartsendsr  r  r   _slice   s    &r  c                 C  sL   t j|t jj}|t jjkrHt|sH|t jjkrH| jd|tj	jd}|S Nrf  rg  )
r   rk  rl  rm  r   rn  rj  r.  rh  ri  r1  rB  rr  r  r  r   _maybe_cast_reduce_op_input	  s     
r  c                   s   t jd fdd	}|S )Nc                   sx   t | |}|d ks|t kr*t| |S  r2dnd}t||dt|dd }} r\|n|g}| j|||dS d S )Nisr  rE   keepdimr  
keepdims_i)r  tupler   Z_handle_reduce_dim_none
_get_constr.  )r1  rB  rE   r  descZdim_listallow_multi_dim_supportonnx_op_namer  r   symbolic  s    
  z%_reduce_op_symbolic.<locals>.symbolic)NN)r   beartype)r  r  r  r  r  r   _reduce_op_symbolic  s    r  c                   s    t  tj fdd}|S )Nc                   s`    | f| }|D ],}|j }t|t|kr|| f|   S qtd j dt| dS )Nzaten::with 
 arguments)Z_arg_descriptorsrV  r   _unimplemented__name__)r1  r  Z	overloadsoverloadZarg_descriptorsr  r  r   r+  1  s    
 z&overload_by_arg_count.<locals>.wrapper)	functoolswrapsr   r  )r  r+  r  r  r   r   /  s    
z	aten::sumZ	ReduceSumsum)Zdecoratez
aten::mean
ReduceMeanmeanz
aten::prod
ReduceProdprodFr  bool)onnx_opr%  r  c                   s$   t |  dt fdd}|S )Nr  c                   s`   t dt ddfdd} r,dnd}t dt d|ddfdd	}||fS )
NTrX  nonec                   s   d }|   dkrBt|dd}t| }| jd||d}n|   dkr`t d|S | |}|d k	rtj	| }||kr| jd||d}|S Nonnx::Constantr  ru  rf  rg  r-  
r  kindr   r  r   rk  rp  r.  r  rl  )r1  rB  ru  
dtype_onnxr  result_dtype_onnxr%  r  r  r   reduce_nodimO  s    
z8_reduce_with_dtype.<locals>.reduce.<locals>.reduce_nodimr  r  c                   s   d }|   dkrBt|dd}t| }| jd||d}n|   dkr`t d|S | |||}|d k	rtj	| }||kr| jd||d}|S r  r  )r1  rB  rE   r  ru  r  r  r  r  r  r   
reduce_dimd  s    z6_reduce_with_dtype.<locals>.reduce.<locals>.reduce_dim)r   quantized_args
parse_args)r1  r  r  r  Zdim_descr  r  r%  r  r  r   reduceM  s    
z"_reduce_with_dtype.<locals>.reduce)r  r   )r  r%  r  r  r  r  r   _reduce_with_dtype@  s    	 +r  zaten::cumsumr  c                 C  sJ   t  r6|  dkr&t dd|S | jd||dS t ddd| d S )Nr-  rC   ru  dim_ir  rG  )r   is_caffe2_aten_fallbackr  r  r  at_onnx_opset_unsupported)r1  r8  rE   ru  r  r  r   rC   |  s
    zaten::_sample_dirichletc                 C  s8   t  r,t |s t dd|S | d|S t d|S )N_sample_dirichletz#We are not able to export generatorr   r  _is_noner  r  _onnx_unsupportedr1  rB  	generatorr  r  r   r    s    
  r  zaten::_standard_gammac                 C  s8   t  r,t |s t dd|S | d|S t d|S )N_standard_gammaznot able to export generatorr  r  r  r  r   r    s    
  r  zaten::tc                 C  s6   t |}|d ks|dk r&| d|S | jd|ddS )Nr  Identity	Transpose)rH  r   Zperm_i)r   r  r.  )r1  rB  rankr  r  r   r     s    
zaten::numpy_Tc                 C  s8   t |}|d k	sttttd|}| jd||dS Nr   r  r  )r   r  r  r=  reversedranger.  )r1  r8  ndimpermr  r  r   r     s    
zaten::expandc              	   C  s   t |d}t |s,| jdt|d}n2t |r^t | t| |d| jdt	dgd}t
jj}t| ||}t| || jdt	dd}t| | d||||}| d||S Nr  r\  r]  r   rx  Expandr   _maybe_get_constrJ  r.  r^  
LongTensor_is_packed_listr@  r   r   r   rk  rj  r   r   r  )r1  rB  r   Zimplicitru  r   neg_onesr  r  r   rR     s    

 
 zaten::broadcast_toc              	   C  s   t |d}t |s,| jdt|d}n2t |r^t | t| |d| jdt	dgd}t
jj}t| ||}t| || jdt	dd}t| | d||||}| d||S r  r  )r1  rB  r   ru  r   r	  r  r  r   r,     s    

 
 zaten::expand_asc                 C  s   t |d}t|tjr|j}|tj}g }t|	 D ]J}t
|||||r:|| | jd|j|dd|d}q:| d|}| d||S )Nr   r\  T)r  r]  r5  r  )r   r  r<  r^  Tensorru  r   r  r  rE   equalr  r	  rQ   r(  r.  )r1  rB  rE  Zself_t	orig_typedimsdr>  r  r  r   rQ     s    
 zaten::embeddingbc                 C  s<   |rt jrtd||dkr.t jr.td | d||S )NzUnsupported: ONNX export of embedding with scale_grad_by_freq=True for training mode. ONNX does not support scaling the gradients.r   zWarning: ONNX export of embedding with padding_idx >= 0 for training mode. ONNX does not support not updating the embedding vector at padding_idx during training.Gather)r   Zexport_trainingr   rd  warningswarnr.  )r1  weightindicespadding_idxscale_grad_by_freqsparser  r  r   rK     s    
zaten::embedding_bagc
           
      C  sF   t |st dS t  r:| jd|||d|||||	d
S t d|S )Nz%embedding_bag with per_sample_weightsrJ      )outputsZscale_grad_by_freq_iZmode_iZsparse_iZinclude_last_offset_iZpadding_idx_i)r   r  r  r  r  )
r1  Zembedding_matrixr  offsetsr  moder  Zper_sample_weightsZinclude_last_offsetr  r  r  r   rJ     s$    
z
aten::size)Zquantize_outputc                 C  sh   |d kr|  d|S t|ddk rZt|}|d k	rZt|d| }| j dt|d}t| ||S )Nr5  r  r   r\  r]  )r.  r   r  r  r^  r   Z_size_helperr1  rB  rE   r  r  r  r   r   6  s    
zaten::transposec                 C  s   ||kr|S t |}|d k	rTtt|}|| ||  ||< ||< | jd||dS t  rp| jd|d||dS td|d S )Nr  r  r  int)overload_nameZdim0_iZdim1_izAUnsupported: ONNX export of transpose for tensor of unknown rank.)	r   r  r=  r  r.  r  r  r   rd  )r1  rB  Zdim0Zdim1r  r  r  r  r   r  D  s    
zaten::permuter  c                 C  s*   |t tdt|kr|S | jd||dS r  )r=  r  rV  r.  )r1  rB  r  r  r  r   r   ]  s    z
aten::viewc                 C  s   t | ||S r  )r   )r1  rB  r   r  r  r   r  f  s    zaten::view_asc                 C  s   |  d|}t| ||S r4  rC  rD  r  r  r   r  m  s    zaten::unsafe_chunkc           	      C  s   |d krt dddd|S t ||}|d kr<t dd|S || d | }|g||  }|| }|rp|| | jd||||dS )	Nr  r  rG  'Dynamic number of outputs not supportedunknown dimension sizerH  SplitZsplit_ir;  r  )r   rL  r  r  r(  r.  )	r1  rB  chunksrE   _outputsr   
split_sizesplitsleftoverr  r  r   r  t  s*          
zaten::splitc           
      C  s   t ||st dddd|S t | d}| dkrJt| ||||S t |dd}t ||}|d kr|d k	r~|| }nt dddd	|S |g||  }|| }	|	r|	|	 | j
d
||||dS )Nr   r  rG  r  r_  r   r  r%  z$Unknown dimension size not supportedr!  r"  )r   _is_split_staticrL  	_node_getr  rE   r   r  r  r(  r.  )
r1  rB  split_size_or_sizesrE   r$  Z	split_valr%  r   r&  r'  r  r  r   r     s8        
    
zaten::unsafe_splitc                 C  s   t | ||||S r  )r   )r1  rB  r*  rE   r$  r  r  r   r    s    zaten::split_with_sizesc                 C  s2   t ||st dddd|S | jd||||dS )Nr   r  rG  r  r!  r"  )r   r(  rL  r.  r1  rB  Zsplit_sizesrE   r$  r  r  r   r     s        zaten::unsafe_split_with_sizesc                 C  s   t | ||||S r  )r   r+  r  r  r   r    s    zaten::unbindc                   s^   |d krt dddd|S jd|dg|  |d}|dkrB|gn|} fdd	|D }|S )
Nr  r  rG  r  r!  rH  r"  c                   s   g | ]}t | gqS r  )r   _squeeze_helper)r  rq  r  r  r   r    s    zunbind.<locals>.<listcomp>)r   rL  r.  )r1  rB  rE   r$  r  Zsqueezed_outputsr  r  r   r    s        zaten::selectc                 C  st   t |}t |s^|dk r^|dkr,tj}n|d }t j| ||g|g|gd}t | ||gS | jd|||dS d S )Nr   r  rH  r  r  r  r  r;  )r   rN  rJ  r
   r  _slice_helperr,  r.  )r1  rB  rE   rm   Z	end_indexZ
slice_noder  r  r   r     s    
    zaten::squarec                 C  s   |  d||S NrI  r6  r  r  r  r   r     s    zaten::squeezec                 C  sJ  |d kr|  d|S t|dd}|dk rt|}|d k	rxtdt| d d d t||  d	 d
  ||7 }ntdd|S t||}|d krtdt| d d t| d d d d  tj	| ||gdS |dkrtdt| d d t| d d d d  |S tdt| d d  tj	| ||gdS )NZSqueezer  rE   r   z'ONNX export squeeze with negative axis - might cause the onnx model to be incorrect. (Negative axis is not supported in ONNX. Axis is converted to & based on input shape at export time. CPassing an tensor of different rank in execution will be incorrect.r   %negative axis with unknown input rankz5This model contains a squeeze operation on dimension z on an input z7with unknown shape. Note that if the size of dimension z of the input zVis not 1, the ONNX model will return an error. Opset version 11 supports squeezing on zMnon-singleton dimensions, it is recommended to export this model using opset zversion 11 or higher.r  rH  z. The size of z%this dimension in the given input is z. The model will zWbe exported without the squeeze node. If the model is intended to be used with dynamic z-input shapes, please use opset version 11 to zexport the model.z. If the model is z_intended to be used with dynamic input shapes, please use opset version 11 to export the model.)
r.  r   r  r  r  r  r#  r  r  r,  )r1  rB  rE   Zsqueeze_dimr  dim_sizer  r  r   r     s    



  
zaten::preluc              	   C  s   t |}t |}t|}|d k	rp|dkrJt | |ttd|d }n&|dkrp|dgkrpt | |dg}d}|d k	r|d k	r||kstd| d| | 	d||S )Nr  rH  r   z)rank(x) should be >= rank(slope) but got z < PRelu)
r   r  _get_tensor_sizesrV  r  r=  r  r,  r  r.  )r1  rB  r  	self_rankZweight_sizesZweight_rankr  r  r   r   +  s&    

  z
aten::siluc                 C  s   |  d||  d|S )NrI  r  r6  r7  r  r  r   r   C  s    z
aten::mishc                 C  s   |  d||  d|  d|S )NrI  r  Softplusr6  r7  r  r  r   r   I  s    c              
   O  s  | dd}| dtjj}t|}tj|d }t|d  oT|dkpTtj	|k }|r|D ]F}	|	
 r^tj|	}
|
|kr^td| d|  d|
  |	q^t|D ]2\}}	|	
 rt|	s| jd|	| d	||< q| j|f||}|r| jd|| d	}|S )
a  Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
    This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
    operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic
    `Clip<int>(INPUT)` (opset version < 12).

    Args:
        g (torch._C.Graph): graph to write the ONNX representation into.
        op_name (str): operator name in ONNX.
        *args (tuple): operands to the operator.
        **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default)
            indicating the smallest opset version to trigger such casting behavior and "target_float_t"
            (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator.

    Returns:
        Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator.
    opset_beforeNtarget_float_tr   z
Inputs of z must have same dtype.Got z and rf  rg  )popr   rk  ro  r=  rl  r   rn  r   export_onnx_opset_versionZisCompleteTensorr   rd  Zscalar_name	enumerater.  rp  )r1  Zop_namer  r  r=  r>  inputsZdtype_0Zrequire_castr8  Zinput_scalar_typer  rB  r  r  r   _op_with_optional_float_castO  s6    rC  z
aten::reluc                 C  s   t | d|ddS )NRelu   r=  rC  r7  r  r  r   r     s    zaten::relu6c                 C  s   t | |ddS )Nr      )r3   r7  r  r  r   r     s    z
aten::ceilc                 C  s   |  d|S )NCeilr6  r7  r  r  r   r0     s    zaten::floorc                 C  s   |  d|S )Nrs  r6  r7  r  r  r   rW     s    z	aten::lenc                 C  s.   t | || jdtdgd}t| |dgS Nr\  r   r]  )r   r.  r^  r  r   r,  )r1  rB  Zsz_0r  r  r   _len  s    rK  zaten::thresholdc                 C  sD   t |dkrt dd|S t |dkr8t dd|S | d|S )Nr   r   znon-zero thresholdznon-zero valuerD  )r   rM  r  r.  )r1  rB  r   r_  r  r  r   r     s
    zaten::leaky_relu_C.Valuer  r1  r8  Znegative_slopeZinplacec                 C  s   | j d||dS )N	LeakyRelur  r6  rM  r  r  r   rv     s    z	aten::gluc                 C  sP   t ||}|d k	r$|d dks$t| jd||dd\}}| d|| d|S )Nr  r   r!  )r;  r  rI  r  )r   r  r  r.  )r1  r8  rE   r8  firstsecondr  r  r   r`     s
    zaten::softmaxc              
   C  sb  t |}|d k	r|dk r"|| }||d k}|rptt|}|d ||  ||< |d< | jd||d}|d }| jd||d}|r|  dkrt |d	d
}| jd|t	|
 d}|r| jd||d}|S | d|| jd||gdd}| d|}	t j| |	|gd}
| d|	|
}|r^|  dkr^t |d	d
}| jd|t	|
 d}|S )Nr   rH  r  r  r  ZSoftmaxr.  r-  r  ru  rf  rg  rQ  	ReduceMaxr  Expr7  re  )r   r  r=  r  r.  r  r  r  r   rk  rp  _reducesum_helper)r1  r8  rE   ru  	input_dimis_transpose_requiredr  r   parsed_dtyperP   r  r  r  r   r     sB    
  zaten::softplusc                 C  s@   t |d}|dkr4| d| d| d|||S | d|S )NrY  rH  re  r<  rI  )r   r  r.  )r1  rB  r  r   Z
beta_constr  r  r   r     s     zaten::get_pool_ceil_paddingc                   s   t | }|d k	r$|t d  nd d ksBtdd D rPt dd| S fddtdtD   fddtdt D   fd	dtdtD fd
dtdtD S )Nc                 s  s   | ]}|d kV  qd S r  r  r  r  r  r  r   r    s     z(get_pool_ceil_padding.<locals>.<genexpr>r_   input size not accessiblec              	     sB   g | ]:}t t | d |   |  t|  d qS r  rH  )r  r  r0   r  rX  )rE   kernel_sizepaddingstrider  r   r     s   0z)get_pool_ceil_padding.<locals>.<listcomp>r   c                   sD   g | ]<} | d  |  | |  kr8 | d  n | qS rH  r  rX  )ceiled_output_dimrE   r\  r]  r  r   r  &  s   "c                   sP   g | ]H}| d krdn2| | d|    | d  |  d    qS )rH  r   r  r  rX  )r_  rE   r[  r\  r]  r  r   r  ,  s   
*c                   sd   g | ]\}| d |    | krT|  | d k rDt | q^t  | d n
t | qS rZ  r  rX  )r[  r\  padding_ceilr  r   r  6  s   )r   r:  rV  anyr  r  )r8  r[  r]  r\  sizesr  )r_  rE   r[  r\  ra  r]  r   r_     s*    
  

zaten::max_pool1dZ
max_pool1drH  )return_indiceszaten::max_pool2dZ
max_pool2dr  zaten::max_pool3dZ
max_pool3d   c              	     sD   t ddddddt ddddddtj fdd}|S )NTFrX  r  r  c                   s<  t |dhkr t d|S |s(|}t|}|rdt||||}|tdd t||D  }n|d }|||d}r| jd|fddi|\}	}
| jd|dd	d
 tD dd
 tD d\}}tj| |dd
 tD t	dt	dd}t
| |
|}
|	|
fS | jd|fddi|}	|	S d S )NrH  dilationc                 s  s   | ]\}}|| V  qd S r  r  r  ar  r  r  r   r  k  s     z1_max_pool.<locals>.symbolic_fn.<locals>.<genexpr>r  )kernel_shape_ipads_i	strides_iMaxPoolr  c                 S  s   g | ]}d qS r^  r  r  _r  r  r   r    s     z2_max_pool.<locals>.symbolic_fn.<locals>.<listcomp>c                 S  s   g | ]}d qS r^  r  rm  r  r  r   r    s     )r  ri  rk  c                 S  s   g | ]}d | qS )r  r  rX  r  r  r   r    s     r   r-  )setr   r  r  r_   zipr.  r  r/  r=  r   )r1  r8  r[  r]  r\  rf  	ceil_modera  r  rr  rn  Zflattened_indicesr`  r%  ndimsrd  tuple_fnr  r   symbolic_fn`  sB    


z_max_pool.<locals>.symbolic_fnr   r  r  r   r  )r%  ru  rt  rd  rv  r  rs  r   	_max_poolC  s
    4rx  zaten::max_pool1d_with_indiceszaten::max_pool2d_with_indiceszaten::max_pool3d_with_indiceszaten::avg_pool1dZ
avg_pool1dzaten::avg_pool2dZ
avg_pool2dzaten::avg_pool3dZ
avg_pool3dc                   sJ   t dt dddddddtjdddddd	d	d
 fdd}|S )NTrX  r  r  r  rL  Sequence[int]zUnion[int, Sequence[int]]r  )r8  r[  r]  r\  rq  count_include_padc              	     s   |s|}t |||| }t|ts*t|}|r\t| d|d| d dddd}dt| }|rt||||}	|td	d
 t|	|D  }n|d }| j	d||||d}
|
S )NPad)r   r   r  constantr  rG  rj  mode_sZvalue_fr=  r   c                 s  s   | ]\}}|| V  qd S r  r  rg  r  r  r   r    s    z1_avg_pool.<locals>.symbolic_fn.<locals>.<genexpr>AveragePool)ri  rk  rj  )
r   Z_avgpool_helperr<  r  r  rC  rV  r_   rp  r.  )r1  r8  r[  r]  r\  rq  rz  Zdivisor_overrideZadjusted_paddingra  outputr%  ru  r  r   rv    sJ         
	
z_avg_pool.<locals>.symbolic_fn)Nrw  )r%  ru  rv  r  r  r   	_avg_pool  s    	 &1r  zaten::adaptive_avg_pool1dZadaptive_avg_pool1dr  zaten::adaptive_avg_pool2dZadaptive_avg_pool2dzaten::adaptive_avg_pool3dZadaptive_avg_pool3dzaten::adaptive_max_pool1dZadaptive_max_pool1drl  zaten::adaptive_max_pool2dZadaptive_max_pool2dzaten::adaptive_max_pool3dZadaptive_max_pool3dc                   s(   t ddtj fdd}|S )NTFc              	     s  }zt dW n  tk
r4   t d| Y S X dgt kr\dkr\| d|S t |}z|dd   W n tk
r   d  Y nX  d kstdd  D rڈdgt kr| d	|d fS t d
|S  fddt	dt D }|dgt| kr>dgt kr0| d	|d fS t d|S  fddt	dt D }dkr| |||dt  dt  dS | j|||d}|S )Nr  z4adaptive pooling, since output_size is not constant.rH  r  ZGlobalAveragePoolr  c                 s  s   | ]}|d kV  qd S r  r  rX  r  r  r   r  a  s     z6_adaptive_pool.<locals>.symbolic_fn.<locals>.<genexpr>ZGlobalMaxPoolrY  c                   s   g | ]} | |  qS r  r  rX  rE   output_sizer  r   r  h  s     z7_adaptive_pool.<locals>.symbolic_fn.<locals>.<listcomp>r   z-output size that are not factor of input sizec                   s    g | ]}t  | |  qS r  r`  rX  r  r  r   r  o  s     rl  r  r^  F)ri  rk  )
r   
_parse_arg	Exceptionr  rV  r.  r:  rb  r  r  )r1  r8  r  Zoutput_size_valuerc  r{  kr  r  r%  ru  typer  r   rv  C  sJ     


    
$z#_adaptive_pool.<locals>.symbolic_fn)r   r  r   r  )r%  r  ru  r  rv  r  r  r   _adaptive_pool  s    A
1r  r  rE   c                 C  sF   t |dd dg| d t|   }|ddd |ddd  }|S )zGenerate paddings in ONNX order based on pad in pytorch.
    Args:
        dim: the dimension of the tensor.
        pad: the paddings in pytorch.
            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
    Nr   r  r  )r=  rV  )rE   r   paddingsr  r  r   _prepare_onnx_paddingsy  s    &r  c              
   C  sh   t | d}t |rdt |rdt |}zdd |D }W n& tk
rb   t dddd|  Y S X |S )Nr  c                 S  s   g | ]}t |d dqS )r  r\  )r   r  )r  rX  r  r  r   r    s    z)_convert_padding_node.<locals>.<listcomp>r{  r  rG  z)The sizes of the padding must be constant)r   r  rJ  r  r  r  rL  )r8  r\  
input_listr  r  r   _convert_padding_node  s     

    
r  zaten::constant_pad_ndc              
   C  sl   d}zt |dd}W n& tk
r<   t dddd| Y S X t|}tt ||}t| d||||ddS )	Nr|  rY  r_  r{  r  rG  z*The value for the padding must be constantr}  )r   r  r  rL  r  r  r  rC  )r1  r8  r\  r_  r  r  r  r  r   r5     s,        
      )r1  r8  r   c                 C  sH  t |}t|d dkstt|d }|}t|D ]}|d| d   }|d| d   }g }	|dkrtj| |d| g| gtjgd}
|	|
 |dk s|dk rt	
d| }t	
d|  }tj| |d| g|g|gd}|	| n
|	| |dkr*tj| |d| gdg|gd}|	| | jd|	dd| i}q4|S )Nr  r   rH  r-  r:  r;  )r:  )r  rV  r  r  r   r/  r
   r  r(  builtinsr   r.  )r1  r8  r   r\  r   curidxZpad_rZpad_lr  leftstartendmiddlerightr  r  r   _pad_circular  sP        


    
r  zaten::reflection_pad1dzaten::reflection_pad2dzaten::reflection_pad3dc                 C  s0   d}t |}tt||}t| d|||ddS )Nreflectr{  rG  rj  r~  r=  r  r  r   r  rC  r1  r8  r\  r  r  r  r  r   r     s         zaten::replication_pad1dzaten::replication_pad2dzaten::replication_pad3dc                 C  s0   d}t |}tt||}t| d|||ddS )Nedger{  rG  r  r  r  r  r  r   r     s         z	aten::padr1  r8  r   r  r_  c                 C  st   t |d}|dkr t| ||S |dkr4t| ||S |dkrJt| |||S |dkr^t| ||S td| |d S )Nr`  Z	replicater  r|  ZcircularzUnrecognized padding mode )r   r  r   r   r5   r  r   rd  r  r  r  r   r     s    	zaten::upsample_nearest1dZupsample_nearest1dZnearestzaten::upsample_nearest2dZupsample_nearest2dr  zaten::upsample_nearest3dZupsample_nearest3d   zaten::upsample_linear1dZupsample_linear1dzaten::upsample_bilinear2dZupsample_bilinear2dzaten::upsample_trilinear3dZupsample_trilinear3d)r%  rE   interpolate_modec                   s    fdd}|S )Nc                   sb   t | |\}}t  t |}|r8t d|S |d krPt | || }| jd||dS )Nzalign_corners == TrueUpsampler~  )r   Z_get_interpolate_attributesZ_interpolate_warningrN  r  Z_interpolate_size_to_scalesr.  )r1  r8  r  r  scalesalign_cornersrE   r  r%  r  r   rv  7  s"      

   z!_interpolate.<locals>.symbolic_fnr  )r%  rE   r  rv  r  r  r   _interpolate  s    ,r  zaten::__interpolatec           	      C  s*   t | |||||\}}| jd|||dS )Nr  r  )r   Z _interpolate_get_scales_and_moder.  )	r1  r8  r   Zscale_factorr  r  Zrecompute_scale_factorZ	antialiasr  r  r  r   __interpolateH  s         r  zaten::bitwise_notc                 C  s"   t |std|| d|S NzOONNX export does NOT support exporting bitwise Not for non-boolean input valuesrw  r   rS  r   rd  r.  r7  r  r  r   r(   Z  s    
zaten::bitwise_orc                 C  s:   t |std|t |s,td|| d||S )NzVONNX export does NOT support exporting bitwise OR for non-boolean input values. self: zWONNX export does NOT support exporting bitwise OR for non-boolean input values. other: Orr  rT  r  r  r   r)   f  s    

c                   s    fdd}|S )Nc                   s   t   fdd}|S )Nc                   s,   t  d  } | || |d|| |dS )NZ_cast_F)r&  )r1  r8  rE  Zto_cast_func)r  to_typer  r   wrap_with_cast{  s    zGwrap_logical_op_with_cast_to.<locals>.decorator.<locals>.wrap_with_castr  r  )r  r  r  r  r   	decoratorz  s    z/wrap_logical_op_with_cast_to.<locals>.decoratorr  )r  r  r  r  r   r  x  s    r   )r*  returnc                   s   t   fdd}|S )Nc                   s   |  d | ||S )Nrw  r6  r1  r8  rE  r)  r  r   wrap_with_not  s    z4wrap_logical_op_with_negation.<locals>.wrap_with_notr  )r*  r  r  r)  r   r    s    zaten::__not_c                 C  s"   t |std|| d|S r  r  r  r  r  r   __not_  s    
r  zaten::eqc                 C  s   t | tjr:t | tjr:| jdtjdtjddS | }| }|	 |	   krfdkrn nN|
d|
d  krdkrn n*| jdtj|d|dktjddS | d||S )	Nr\  Trt  r]  r  r_  r`  rx  )r<  r  r	   DeviceObjTyper.  r^  r   r  r  r  kindOfr`  )r1  rB  rE  Z	self_nodeZ
other_noder  r  r   rN     s"      $zaten::nec                 C  s   t | ||S r  )rN   rT  r  r  r   r     s    zaten::gtc                 C  s   t | ||S r  _gt_implr  r  r  r   rb     s    c                 C  sJ   t |r<t |r<| jd|tjjd}| jd|tjjd}| d||S )Nrf  rg  r  r   rS  r.  rh  ri  INT32r  r  r  r   r    s    r  zaten::ltc                 C  s   t | ||S r  _lt_implr  r  r  r   r     s    c                 C  sJ   t |r<t |r<| jd|tjjd}| jd|tjjd}| d||S )Nrf  rg  r  r  r  r  r  r   r    s    r  zaten::gec                 C  s   t | ||S r  r  r  r  r  r   r]     s    zaten::lec                 C  s   t | ||S r  r  r  r  r  r   ru     s    zaten::__and_c                 C  s:   t |std|t |s,td|| d||S )NzOONNX export does NOT support exporting bitwise AND for non-boolean input valuesrR  r  r  r  r  r   __and_  s    

r  zaten::__or_c                 C  s:   t |std|t |s,td|| d||S )NzNONNX export does NOT support exporting bitwise OR for non-boolean input valuesr  r  r  r  r  r   __or_  s    

r  zaten::__xor_c                 C  s:   t |std|t |s,td|| d||S )NzOONNX export does NOT support exporting bitwise XOR for non-boolean input valuesrv  r  r  r  r  r   __xor_	  s    

r  zaten::logical_andZBoolc                 C  s   |  d||S )NrR  r6  r  r  r  r   r   "	  s    zaten::logical_orc                 C  s   |  d||S )Nr  r6  r  r  r  r   r   )	  s    zaten::logical_xorc                 C  s   |  d||S )Nrv  r6  r  r  r  r   r   0	  s    zaten::logical_notc                 C  s   |  d| j d|tjjdS )Nrw  rf  rg  r.  rh  ri  BOOLr7  r  r  r   r   7	  s    zaten::__rshift_c                 C  s   t j|}t j|t jj|kr6| jd|| d}| jdtjdtjdd}t	
|sn| jd|tjjd}| d||}| jd|| d}| d||}|S )	Nrf  rg  r\  r  rt  r]  Powre  r   rk  rl  rm  r.  rp  r^  r   float32r   rn  rh  ri  ro  )r1  rB  rE  r  twotwo_powrshiftr  r  r   	__rshift_=	  s*    
r  zaten::__lshift_c                 C  s   t j|}t j|t jj|kr6| jd|| d}| jdtjdtjdd}t	
|sn| jd|tjjd}| d||}| jd|| d}| d||}|S )	Nrf  rg  r\  r  rt  r]  r  rI  r  )r1  rB  rE  r  r  r  lshiftr  r  r   	__lshift_[	  s*    
r  zaten::wherec              	   C  s`   t |s| jd|tjjd}|d krPt| |}t | || jdt	dd|S | d|||S )Nrf  rg  r\  rH  r]  r  )
r   rS  r.  rh  ri  r  r   Z_unbind_helperr^  r   )r1  	conditionrB  rE  r$  r  r  r   r  y	  s    

   zaten::log_softmaxc           	      C  s   t |}|d krt ddS |dk r.|| }||d k}|r|tt|}|d ||  ||< |d< | jd||d}|d }| jd||d	}|r|  d
krt |dd}| jd|t	
| d}|r| jd||d}|S )NrE   fONNX and PyTorch use different strategies to split the input. Input rank must be known at export time.r   rH  r  r  r  Z
LogSoftmaxr.  r-  r  ru  rf  rg  )r   r  r  r=  r  r.  r  r  r  r   rk  rp  )	r1  r8  rE   ru  rU  rV  r  Z	return_oprW  r  r  r   r   	  s2    
  zaten::_log_softmaxc                 C  s>   |r2t j|t jjt jjkr2| jd|tjjd}t	| ||S r  )
r   rk  rl  rm  HALFr.  rh  ri  ro  r   )r1  r8  rE   Zhalf_to_floatr  r  r   _log_softmax	  s     r  zaten::_convolutionc                 C  s"  t |}z|dd  }W n tk
r2   d }Y nX |d ksNtdd |D rZtd|||g}t |st |dkr|| |dd  ||| ||	d}tdd |D r|st	t
|t
|kst	||d< | j|rd	nd
f||}t |st |dkr| d||S |S d S )Nr  c                 s  s   | ]}|d kV  qd S r  r  rX  r  r  r   r  	  s     z_convolution.<locals>.<genexpr>DUnsupported: ONNX export of convolution for kernel of unknown shape.rH  )ri  rk  rj  dilations_igroup_ic                 s  s   | ]}|d kV  qdS r  r  )r  or  r  r   r  	  s     Zoutput_padding_iZConvTransposeConvrF  )r   r:  r  rb  r   rd  r  r  r(  r  rV  r.  )r1  r8  r  biasr]  r\  rf  
transposedoutput_paddinggroupsZ	benchmarkZdeterministiccudnn_enabledZ
allow_tf32weight_sizekernel_shaper  r  r2  r  r  r   _convolution	  sB    




r  zaten::_convolution_modec                 C  s   t |}z|dd  }	W n tk
r2   d }	Y nX |	d ksNtdd |	D rZtd|||g}
t |st |dkr|
| |dkrd}n|dkrd	}|dd  ||||d
}| j	d|
|}t |st |dkr| 	d||S |S d S )Nr  c                 s  s   | ]}|d kV  qd S r  r  rX  r  r  r   r  
  s     z$_convolution_mode.<locals>.<genexpr>r  rH  validZVALIDsameZ
SAME_UPPER)ri  rk  Z
auto_pad_sr  r  r  rF  )r  )
r   r:  r  rb  r   rd  r  r  r(  r.  )r1  r8  r  r  r]  r\  rf  r  r  r  r  r  r2  r  r  r   _convolution_mode
  sB    



r  zaten::convolutionc
           
      C  s"   t | |||||||||	d d d d S r  r  )
r1  r8  r  r  r]  r\  rf  r  r  r  r  r  r   r?   C
  s     zaten::conv1dc           	      C  s\   t |d}|dkr*t| |||||||S t |d}t| ||||||dd|d d d d S d S Nr`  )r  r  r  Fr  r   r  r  r  	r1  r8  r  r  r]  r\  rf  r  Zstr_paddingr  r  r   r;   d
  s:    zaten::conv2dc           	      C  s\   t |d}|dkr*t| |||||||S t |d}t| ||||||dd|d d d d S d S r  r  r  r  r  r   r<   
  s:    zaten::conv3dc           	      C  s\   t |d}|dkr*t| |||||||S t |d}t| ||||||dd|d d d d S d S r  r  r  r  r  r   r=   
  s:    zaten::conv_transpose1dc	           	      C  s"   t | ||||||d||d d d d S NTr  	r1  r8  r  r  r]  r\  r  r  rf  r  r  r   r8   
  s     zaten::conv_transpose2dc	           	      C  s"   t | ||||||d||d d d d S r  r  r  r  r  r   r9   
  s     zaten::conv_transpose3dc	           	      C  s"   t | ||||||d||d d d d S r  r  r  r  r  r   r:     s     zaten::batch_normc
                 C  s   t |d t rDt |||||gsDtjdk rDt dddd|S t | |||||\}}}}| j	d||||||d| |sdndd	}
|s|
S |
\}}}}}|
|  |
|  |d	|   |d	|   |S d S )
Nr&      ZBatchNormalizationr  zaAll input tensors must have the same `dtype`. Turn off Autocast or export using opset version 15.rH  r  )	epsilon_fZ
momentum_fr  zbatch_norm_dead_output-)r   check_training_moder^  Zis_autocast_enabledZargs_have_same_dtyper   r@  rL  Z_batchnorm_helperr.  r/  r  ZsetDebugNameZ	debugName)r1  r8  r  r  running_meanrunning_vartrainingmomentumepsr  rq  resZnew_running_meanZnew_running_varZ
saved_meanZ	saved_varr  r  r   r&   6  sV    	     
zaten::native_layer_normry  z#Tuple[_C.Value, _C.Value, _C.Value])r1  r8  normalized_shaper  r  r  r  c                 C  sv  dd t t|ddD }t| d}t| |}| jd||d}	t| ||	}
tj|
tjj	k}|rtj|}| jd|
t|
 d	}
| jdt| |
||d}t| | d
||}| d|
|}|rtj|}| jd|t|
 d	}|d kst|st| ||}|d ks4t|s4t| ||}|rb| jd|t|
 d	}| d|}n
t| |}||	|fS )Nc                 S  s   g | ]
}| qS r  r  rX  r  r  r   r  |  s     z%native_layer_norm.<locals>.<listcomp>r   r         @r  r7  rf  rg  rF  re  r  )r  rV  r   Z_generate_wrapped_numberr.  r   r   rk  rl  r  rp  r   r   r  r   r   r   )r1  r8  r  r  r  r  r  Ztwo_cstZeps_cstr  	numeratorZis_type_halfZ	eps_dtypeZvariancedenominator
normalizedZinput_dtypeZrdenominatorr  r  r   r   p  sN    
      
zaten::layer_norm)r1  r8  r  r  r  r  cudnn_enabler  c           	   	   C  s<   t  r | jd||||||dS t| |||||\}}}|S )Nrt   )Znormalized_shape_ieps_fZcudnn_enable_i)r   r  r  r   )	r1  r8  r  r  r  r  r  r  rn  r  r  r   rt     s    	zaten::instance_normr   )r1  use_input_statsr  r  r  c
                 C  s,  t |d t |d}
|d ks*t |rl|
d kr>td|tjdg|
 tj	
| d}| jd|d}|d ks~t |r|
d krtd|tjdg|
 tj	
| d}| jd|d}|d kst |s|d kst |r| jd	||||d
S t |}| }|d }|d kr(td||d }d|d< || |d< t| || jdtj|gtjdd}t| || jdtj|gtjdd}t| || jdtj|gtjdd}t| || jdtj|gtjdd}| d|| jdt|d}t| |||||||||	
}t| || jdt|dS d S )Nrn   rH  zCUnsupported: ONNX export of instance_norm for unknown channel size.rZ  rt  r\  r]  r  InstanceNormalizationr  r   zJUnsupported: ONNX export of instance_norm training for unknown batch size.ZReshape)r   r  r  r  r   rd  r^  r   r   rk  rl  ru  r.  r:  copyr   ry  r  r&   r  )r1  r8  r  r  r  r  r  r  r  r  channel_sizeweight_value
bias_value
input_sizeZinput_size_reshaper2  cweight_bias_Zrunning_mean_Zrunning_var_input_reshapedrq  r  r  r   rn     s    

    zaten::unfoldc                   s   t  rjd ||dS t }z|  }W n tk
rJ   d }Y nX |d k	rtd||}t||d |} fddt||D }	t|}
ttd|
	
   fdd|	D }jd|d	 iS t d
dS d S )Nr  )Zdimension_iZsize_iZstep_ir   rH  c              	     s*   g | ]"\}}t j g|g|gd qS )r-  r   r/  )r  lowhi)	dimensionr1  r8  r  r   r  <  s       zunfold.<locals>.<listcomp>c              
     s(   g | ] }t jd |d gqS )r  r  )r   r  r.  r  )r  r1  r  r  r   r  E  s     r:  r;  ZUnfoldrY  )r:  )r   r  r  r:  r  r  rp  rV  r=  r(  r?  r.  r  )r1  r8  r  r   steprc  ZsizedimZlow_indicesZ
hi_indicesr   r   r	  r  )r  r1  r8  r  r   r  +  s2    

  z	aten::eluc                 C  sJ   |r|dkrt dd|S |r4|dkr4t dd|S | jd|t |dS )NrZ  r  zdoes not support scale in Eluinput_scalez#does not support input_scale in EluElurO  )r   r  r.  rM  )r1  r8  rP  r  r  r  r  r   rI   R  s        z
aten::seluc                 C  s   |  d|S )NZSelur6  r7  r  r  r   r   c  s    zaten::index_selectc                 C  s   t | |||S r  )r   _select_helper)r1  rB  rE   rm   r  r  r   rl   j  s    zaten::index_putc                 C  s   t |rt |}n|g}t  rD|g| ||g }| jd| S t |d}t|dkrp|rlt| ||S |S t ddd| d S )Nrk   r  r   r  rG  )rk   )	r   r  r  r  r  r  rV  r   r  )r1  rB  Zindices_list_valuevalues
accumulateZindices_listr  r  r  r   rk   t  s    
zaten::index_fillc           	      C  sr   t |d}t  r*| jd|||d|dS t | |||\}}t |}t ||}t| ||d }t| ||||S )Nr  rj   Z
int_Scalar)r  r  )	r   r  r  r  _index_fill_reshape_helperrN  r  rR   r   )	r1  rB  rE   rm   r_  	dim_valueexpanded_index_shapeexpanded_indexZexpanded_valuer  r  r   rj     s(    	   
zaten::index_copyc                 C  sL   t |d}t  r(| jd||||dS t | |||\}}t| ||||S )Nr  ri   r  )r   r  r  r  r  r   )r1  rB  rE   rm   sourcer  r  r  r  r  r   ri     s       zaten::bucketizec                 C  s   t jj}|rt jj}| jd| d|| d|dd}t|}|d k	sLttt	d|d }t
| t| |||d }	|rt| ||	}
nt| ||	}
| jd|
|d}tj| |dgddS )	Nr:  r5  r   r.  rH  rf  rg  r  )rh  ri  rj  r  r.  r   r  r  r=  r  rR   r  r]   rb   rT  )r1  rB  Z
boundariesZ	out_int32r  Zout_type	new_shapeZtensor_rankZunsqueeze_axesZexpanded_boundariescondZcond_outr  r  r   r-     s$    "

zaten::type_asc                 C  sj   t |}t |}||kr(|d k	r(|S |d k	rD| jd|| dS t  rZ| d||S td|d S )Nrf  rg  r  zUnsupported: ONNX export of type_as for tensor of unknown dtype. Please check if the dtype of the parameter passed to the type_as function is correct.)r   r  r.  rp  r  r  r   rd  )r1  rB  rE  
self_dtypeZother_dtyper  r  r   r    s     

zaten::cosine_similarityc           	      C  s   t  r| jd||||dS t j| t| |||gdd}t j| t| |||gdd}t j| t| |||gdd}t| t| t| ||| jdt	|gd}t
| ||S )NrA   )r  r  r   r  r\  r]  )r   r  r  rT  r   r   r   r.  r^  r   rF   )	r1  x1x2rE   r  rB   Zx1_l2Zx2_l2Zdiv_tensr  r  r   rA     s4     
   
   
    zaten::pairwise_distancec                 C  s   t |s | jdt|gd}t| | jdtjdgtjddt| ||}t j| t	| t
| |||dgt |dd}t	| ||S )Nr\  r]  rH  rt  r  r  r  )r   rJ  r.  r^  r   rF   r  r   rT  r   r   r  )r1  Zinput1Zinput2pr  r  Zinv_pZ	summationr  r  r   r     s    


zaten::clonec                 C  s   |S r  r  )r1  r8  Zunused_memory_formatr  r  r   r4     s    z	aten::absc                 C  s   |  d|S )NAbsr6  r  r  r  r   r     s    z	aten::logc                 C  s   |  d|S )NLogr6  r  r  r  r   r   !  s    zaten::log1pc              	   C  s    t | t| ttd||S )NrH  )r   r   r   r  r^  r   r  r  r  r   r   '  s    zaten::log10c              	   C  s*   d}|  dt| || j dt|gdS )NgUk@re  r\  r]  r.  r   r^  r   )r1  rB  Z_ln10r  r  r   r   -  s    z	aten::powc                 C  sb   t j|}t|s2t jj}| jd|| d}t|sP| jd|| d}| d||}|S )Nrf  rg  r  )r   rk  rl  r   rn  ro  r.  rp  )r1  rB  exponentZf_dtyper   r  r  r   r   4  s    

zaten::clampc              	   C  s|   t |rt| ||S t |r,t| ||S t |rdt |rdt| d|t |dt |dddS t| t| |||S d S )NCliprY     min_fmax_fr=  )r   r  r1   r2   r  rC  r  )r1  rB  r   r   r  r  r   r3   E  s    



	zaten::clamp_minc                 C  sZ   t |r$t| d|t |dddS tj|}| jd|| d}t| d||ddS d S )	Nr  rY  r  )r  r=  rf  rg  MaxrF  	r   r  rC  r  r   rk  rl  r.  rp  )r1  rB  r   ru  r  r  r   r2   \  s    
   
 zaten::clamp_maxc                 C  sZ   t |r$t| d|t |dddS tj|}| jd|| d}t| d||ddS d S )	Nr  rY  r  )r  r=  rf  rg  MinrF  r!  )r1  rB  r   ru  r  r  r   r1   j  s    
   
 z	aten::maxc                 C  s   |d kr |d kr | j d|ddS |d kr:t| d||ddS t|dd}t|dd	}| j d||g|d
}| j d|||d}||fS d S )NrR  r   r  r   r  rF  r  rE   r  r  ArgMaxr;  r  r.  rC  r   r  )r1  rB  dim_or_yr  rE   r   r  r  r  r   r   x  s    zaten::maximumc                 C  s   t | ||dS N)r'  )r   r  r  r  r   r     s    z	aten::minc                 C  s   |d kr |d kr | j d|ddS |d kr:t| d||ddS t|dd}t|dd	}| j d||g|d
}| j d|||d}||fS d S )N	ReduceMinr   r#  r"  r  rF  r  rE   r  r  ArgMinr%  r&  )r1  rB  r'  r  rE   r   r  r  r  r   r     s    zaten::minimumc                 C  s   t | ||dS r(  )r   r  r  r  r   r     s    z
aten::amaxc                 C  s   | j d|||dS )NrR  r  r6  r1  rB  rE   r  r  r  r   r     s    z
aten::aminc                 C  s   | j d|||dS )Nr)  r  r6  r+  r  r  r   r     s    zaten::aminmaxc                 C  sJ   d|i}t |s*t |dd}|g|d< | jd|f|| jd|f|fS )Nr  r  rE   r  r)  rR  )r   r  r  r.  )r1  rB  rE   r  Zreduce_kwargsr  r  r   r     s    

 z	aten::expc                 C  s   |  d|S )NrS  r6  r  r  r  r   rP     s    zaten::dropout_zaten::dropoutc                 C  s.   t |d |s|S | jd||dd\}}|S )NrH   ZDropoutr  )Zratio_fr  )r   r  r.  )r1  r8  r  trainrr  rn  r  r  r   rH     s
    zaten::alpha_dropout_zaten::feature_alpha_dropout_zaten::feature_dropout_zaten::feature_alpha_dropoutzaten::alpha_dropoutzaten::feature_dropoutc                   s$   t dddtj fdd}|S )NrX  r  r  c                   s   |rt  d|S |S )Nztraining mode)r   r  )r1  r8  r  r,  r$  r  r   feature_dropout  s    z-_unsupported_dropout.<locals>.feature_dropoutr   r  r   r  )r%  r-  r  r$  r   _unsupported_dropout  s    r/  z
aten::normc                 C  st   |dkrt d}n|dkr$t d}ntd||| |||d}|d k	rpt|dd}| jd	|t| d
}|S )NrH  ZReduceL1r  ZReduceL2z)ONNX export only p-norms with p of 1 or 2)rE   r  r  ru  rf  rg  )	r  r   rd  r   r  r.  r   rk  rp  )r1  rB  r  rE   r  ru  rY  r  r  r  r   r     s    

 zaten::conv_tbcc              	   C  s~   t  r| jd||||dS | jd|dddgd}| jd|dddgd}t| |||dg|gdgd}| jd|dddgdS d S )Nr7   )Zpad_ir  rH  r  r   r  )r   r  r  r.  r;   )r1  r8  r  r  r   convr  r  r   r7     s    zaten::_uniquec                 C  s,   t  r| jd|||ddS t d|S d S )N_uniquer  )sorted_ireturn_inverse_ir  )r   r  r  r  )r1  r8  sortedreturn_inverser  r  r   r1  $  s    r1  zaten::_unique2c                 C  s2   t  r| jd||||ddS t ddd| d S )N_unique2re  )r2  r3  Zreturn_counts_ir  r  rG  )r   r  r  r  )r1  r8  r4  r5  Zreturn_countsr  r  r   r6  4  s    	r6  zaten::_cast_Bytez2.0z
the futurez8Avoid using this function and create a Cast node insteadc                 C  s   | j d|tjjdS r  )r.  rh  ri  r  r1  r8  Znon_blockingr  r  r   
_cast_ByteE  s    r8  zaten::_cast_Charc                 C  s   | j d|tjjdS r  )r.  rh  ri  r  r7  r  r  r   
_cast_CharP  s    r9  zaten::_cast_Shortc                 C  s   | j d|tjjdS r  )r.  rh  ri  r  r7  r  r  r   _cast_Short[  s    r:  zaten::_cast_Intc                 C  s   | j d|tjjdS r  )r.  rh  ri  r  r7  r  r  r   	_cast_Intf  s    r;  zaten::_cast_Longc                 C  s   | j d|tjjdS r  )r.  rh  ri  rj  r7  r  r  r   
_cast_Longq  s    r<  zaten::_cast_Halfc                 C  s   | j d|tjjdS r  )r.  rh  ri  ZFLOAT16r7  r  r  r   
_cast_Half|  s    r=  zaten::_cast_Floatc                 C  s   | j d|tjjdS r  )r.  rh  ri  ro  r7  r  r  r   _cast_Float  s    r>  zaten::_cast_Doublec                 C  s   | j d|tjjdS r  )r.  rh  ri  r  r7  r  r  r   _cast_Double  s    r?  zaten::_cast_Boolc                 C  s   | j d|tjjdS r  r  r7  r  r  r   
_cast_Bool  s    r@  zaten::emptyc                 C  s   t | |||||S r  )r  )r1  rc  ru  layoutdevice
pin_memorymemory_formatr  r  r   rM     s    zaten::empty_likec                 C  s   t | |||||S r  )r  )r1  r8  ru  rA  rB  rC  rD  r  r  r   rL     s    zaten::new_emptyc                 C  s0   t |}|d kr|d k	r|}t| |||||S r  )r   r  rM   r1  rB  rc  ru  rA  rB  rC  r  r  r  r   r     s    
zaten::scalar_tensorc                 G  s<   t |dd}|d krtjj}| jd|t| d}|S )Nr  ru  rf  rg  )r   r  r   rk  ro  r.  rp  )r1  Zscalarru  optionsr  r  r   r     s
    zaten::tensorc                 C  s  t |dd}t |r|d kr6tjt |d }t }t |D ]L}| jdt	
dgd}t | ||}| jd|t| d}|| qF| jd|d
diS |d krtj|}t |rt |st |r| jd|ddd}| jd|t| dS )Nr  ru  r   r\  rH  r]  rf  rg  r:  r;  ZConcatFromSequence)r;  Z
new_axis_i)r:  )r   r  r  r   rk  rl  r  r=  r.  r^  r  r@  rp  r(  Z_is_listrK  Z_is_scalar_list)r1  dataru  rB  requires_gradr  r   Zshape_referencer  r  r   r     s,    

zaten::as_tensorc                 C  s   t | |||S r  )r   )r1  rG  ru  rB  r  r  r   r!     s    zaten::zerosc                 C  sz   |d krt jj}n
t |}t|d}t|trZt|dkrZ| jdt	
g t	jd}| jd|t	j
dg| ddS )Nr  r   r\  r]  ConstantOfShapert  r   rk  ro  r   r  r<  r=  rV  r.  r^  r   r   ry  ru  r1  rc  ru  rA  rB  rC  rr  sizes_r  r  r   r    s    

zaten::zeros_likec           	      C  sR   |  d|}|d kr(tj|tjj}n
t|}| j d|tjdg| ddS )Nr5  rI  r   rt  r]  r.  r   rk  rl  ro  r^  r   ru  	r1  r8  ru  rA  rB  rC  rD  r>  rr  r  r  r   r    s     
zaten::new_zerosc                 C  s0   t |}|d kr|d k	r|}t| |||||S r  )r   r  r  rE  r  r  r   r   +  s    
z
aten::zeroc                 C  s   t |}t| ||S r  )r   r  r  )r1  rB  r  r  r  r   r  6  s    
z
aten::onesc                 C  sz   |d krt jj}n
t |}t|d}t|trZt|dkrZ| jdt	
g t	jd}| jd|t	j
dg| ddS )Nr  r   r\  r]  rI  rH  rt  rJ  rK  r  r  r   r   =  s    

zaten::ones_likec           	      C  sR   |  d|}|d kr(tj|tjj}n
t|}| j d|tjdg| ddS )Nr5  rI  rH  rt  r]  rM  rN  r  r  r   r   O  s     
zaten::new_onesc                 C  s0   t |}|d kr|d k	r|}t| |||||S r  )r   r  r   rE  r  r  r   r   i  s    
z
aten::fullc              	   C  s   t |d}t |rX|d kr&tjjn|}t| ||||}t| ||| jdt	
ddS t |dd}|d krxtjj}	n
t|}	t |d}
t|
trt|
dkr| jdt	
g t	jd}| jd	||d|	 dS d S )
Nr   r\  rH  r]  r  ru  r  r   rI  )r   r  rJ  r   rk  ro  r  r   r.  r^  r   r  r<  r=  rV  r   ry  r  ru  )r1  rc  r_  ru  rA  rB  rC  const_valuetmprr  rL  r  r  r   r[   t  s"    


zaten::full_likec              	   C  s   t |d}t |dd}|d kr6tj|tjj}n
t|}t |rt| ||||}	| j	d||
 d}t| |	|| j	dtddS | 	d	|}
| j	d
|
tj|g| ddS d S )NrY  r  ru  rf  rg  r\  rH  r]  r5  rI  rt  )r   r  r  r   rk  rl  ro  rJ  r  r.  rp  r   r^  r   ru  )r1  r8  
fill_valueru  rA  rB  rC  rD  rr  rP  r>  r  r  r   rZ     s$     

zaten::new_fullc           	      C  s2   t |}|d kr|d k	r|}t| ||||||S r  )r   r  r[   )	r1  rB  r   rQ  ru  rA  rB  rC  r  r  r  r   r     s    
	aten::eyec                 G  s   t |dkrX|\}}}}}t| |dg}| jd||dd}t| ||||}	| d|	S t |dkr|\}}
}}}}| jdt| |dgt| |
dgdd}t| ||||}	| d|	S tddt | d	S )
Nr  r   r:  r.  ZEyeLikerH  rR  r  r  )rV  r   r  r.  r  r  )r1  r  r2  ru  rA  rB  rC  r8  r>  r   mr  r  r   rS     s"    aten::slicec                 G  s2  t |dkrr|\}}}}t|d}|dkr:td||  dkoXt| t	j
}|  dkoxt| t	j
}|  dk}	|  dk}
|s|	r|s|
r|  dkrtjtjjkrtd|nBt| |dg}t| |dg}t| |dg}| d	||||S nT|r&dn
t|d}|r>tjn
t|d}t|d}tj| ||g|g|gd
S nt |dkr|\}}}d}|  dkot| t	j
}|  dkot| t	j
}|rdn
t|d}|rtjn
t|d}tj| ||g|g|gd
S tddt | dS )Nr  r  rH  z"step!=1 is currently not supportedr-  r  zUnsupported: ONNX export of Slice with dynamic inputs. DynamicSlice is a deprecated experimental op. Please use statically allocated variables or export to a higher opset version.r   ZDynamicSlicer-  re  rT  r  r  )rV  r   r  r   rd  r  r  r<  r  r	   ZNoneTyper   operator_export_typerh  ZOperatorExportTypesZONNXr  r.  r
   r  r/  r  )r1  rB  r  rE   r  r  r  Zis_start_noneZis_end_noneZis_start_onnx_constZis_end_onnx_constZstart_unsqueezedZend_unsqueezedZdim_unsqueezedr  r  r   r     s      

    
  

    zaten::hardtanhr1  rB  Zmin_valZmax_valc                 C  s   t | d|||ddS )Nr  r  r  rG  rV  r  r  r   rg   $  s         zaten::hardswishc                 C  s   t | |}| d||S r0  )re   r.  )r1  rB  hsr  r  r   rf   .  s    
zaten::hardsigmoidc                 C  s   | j d|ddS )NHardSigmoidgUUUUUU?rO  r6  r  r  r  r   re   7  s    zaten::tanhshrinkc                 C  s   |  d|t| |S )NrQ  )r.  r   r  r  r  r   r   B  s    zaten::hardshrinkc                 C  sx   t j|t jj}| jdtj|| dd}t| t	| ||t
| |t| |}| d||| jdtjd| ddS Nr\  rt  r]  r  r   )r   rk  rl  ro  r.  r^  r   ru  r   rb   r   r   )r1  rB  lambdrr  lambd_opr  r  r  r   rd   I  s$     "zaten::softshrinkc           	      C  s   t j|t jj}| jdtj|| dd}t| ||}| d|t	| ||| jdtjd| dd}t
| |t| |}| d|t| ||| jdtjd| dd}t| ||S rY  )r   rk  rl  ro  r.  r^  r   ru  rb   r   r   r   r   )	r1  rB  rZ  rr  r[  Zgt_condZgt_outZlt_condZlt_outr  r  r   r   `  s:     
	
	zaten::aliasc                 C  s   |S r  r  r  r  r  r   r     s    zaten::unsqueezec                 C  s~   |dk rlt |}|d k	r^tdt| d d d t|| d  d d  || d }nt d	d
|S t j| ||gdS )Nr   z)ONNX export unsqueeze with negative axis r1  r2  r3  rH  r4  r5  r	  r6  r7  )r   r  r  r  r#  r  r  r  r  r  r   r	    s6    

  z
aten::sortc                 C  sp   |d k	rt dd| t |}z|| }W n tk
rD   d }Y nX |d kr\t dd|S | jd|||ddS )NZSortz'Out parameter is not supported for sortrY  TopKr  Zk_ir;  r  )r   r  r:  r  r.  )r1  rB  rE   Z	decendingrq  Z
self_sizesr8  r  r  r   r     s      

zaten::numelc                 C  s   |  d|}| j d|ddS )Nr5  r  r   r#  r6  rA  r  r  r   r     s    z
aten::topkc                 C  s<   |d k	rt dd| |s(t dd| | jd|||ddS )Nr\  z'Out parameter is not supported for topkzAscending TopK is not supportedr  r]  )r   r  r.  )r1  rB  r  rE   largestr4  rq  r  r  r   r     s      zprim::convert_element_typec                 G  s,   t |d dd}| jd|t| dS )Nr   r  ru  rf  rg  )r   r  r.  r   rk  rp  )r1  rB  r  ru  r  r  r   r>     s    zaten::toc                 G  s  t jdd }||r|S t|dkr|d }t|d r|d   dkrt|d  d}t|t	j
rt|jdkr| }t|}n|}t|st|t	j
rtj|d }| jd|| dS | jd|t| dS nt|d	kr$t|d
 dd}| jd|t| dS t|dkr^t|d dd}| jd|t| dS t|dkrt|d dd}| jd|t| dS td|S )Nc                 S  s   t | dkrL| d   dkpJ| d  tj pJt| d  tj	S t | dkrrt
| d dd}|d kS t | dkrt
| d dd}|d kS d	S )
Nr  r   prim::devicer  rH  r  ru  )rH     F)rV  r  r  r  isSubtypeOfr	   ListTypeofIntsr<  r  r   r  )r  ru  r  r  r   is_aten_to_device_only  s    z"to.<locals>.is_aten_to_device_onlyr  r   r  r_  rf  rg  r  rH  r  ru  rH  r`  zUnknown aten::to signature)r   r  rV  r   rJ  r  r  r)  r<  r^  r
  r>  rr   r  r   rk  rl  r.  rp  r  r  )r1  rB  r  rd  ru  Ztvalr  r  r   r     sD    

zaten::repeatc                 C  s0   t jj}t| ||}| d||}| d||S )Nr  ZTile)r   rk  rj  r   r.  )r1  rB  repeatsru  Zshape_r  r  r   r   $  s    zaten::repeat_interleavec              
   C  s  |}t |r@t | || jdtdgd}tjdtjd}n
t |}t |}t 	|}t 	|}|d kr|t
d||d krt
d||d krt
d||dk r|t|7 }| }	t|D ] \}
}|d krd	\||
< |	|
< q|dks|d
kr<|d d
kr<|| dkr,t dddd|S t | |||S |d
kr|| dkrft dddd|S |d d krt dddd|S |d || kstd|d }nt
d|t }t | ||d}t | |||}d\||< |	|< t|D ]\}
}t| ||
 |d
 }| jdt|	d |d
  d|| jdt|	|d
 d  dg}| jd|ddi}t| ||d }t j| || jdt|ddd}|| q| jd|d|iS )Nr\  r  r]  r   rt  zGUnsupported: ONNX export of repeat_interleave for unknown repeats rank.zGUnsupported: ONNX export of repeat_interleave for unknown repeats size.zEUnsupported: ONNX export of repeat_interleave for unknown input size.)r   r  rH  r   r     z3Unsupported along dimension with unknown input sizez*Unsupported for cases with dynamic repeatsz2repeats must have the same size as input along dimz%repeats must be 0-dim or 1-dim tensor)r  rH  r:  r;  Z	allowzero)r:  )r:  )r   r  r@  r.  r^  r   ry  rN  r  r:  r   rd  rV  r  rA  rL  Z-_repeat_interleave_single_value_repeat_helperr  r=  Z_repeat_interleave_split_helperr	  r  rR   r(  )r1  rB  re  rE   r  r8  Zrepeats_dimZrepeats_sizesZinput_sizesZinput_sizes_tempr  r  ZrepsZfinal_splitsZr_splitsZi_splitsZr_splitZi_splitZr_concatr  r  r   r   -  s    
  



"   

zaten::pixel_shufflec           	      C  s  t |}t|dkr$t dd|S tdd |dd  D rt j| t | |ddg| jd	t	d
d||d
d
gdd
d}| jd|d
dddddgd}t j| || jd	t	d
d
ddd
d
gdd
d}t j| || jd	t	d
d
d
d
ddgdd
d}t 
| |ddgS |d | | }t j| || jd	t	d||||d |d gdd
d}| jd|d
dddddgd}t j| || jd	t	d||d | |d | gdd
dS d S )Nr  r   only support 4d inputc                 s  s   | ]}|d kV  qd S r  r  rX  r  r  r   r    s     z pixel_shuffle.<locals>.<genexpr>rH  r  re  r\  r   r  r]  rg  r  r  r  r   r:  rV  r  rb  r@  r  r.  r^  r   r,  )	r1  rB  Zupscale_factorr  
after_viewafter_transpose	reshape_h	reshape_woutput_channelr  r  r   r     s    
  	

zaten::pixel_unshufflec           
      C  s  t |}t|dkr$t dd|S tdd |dd  D rt j| t | |dg| jdt	d	d	d
|d	gdd	d}t j| || jdt	d	d	d	d	d
|gdd	d}| jd|d	dddddgd}t j| || jdt	d	d
ddd	d	gdd	d}t 
| |ddgS |d | | }t j| || jdt	d
|d |d | ||d | |gdd	d}	| jd|	d	dddddgd}t j| || jdt	d
||d | |d | gdd	dS d S )Nr  r   rh  c                 s  s   | ]}|d kV  qd S r  r  rX  r  r  r   r    s     z"pixel_unshuffle.<locals>.<genexpr>rH  re  r\  r   r  r]  rg  r  r  r  r  ri  )
r1  rB  Zdownscale_factorr  rl  rm  rk  Zfinal_reshapern  rj  r  r  r   r     s|    
  



c           *        s  t d d d d d  dddd	d
ddddddg}ttdd |D |}|rXdnddkrt  d|	  krtdd|S t  d|	  kst fddtdt D |
rj	d|dddgd}|r|rtdd|S 
dr|dd    }d d }t|dd krFtdd|S |	 }|}g }dkshd krn|}ndkr|\}}g }|d krtn|}d krd!d"d#gndkrd"d$d%gtjd&d' tjfd(d)}tjfd*d+}tjfd,d-}tD ]R}|rndkrL||\}}}n||\}}t}||d f}ndkr|d| \}} }!|d| d \}"}#}$j	d.|!|$dd/}n,|d| \}} |d| d \}"}#t}j	d.||"dd/}j	d.| |#dd/}d| d| d f}|||||g}%|%||f|  dkrX|%||f|  |rbi nd0d1i}&dkr|	r||g}'n|g}'j	d;|%d|'d2|&\}}(nVd kr؈j	d<|%ddd3|&\}}(n*dkrj	d=|%d4d5|&\}}(})|	rJj	d|dddd4gd}tj|j	d6tddd7gd8dd9}nt|dg}||( dkr&||) q&|
rj	d|dddgd}dkr|(nj	d>|d:di}dksΈd kr||fS dkrdkr|)nj	d?|d:di}|||fS d S )@NzVExporting a model to ONNX with a batch_size other than 1, with a variable length with z can cause an error z9when running the ONNX model with a different batch size. z4Make sure to save the model with a batch size of 1, z=or define the initial states (h0/c0) as inputs of the model. rD  r  r  ZAffinerN  ZThresholdedReluZ
ScaledTanhrX  r  ZSoftsignr<  c                 S  s   g | ]}|  qS r  )lower)r  Zact_funr  r  r   r  M  s     z _generic_rnn.<locals>.<listcomp>r  r  LSTMrH  zLSTMs with projectionsc                   s   g | ]} ||  qS r  r  rX  )all_weightsweights_per_layerr  r   r  V  s   r   r  r  zRNN/GRU/LSTMzdropout in training modeRNNzunknown hidden sizeGRU)rH  r  )r   rH  )r  re  )re  r  )rH  re  c                   s*    fdd|D } j d|ddiS )Nc              	     s2   g | ]*\}}t j d g| g| gdqS )r   r-  r  )r  xyr1  r2  wr  r   r    s   z8_generic_rnn.<locals>.reform_weights.<locals>.<listcomp>r:  r;  r   )r:  r6  )r1  rx  r2  Z	intervalsZslicesr  rw  r   reform_weights  s    z$_generic_rnn.<locals>.reform_weightsc                   s`   |  }dkr|\}}n,dks*dkrF fdd|D \}}t  fdd||fD S )Nrs  rt  rp  c                 3  s   | ]} |V  qd S r  r  r  rx  r1  hidden_sizereform_permutationry  r  r   r    s    zB_generic_rnn.<locals>.transform_weights_no_bias.<locals>.<genexpr>c                 3  s   | ]}t  |d gV  qdS r  r  r  ru  r3  r  r   r    s    )r  )layer_indexweights	weight_ih	weight_hhr1  r|  layer_weightsr}  ry  variantr  r   transform_weights_no_bias  s    

z/_generic_rnn.<locals>.transform_weights_no_biasc                   s|   |  }dkr|\}}}}n0dks.dkrN fdd|D \}}}} j d||dd}t fd	d|||fD S )
Nrs  rt  rp  c                 3  s   | ]} |V  qd S r  r  rz  r{  r  r   r    s    z:_generic_rnn.<locals>.transform_weights.<locals>.<genexpr>r:  r   r.  c                 3  s   | ]}t  |d gV  qdS r  r  r~  r3  r  r   r    s   )r.  r  )r  r  r  r  Zbias_ihZbias_hhbias_concatr  r  r   transform_weights  s    z'_generic_rnn.<locals>.transform_weightsc                   s&   dkr| S t j | dg|g|gdS )NrH  r   r-  r  )ru  r  r  )r1  
num_layersr  r   retrieve_state  s        z$_generic_rnn.<locals>.retrieve_stater:  r.  Zdirection_sbidirectional)r  hidden_size_iZactivations_s)r  r  Zlinear_before_reset_ire  )r  r  r\  r  r]  rg  r;  )rs  )rt  )rp  )r:  )r:  )r  r  dictrp  rV  r   r  r  r  r.  
startswithro  r  r  r   r  r(  r@  r^  r  r,  )*r1  r  r8  Zinitial_statesrq  
has_biasesr  rH   r,  r  batch_firstbatch_sizesZonnxActivationsZvariantToOnnxActivationMapZnonlinearityw_hhZunidirectionalZprev_outputh_outsZh0Zc0c_outsZsequence_lensr  r  r  r  r  r  r  Zstate_indicesZweight_ih_fZweight_hh_fZbias_fZweight_ih_bZweight_hh_bZbias_brB  extra_kwargsZ
activationZh_outZc_outr  )	rq  r1  r|  r  r  r}  ry  r  rr  r   _generic_rnn&  s:     
  




	




 
 
  

"
"r  c
                 C  s2   t |t | }
}t| d||
|||||||	S )Nrp  r   r  r  )r1  r8  hidden_vweight_vr  r  rH   r,  r  r  hiddenr  r  r  r   
_lstm_full  s$    r  c
                 C  s4   t |t | }
}t| d||
||||||	|dS )Nrp  r  r  )r1  r8  r  r  r  r  r  rH   r,  r  r  r  r  r  r   _lstm_packed,  s$    r  z
aten::lstmc                 G  s.   t |d rt| f| S t| f| S d S Nre  )r   rK  r  r  r1  r  r  r  r   r   L  s    zaten::lstm_cellc                   s   t  |dg}t |} fdd|D }t |rB||||fn||f}t |rXdnd}	t d||||	dddddd\}
}}t  |dgt  |dgfS )	Nr   c                   s   g | ]}t  |d gqS r  r  r~  r3  r  r   r  Z  s     zlstm_cell.<locals>.<listcomp>TFrp  rH  )r  rH   r,  r  r  )r   r  r  Z
_is_tensorr  r,  )r1  rB  r  Zw_ihr  Zb_ihZb_hhr8  r  r  rn  r  r  r  r3  r   r   U  s4    
  z	aten::grurt  Zgruzaten::rnn_tanhZRNN_TANHZrnn_tanhzaten::rnn_reluZRNN_RELUZrnn_relur  c                   sd   t ddddddddd	tjfdd t ddddddddd	fdd fdd	}|S )
NrX  r  rY  c
                   s&   t |}
t|  |||
||||||	S r  r  )r1  r8  r  r  r  r  rH   r,  r  r  r  r  r  r   	_rnn_fully  s    
z"_one_hidden_rnn.<locals>._rnn_fullc
                   s(   t |}
t|  |||
|||||	|dS )Nr  r  )r1  r8  r  r  r  r  r  rH   r,  r  r  r  r  r   _rnn_packed  s    
z$_one_hidden_rnn.<locals>._rnn_packedc                   s.   t |d r| f| S  | f| S d S r  )r   rK  r  )r  r  r  r   r    s    z!_one_hidden_rnn.<locals>.symbolicr.  )r  r  r  )r  r  r  r   _one_hidden_rnnq  s    r  zaten::_dim_arangec                 C  sX   |  d|}| j d|| j dt|ddd}t rB|  d|S t| |dd d d S d S )	Nr5  r  r\  r]  r   r.  z_caffe2::Ranger  )r.  r^  r   r   r  r   )r1  likerE   Z
like_shapestopr  r  r   _dim_arange  s       r  zaten::detachc                 C  s   |S r  r  r7  r  r  r   rD     s    zaten::contiguousc                 C  s   |dkrt d||S )Nr  z-onnx memory_format support is not implemented)r   rd  )r1  r8  rD  r  r  r   r6     s     zaten::_pack_padded_sequencec                 C  s|   |r| j d|dddgd}| tjj s<td|t	j
|t	j
jt	j
jkrj| j d|tjjd}| j d	||dd
S )Nr  rH  r   r  r  z*'lengths' must be a Tensor for ONNX exportrf  rg  zprim::PackPaddedr  )r.  r  ra  r^  r	   Z
TensorTypegetr   rd  r   rk  rl  rm  r  rh  ri  r  )r1  r8  lengthsr  r  r  r   _pack_padded_sequence  s       r  zaten::_pad_packed_sequencec                 C  s8   | j d||dd\}}|r0| j d|dddgd}||fS )Nzprim::PadPackedr  r  r  rH  r   r  r6  )r1  rG  r  r  Zpadding_valuetotal_lengthr  r  r  r   _pad_packed_sequence  s    r  zaten::randintc                 G  s  t |dd}t |dd}t |dd}|d kr<tjj}n
t|}|d krZt d||d krnt d|t |d}	t |	r| jd|t	j
dgt	jd	d
}
| jd|
||d}n| jd|	||d}tjj}| jd|| d}||kr| jd|| d}|S )Nr  ru  r  highr   r  rI  r   rt  r]  RandomUniformLikelow_fhigh_fRandomUniform)shape_ir  r  rf  rg  )r   r  r   rk  rj  r  r  rJ  r.  r^  r   r  rp  )r1  r  r  shapesru  rF  low_ihigh_irr  r>  shape_constr   	int_dtyper   r  r  r   r   
  sD    



zaten::randint_likec                 G  s   t |dd}t |dd}t |dd}|d kr<tjj}n
t|}|d krZt d||d krnt d|| jd|||d}	tjj}
| jd|	|
 d	}|
|kr| jd|| d	}|S )
Nr  ru  r  r  r   r  r  rf  rg  )r   r  r   rk  rj  r  r.  rp  )r1  rB  r  r  ru  rF  r  r  rr  r   r  r   r  r  r   r   6  s*    

zaten::randnc                 G  s   t |dd}|d kr tjj}n
t|}t |d}t |rr| jd|tj	dgtj
dd}| jd|| d	S | jd
|| dS )Nr  ru  r  rI  r   rt  r]  RandomNormalLikedtype_iZRandomNormalr  r  r   r  r   rk  ro  r  rJ  r.  r^  r   r  rp  r1  r  ru  rF  rr  r>  r  r  r  r   r   T  s*    


z
aten::randc                 G  s   t |dd}|d kr tjj}n
t|}t |d}t |rr| jd|tj	dgtj
dd}| jd|| d	S | jd
|| dS )Nr  ru  r  rI  r   rt  r]  r  r  r  r  r  r  r  r  r   r   o  s*    


zaten::randn_likec                 C  sH   t |dd}|d kr*tj|tjj}n
t|}| jd|| dS )Nr  ru  r  r  r   r  r   rk  rl  ro  r.  rp  )r1  rB  ru  rA  rB  rC  rD  rr  r  r  r   r     s     
zaten::rand_likec                 C  sB   t |dd}|d kr(tj|tjj}| jd|t| dS )Nr  ru  r  r  r  )r1  rB  ru  rA  rB  rC  rD  r  r  r   r     s       zaten::rreluc                 C  s@   |s || d }| j d||dS | j d|||d}|  d||S )Nr  rN  rO  r  )r  r  r9  r6  )r1  r8  ro  upperr  r  r  r  r  r  r   r     s
    zaten::bernoullic           	      C  s   |d k	r t |s t dd| |d k	r@t |s@t dd| tj|tjj}|tjjkrlt dd|S | jd|dd| d}|d k	rt |s|n|}| d	||}| jd
|| dS )NZ	Bernoulliz,out parameter is not supported for bernoulliz(generator is not supported for bernoulliinput dtype not accessibler  rZ  r  )r  r  r  r  rf  rg  )	r   r  r  r   rk  rl  rm  r.  rp  )	r1  r8  r  r  rq  ru  ZrandsZprobr  r  r  r   r'     s@           zaten::log_sigmoidc                 C  s   |  d|}|  d|S )Nr  r  r6  )r1  r8  r  r  r  r   r     s    z	aten::erfc                 C  s   |  d|S )NErfr6  r7  r  r  r   rO     s    zaten::flattenc                 C  s   t |}|d kr t dd|S |dkr8t | |dgS |dkrL| d|S |dk r\|| }|dkr||d kr| jd||dS |dkr||d kr| jd||d dS t | ||||S )	NrE   r  r   rH  r  Flattenr.  r  )r   r  r  r@  r.  Z_flatten_helper)r1  r8  Z	start_dimZend_dimrE   r  r  r   rU     s$    
zaten::nonzeroc                 C  s   t | | d|S )z/Emitted from `torch.nonzero(x, as_tuple=False)`ZNonZero)r   r.  r7  r  r  r   r     s    zaten::nonzero_numpyc                 C  s   t | t| |d|dS )NrH  )r$  )r  r   )r1  r8  r$  r  r  r   r     s    zaten::isnanc                 C  s   |  d|}|S )NZIsNaNr6  )r1  r8  r  r  r  r   rq     s    z	aten::anyc              	   G  s   t |dkr|d }d\}}n$|\}}}t|dg}t|d}| jd|tjjd}tj| |||d}t| || jdt	j
dt	jd	d
S )NrH  r   rU  r  rf  rg  r  r\  rt  r]  )rV  r   r  r.  rh  ri  rj  rT  rb   r^  r   long)r1  r  r8  rE   r  Z	input_sumr  r  r   _any#  s    

   r  z	aten::allc              	   G  sP   |  d|d }t|dkr.|  dt| |S |  dt| ||d |d S d S )Nrw  r   rH  r  )r.  rV  r  )r1  r  r8  r  r  r   _all6  s    r  zaten::narrowc                 C  s   t j| ||g|g|| gdS )Nr-  r  )r1  r8  rE   r  lengthr  r  r   r   B  s        zaten::argmaxztorch._C.Valuer1  r8  rE   r  c                 C  s   t | |||dS )Nr$  r   Z_argmin_argmax_helperr  r  r  r   r   K  s    	zaten::argminc                 C  s   t | |||dS )Nr*  r  r  r  r  r   r   W  s    	zaten::scatterc                 C  s   t j|t jj}t|}t|r:| jd||||dS t j|}||krb| jd|| d}| jd||t	| |||dS d S )NZScatterr.  rf  rg  )
r   rk  rl  rm  r   rN  rJ  r.  rp  rQ   )r1  rB  rE   rm   srcZsrc_typer  r  r  r   r   c  s     

zaten::scatter_addc                 C  sz   t |}|d kr t dd|S t j|dd}|rP| jdtj|| dd}nt| ||}t 	| ||||}t
| ||S )Nr   r  F)Zallow_nonstaticr\  rt  r]  )r   r  r  r:  r.  r^  r  ru  r  Z_scatter_helperr   )r1  rB  rE   rm   r  rr  rc  Zto_addr  r  r   r   v  s    
  z
aten::log2c              	   C  s(   d}|  dt| || j dt|dS )Ng9B.?re  r\  r]  r  )r1  rB  Z_ln2r  r  r   r     s    zaten::is_floating_pointc                 C  s6   t |r | jdtdgdS | jdtdgdS Nr\  rH  r]  r   )r   rn  r.  r^  
BoolTensorr  r  r  r   ro     s    
zaten::__is_c                 C  sL   t |r@t |r*| jdtdgdS | jdtdgdS t| ||S r  )r   r  r.  r^  r  rN   rT  r  r  r   __is_  s
    

r  zaten::__isnot_c                 C  s   t | ||S r  )r  rT  r  r  r   __isnot_  s    r  zaten::one_hotc                 C  sn   | j dtddgd}tj|tjjtjjtjjtjj	tjj
hkrZ| j d|tjjd}| j d|||dd	S )
Nr\  r   rH  r]  rf  rg  OneHotr  r.  )r.  r^  r  r   rk  rl  rm  r  r  r  r  rh  ri  rj  )r1  rB  Znum_classesr
  r  r  r   r     s     zaten::gatherc           	   	   C  s   t |drt dd|S tj|}| jdtddgd}t	| || jdt|gd}| jd| jd	||||d
|
 d}| dt | ||d g|}t j| ||gddS )Nr  r\   zsparse_grad == Truer\  r   rH  r]  rf  r  r.  rg  rI  r  )r   r  r  r   rk  rl  r.  r^  r  r   rp  r  rT  )	r1  rB  rE   rm   Zsparse_gradrr  r
  depthr   r  r  r   r\     s    c              	   C  s:  |d kr(| j d|dd}|}t| |}nb| j d|||d}| j d||dd}|  d|}| j d|| j dt|d	dd
}| j d|dd}|  d||}	|  d|	|	}
|d krdn|}| j d|
||d}|d krd}|dkr2| j d|tjjd}| j dtj|tjdd	}|  d||}|  d||  d||}||fS )Nr  r   r#  r  rH  r5  r  r\  r]  r.  r  rQ  rI  rf  rg  rt  re  )r.  r   r^  r   rh  ri  ro  r  )r1  r8  rE   Z
correctionr  r  Zt_meanZnum_elementsZredudced_dimsZsub_vZsqr_subZkeepdim_meanr  r|  r   r  r  r   	_var_mean  s<    
  r  z	aten::stdc                 G  s    t | |f| \}}| d|S Nr  r  r.  r1  r8  r  r  rn  r  r  r   r     s    z	aten::varc                 G  s   t | |f| \}}|S r  )r  r  r  r  r   r    s    zaten::var_meanc                 G  s4   t |dkr t| |d |d d S t| |f| S d S )NrH  r   )rV  r  )r1  r8  r  r  r  r   r    s    zaten::std_meanc                 G  s$   t | |f| \}}| d||fS r  r  )r1  r8  r  r  r  r  r  r   r     s    zaten::logsumexpc                 C  s   | j d|||dS )NZReduceLogSumExpr  r6  r  r  r  r   r     s    aten::arangec           
        s  t  r jd| S tjdd }tj fdd}t|dksNt|dkrt|dkr`d }n||d }t j |d	 |d
\}}}}t  |d	g}||}t  t	 t
 ||d d dg}	 jd|	t| dS t|dkst|dkrt|dkr
d }n||d }t j |d	 |d |d |d\}}}}t  |d	g}t  |d	g}t  |d	g}| d d|||}t  t	 t
 |d d d dg}	 d d|	||}	 jd|	t| dS t|dkr||d }t j |d	 |d |d\}}}}t  |d	g}t  |d	g}| d||} dt  t	 t
 ||f|dd   dg|}	 jd|	t| dS t ddt| dS )Nr   c                 S  s   t | d} | S )Nr  )r   r  rt  r  r  r   _get_arange_dtype!  s    z!arange.<locals>._get_arange_dtypec                   s.   t | r* jd d| tjj d} | S )Nrf  rI  rg  )r   rn  r.  r   rk  rj  rp  )range_tensorr3  r  r   _float_step_convert&  s    


z#arange.<locals>._float_step_convertr  r  rH  r   )r  ru  rf  rg  r  r`  re  )r  r  r  ru  re  rQ  rF  rI  rH  )r  r  ru  r  r  r  )r   )r   r  r  r   r  rV  Z_arange_cast_helperr  r,  r   r   r.  r   rk  rp  r  )
r1  r  r  r  ru  r  r  r  r  Zarange_tensorr  r3  r   r     s    
	                     zaten::linspacec           
      C  sT   t | |d }t| t| ||t| || jdtjdtjdd}	t| t	| ||	|S )Nr\  rH  rt  r]  )
r   Z_arange_helperrF   r   r.  r^  r   ry  r   r   )
r1  r  r  Zstepsru  rA  rB  rC  r  r  r  r  r   r~   n  s    
 z
aten::liftc                 C  s   |S r  r  r  r  r  r   rx   |  s    zaten::masked_fillc                 C  s6   | j d|tjjd}t|}|  d|t|||S )Nrf  rg  r  )r.  rh  ri  r  r   rN  r  r1  rB  maskr_  r  r  r   r     s    
zaten::masked_fill_c                 C  s   t | |||S r  )r   r  r  r  r   r     s    aten::indexc                   s,  t  rjd|ddS t |r0t |}n|g}tjfddfdd|D }t|dkrt jd	|d	 d
dS dd t	|D  t d	krS t dkrt
 d	 | d	  S t }|d krt ddS tdtj d t }tfddt|D jd  fddt|D  djd|d| d  } d  }t|d ddD ]@}d| |  |}	d||	}d| |  }qt
d	|t|}
 tt d	  d d krjdtdgdg fddt|D  }jd#|dd	i}t |ttd d	 d d	g tt d	 d || d  }jd|dfd dt d	 D |
g  fd!dt d	 |D  }jd$|dd	i}n,jd|
f fd"dt|D dd	i}t |S d S )%Nrm   r
  )r  c                   sh   t | sdtj| tjjtjjks.t | rd jdk rDt	
dtd t  t | dg} | S )Nr  z?Exporting masked indices are only supported after ONNX opset 9.zExporting aten::index operator with indices of type Byte. Only 1-D indices are supported. In any other case, this will produce an incorrect ONNX graph.rH  )r   r  r   rk  rl  rm  r  rS  r  r   rd  r  r  r,  r   )rm   r  r  r   try_mask_to_index  s(    
 
z index.<locals>.try_mask_to_indexc                   s   g | ]} |qS r  r  )r  r  )r  r  r   r    s     zindex.<locals>.<listcomp>rH  r   F)Zapply_reshapec                 S  s   g | ]\}}t |s|qS r  )r   r  )r  r  r  r  r  r   r    s    
 r  zoperator of advanced indexing on tensor of unknown rank. Try turning on shape inference during export: torch.onnx._export(..., onnx_shape_inference=True).z=Exporting aten::index operator of advanced indexing in opset z is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.c              
     s0   g | ](} j d  j dt|gdddqS )r  r\  r]  r   r.  )r.  r^  r  r  rE   )r1  shape_tensorr  r   r    s   r  c                   s   g | ]}| kr|qS r  r  rX  )adv_idx_indicesr  r   r    s      r  r  r.  r  r  rI  rF  r\  r]  c                   s   g | ]}| kr| qS r  r  rX  r  dim_tensor_listr  r   r    s     r:  r;  c                   s   g | ]} | qS r  r  rX  )r  r  r   r  $  s     c                   s   g | ]}| kr| qS r  r  rX  r  r  r   r  &  s   c                   s   g | ]}| kr| qS r  r  rX  r  r  r   r  1  s   )r:  )r:  )r   r  r  r  r  r   r  rV  r	  rA  rl   r  r  r  r  r   r@  r9  r  r.  r=  r^  r  r@  )r1  rB  rm   r  r  Zadv_idx_countZcum_adv_index
multiplierr  Z	adv_indexZcum_adv_index_shape_tensorZfolded_adv_idx_shape_listZfolded_adv_idx_shapeZadv_idx_permuteZfinal_shape_listZfinal_shaper  )r  r  r1  rB  r  r  r   rm     s    
       

	

  

 	zaten::linalg_normzOptional[Sequence[int]])r1  rB  ordrE   r  ru  c                 C  s   d }|d kr|t |r<t | |dg}| jdtdgd}t |}|d kr\t dd|S |dkrrt |d}qd	dg}n8t	|dkrt |r| jdtdgd}t |d}|rt
| |||||S t| |||||S )
Nr  r\  r  r]  rE   (Input rank must be known at export time.rH  rY  r   )r   r  r@  r.  r^  r  r  r  r  rV  r|   rz   )r1  rB  r  rE   r  ru  	ord_valueself_dimr  r  r   r{   <  s,    

  

zaten::linalg_vector_normc                 C  s   |d krt | |dg}d}|tjkrB| jd| d|||d}n|tj krj| jd| d|||d}n|dkrt dd	d
d|S | jdtj|tjdd}t j	| | d| d||||d}| d|| d| jdtjdtjdd|}|S )Nr  FrR  r  r  r)  r   r|   r  rG  zord=0 not supportedr\  rt  r]  r  re  rH  )
r   r@  r  infr.  rL  r^  r   r  rT  )r1  rB  r  rE   r  ru  r  Zord_opr  r  r   r|   `  s@    
       	zaten::linalg_matrix_normz	List[int]c              	   C  s  t |d}|dkr"t| |||S |dkr8t dd|S t |d}|d krZt| |||S |dksj|dkrxt dd	|S t |}|d krt dd
|S |d dk r|d  |7  < |d dk r|d  |7  < |tjks|tj kr|d |d  |d< |d< |d |d kr*|s*|d  d8  < t j| | d||d g|d}|dkrt	| || jdt
|d gd|d\}	}
n*t| || jdt
|d gd|d\}	}
|	S d S )Nr`  ZfroZnuczlinalg.matrix_normzord==nucrY  r  r  zord==2r  r   rH  r  r  r\  r]  )r'  r  )r   r  rY   r  r  r  r  rT  r.  r   r^  r  r   )r1  rB  r  rE   r  ru  r  r  r  r  r  r  r  r   rz     sZ    
   
  

zaten::linalg_crossr  c                 C  s   t | |||S r  )rB   )r1  r8  rE  rE   r  r  r   ry     s    zaten::frobenius_normc                 C  s,   |  d||}tj| |||d}|  d|S )NrI  r  r  )r.  r   rT  )r1  rB  rE   r  ZsqrZsumsqrr  r  r   rY     s    zaten::multinomialc                 C  sZ   |d k	r t |s t dd| |s:|dkr:t dd| t| |}| jd|tjj|dS )NZMultinomialz*generator is not supported for multinomialrH  zGreplacement=False when num_samples > 1 is not supported for multinomial)r  Zsample_size_i)r   r  r  r   r.  rh  ri  rj  )r1  r8  Znum_samplesreplacementr  Z	log_inputr  r  r   r     s&      
zaten::baddbmmc           
      C  s\   t j|}t| ||}t| || jd|| d}t| || jd|| d}	t| ||	S r  )r   rk  rl  r   r   r.  rp  r   )
r1  rB  Zbatch1Zbatch2r  rP  rr  Z	batch_mulZmul_aZmul_br  r  r   r%     s    zaten::meshgridzOptional[str])r1  indexingc                   s0  |d krd}n|dkr(t d| ||dkrJ|d |d  |d< |d<  fddt|D } fd	d|D } jd|ddi}g }t|D ]h\}} jdtjdtjddgt	| }	|| |	|< t
 | jd|	ddi}
| d|
| q|dkr"|d |d  |d< |d<  jd| S )Nij>   xyr  zUnsupported indexing: r  rH  r   c                   s,   g | ]$}t  | jd tdgdqS )r\  r  r]  )r   r@  r.  r^  r  r  r3  r  r   r    s     zmeshgrid.<locals>.<listcomp>c                   s   g | ]}  d |qS )r5  r6  r  r3  r  r   r    s     r:  r;  r\  rt  r]  r  prim::ListConstruct)r:  )r:  )r  )r   rd  r   r  r.  rA  r^  r   ry  rV  r?  r(  )r1  r  r  r  Ztensors_shapeZ	out_shaperq  r  r   r  Z
t_reshapedr  r3  r   r     s2     

zaten::remainderc                 C  s(   t | ||}| d||}| d||S )NrI  rQ  )rb  r.  )r1  r8  rE  rF   Zquor  r  r   r   $  s    z
aten::gelu)r1  rB  approximatec                 C  s&  |dkrt dt j }d}tj|tjd}tj|tjd}tjdtjd}tjdtjd}t| |t| ||}	t| |t| |t| ||	}
t| |t| |t| || d|
S d}| d	| d
|tj|tjd}t| || jdtjdtjdd}t| t| ||| jdtjdtjddS d S )Nr   r  gHm?rt  rZ        ?r  g;f?r  re  r\  rH  r]  )	r  r   r  r^  r   r  r   r   r.  )r1  rB  r  ZkBetaZkKappar  kappar|  ZhalfZ	self_cubeinnerZ_sqrt2rO   Zerf_plusoner  r  r   r^   ,  s,    $"  
zaten::group_normc              
   C  s  t  r | jd||||||dS t |d}|d k	rD|| dksDtt |}|d krdt dd|S d|dg}	t | || jdt	
|	d}
| jdt	jd	g| tj| d
d}| jdt	jdg| tj| d
d}| jd|
|||d}t | || d|}|d ks"|  rLt	jd	gtj| d
}| jd|d}|d ksd|  rt	jdgtj| d
}| jd|d}ttd|d }t| t| |t | ||t | ||S )Nra   )Znum_groups_ir  Zcudnn_enabled_irH  r   zunknown input rankr  r\  r]  rZ  rt  r  r  r  r5  )r   r  r  r  r  r  r  r@  r.  r^  r  r   r   rk  rl  ru  r  
mustBeNoner=  r  r   r   r  )r1  r8  Z
num_groupsr  r  r  r  r  Z
input_rankr>  r  r  r   Znorm_reshapedr   r  r  r  r  r  r   ra   I  s|    


        zaten::_weight_normc                 C  s   t |}|d k	rttt|}|d k	rH|dk r6||7 }|dkrH|| t| |d|d}| d||}| d||S t  r| jd|||dS t	
d|d S )	Nr  r  rH  re  rI  _weight_normr  zDUnsupported: ONNX export of _weight_norm for tensor of unknown rank.)r   r  r=  r  remover   r.  r  r  r   rd  )r1  r  Zweight_grE   r  r  Znorm_vrF   r  r  r   r    s"    

r  z	aten::dimc                 C  s   |  d|}|  d|S )zFImplement the dim functionality available for a pytorch tensor in ONNXr5  Sizer6  rA  r  r  r   rE     s    zaten::__contains_c                 C  sd   t |}tdd |D rTt |rT| jdtt | ddd |D kdS t	
d|d S )Nc                 s  s   | ]}t |V  qd S r  )r   r  r~  r  r  r   r    s    z__contains_.<locals>.<genexpr>r\  r_  c                 s  s   | ]}t | d V  qdS )r_  N)r   r)  r  r~  r  r  r   r    s     r]  zJUnsupported: ONNX export of __contains__ for non-constant list or element.)r   r  r  r  r.  r^  r   r)  r  r   rd  )r1  rB  elementZunpacked_listr  r  r   __contains_  s$    
r  zaten::__getitem_c                 C  s    t | || jdtdgd|S rJ  )r   r.  r^  r   )r1  rB  r  r  r  r   
__getitem_  s    r  z
aten::itemc                 C  s   |S r  r  r  r  r  r   rr     s    z
aten::takec              
   C  sD   t | || jdtjdgtjdd}t| |d|}t| ||}|S )Nr\  r  rt  r]  r   )r   r@  r.  r^  r   ry  rl   r   )r1  rB  rm   Zself_flattenedrq  r  r  r   r     s      c                 C  s&   t | ||}t| |}t| ||}|S r  )r   rP   r   )r1  r8  targetdiff_Zexp_r  r  r  r   _kl_div_log_target_impl  s    
r  c           	      C  sZ   t | |}t| ||}t| ||}t| |}t| || jdtdd}t| |||}|S rJ  )	r   r   r   r  rb   r.  r^  r   r  )	r1  r8  r  Zlog_r  Z
output_posZzeros_Zmask_r  r  r  r   _kl_div_non_log_target_impl  s    

r  zaten::kl_divc                 C  sj   |rt | ||}nt| ||}|dkr*|S |dkrB| jd|ddS |dkrZtj| |ddS td|S d S )Nr   rH  r  r#  r  z4kl_div with reduction other than none, mean, or sum.)r  r  r.  r   rT  r  )r1  r8  r  	reductionZ
log_targetr  r  r  r   rs     s     zaten::mse_lossc                 C  sh   t | t| ||t| ||}|dkr(|S |dkr@| jd|ddS |dkrXtj| |ddS td|S d S )Nr   rH  r  r#  r  z6mse_loss with reduction other than none, mean, or sum.)r   r   r.  r   rT  r  )r1  r8  r  r  r  r  r  r   r     s     zaten::as_stridedc                 C  s  t |d}t|}t | || jdtjdgtjdd}t |stjdgtj	d}t
t||D ]6\}\}	}
dg| }d||< |t|	||
  }qd|r|| }| d|| jd|dS d }t
|D ]\}}
dg| }d||< t| || jdtdgd| jdt|d}	t | t| |	d	d d d | jdt|d}| d
|| jdt|
gd}|d krr|}q| d||}q|r| d|| dt|g}| d||S d S )Nr  r\  r  rt  r]  r   rH  r  r  rI  rF  )r   r  rV  r@  r.  r^  r   ry  rJ  r  rA  rp  r   r  r   )r1  rB  rc  stridesoffsetr  Zself_1dindr  r   r]  Zr_sizeZtmp_indr  r  r   r      sT      


  
zaten::__derive_indexc              	   C  s   |  d||  d||S )NrF  rI  r6  )r1  rm   r  r  r  r  r   __derive_indexM  s    r   zaten::__range_lengthc                 C  s6   |  d||}|  dt| ||}| j d|tjjdS )NrQ  rI  rf  rg  )r.  r  rh  ri  rj  )r1  lor  r  r   rF   r  r  r   __range_lengthS  s    r  zaten::linearc                 C  s   t |}t| |}|dkrp|  sp| jdtjdtjdd}| jdtjdtjdd}t	| |||||}n$t
| ||}|  st| ||}|S )Nr  r\  rH  rt  r]  )r   r  r   r  r  r.  r^  r   ry  r   r   r   )r1  r8  r  r  r  rP  r  r  r  r  r   r}   c  s    

zaten::hann_windowzOptional[int])r1  ru  c              	   C  s   |d kr.t  }|r|js t j}tj|}	n
t|}	t| |dd d d }
| jd|
t	j
jd}t| | jdt jtjt jdd|}|dkrt| || jdt jdt jdd}t| ||}| jdt| t| ||	 d}|S )	Nr  rf  rg  r\  rt  r]  FrH  )r^  r~  ro   r  r   rk  Z
from_dtyper   r.  rh  ri  ro  r   r   r  r  r   r  rF   r   r   rp  )r1  Zwindow_lengthZperiodicru  rA  rB  rC  rH  Zdtype_rr  Zn_arrayr  r  r  r   rc   t  s4    

    zaten::mvc                 C  s   t | ||S r  r   )r1  rB  Zvecr  r  r   r     s    z	aten::dotc                 C  s   t | ||S r  r  rT  r  r  r   rG     s    zaten::movedimc           
      C  s   | d}| d}| | ks(t||k r8|S t|}|d k	sNttt|}| }| }t	|
 |
 D ] \}}	|||	< d||< d||	< q|dd |D }dd |D }t	||D ]\}}	|||	< q| jd||dS )Nr  c                 S  s   g | ]}|d kr|qS r  r  r  r  r  r   r    s      zmovedim.<locals>.<listcomp>c                 S  s   g | ]}|d kr|qS r  r  r  r  r  r   r    s      r  r  )r  r   r  r  r   r  r=  r  r  rp  tolistr.  )
r1  rB  r  destinationr;  r  Zsrc_dimsZdst_dimsr  dstr  r  r   r     s&    




z
aten::fillc                 C  s    t j|t jj}t| |||S r  )r   rk  rl  ro  rZ   )r1  rB  r_  rr  r  r  r   rT     s
     zaten::index_addc                   s  t d |r0tt|dkr0tdd|S t d  d krPtd|t	|}t	|}|d kst|d krtd|||kr|| }t
|D ]}	t| |t	|g}qt| }
t| }|
d k	r|d k	r|
|krtd|tt
|}d	d
 t
|D } fdd
t
|D }tj| ||||d}t| ||}t
 D ]}	t| |dg}qLt
|  d D ]}	t| |t	|g}qtt| | t| |||S )NzyWarning: ONNX export does not support duplicated values in 'index' field, this will cause the ONNX model to be incorrect.rH  rh   z
alpha != 1r  zXONNX export does NOT support exporting 'index_add_()' function with unknown 'dim' value.z~ONNX export does NOT support exporting 'index_add_()' function while the rank of self tensor or tensor to be added is unknown.zoONNX export does not support exporting 'index_add_()' function with duplicated values in 'index' parameter yet.c                 S  s   g | ]}d qS r  r  rX  r  r  r   r  
  s     zindex_add.<locals>.<listcomp>c                   s   g | ]}| krt jnd qS r^  )sysmaxsizerX  r  r  r   r    s     r-  r   )r  r  r   rM  rN  r  r  r   rd  r  r  r  r  r=  r/  rQ   r   )r1  rB  rE   rm   rE  rP  Zself_dim_rankZother_dim_rankdeltar  Zother_dim_sizeZself_dim_sizeZnew_shape_axesZnew_shape_startsZnew_shape_endsr  r  r  r   rh     sl    

  
      
z
aten::rollc                 C  s   t |t |kst|}tt |D ]~}g }tj| ||| g||  gtjgd}|| tj| ||| gdg||  gd}|| | jd|d|| i}q$|S )Nr-  r   r:  r;  )r:  )	rV  r  r  r   r/  r  r	  r(  r.  )r1  rB  Zshiftsr  r  r  r  r>  r  r  r   r     s,       
 
    

zaten::crossc                 C  sp   t ||}t| |dg|g}t| |dg|g}t| |dg|g}t| |dg|g}t| t| ||t| ||S )Nr  rH  )r   Z_get_dim_for_crossr   r   r   )r1  r8  rE  rE   Zroll_x_1Zroll_y_1Zroll_x_2Zroll_y_2r  r  r   rB   3  s    zaten::cdistr  #use_mm_for_euclid_dist_if_necessaryc                 C  sR   t |}|d k	stt | ||d g}t | ||d g}t| |||dddS )NrH  r  gư>F)r  r  )r   r  r  r  r   )r1  r  r  r  Zcompute_moder  Zbroadcasted_x1Zbroadcasted_x2r  r  r   r/   G  s    
     z
aten::lerpc                 C  sx   |  d||}t| |  d|| j dtdd|  d||  d|||  d||  d||  d| j dtdd|S )	NrQ  r  r\  r  r]  rF  rI  rZ  )r.  r  r^  r   )r1  rB  r  r  diffr  r  r   rw   _  s    zaten::broadcast_tensorsc                   sP   t |}t |d |D ]}t |q fdd|D } jd| S )Nr   c                   s   g | ]}t  |qS r  )rQ   r  r1  Zt_with_final_shaper  r   r    s     z%broadcast_tensors.<locals>.<listcomp>r  )r  )r   r  r  r   r.  )r1  rB  Zall_tensorsr   Zt_listr  r  r   r+   u  s    
zaten::is_pinnedc                 C  s   d S r  r  )r1  rB  rB  r  r  r   rp     s    prim::ConstantSplitc                 C  s^   t ||}|d kr"t dd|S |g||  }|| }|rF|| | jd|||t|dS )Nr  r   r!  r"  )r   r  r  r(  r.  rV  )r1  rB  r%  rE   r   r&  r'  r  r  r   r     s      
prim::ConstantChunkc                 C  s@   t ||}|d kr"t dd|S || d | }t| |||S )Nr  r   rH  )r   r  r  r   )r1  rB  r#  rE   r8  r%  r  r  r   r     s      zprim::shapec                 C  s   |  d|S r4  r6  r  r  r  r   r     s    z	prim::maxc                 C  s   t | d||ddS )Nr   r  rF  rG  rT  r  r  r   r     s    z	prim::minc                 C  sB   |s6t |r,t| || jdtdgd}t| |S t| ||S rJ  )r   r  r   r.  r^  r   r   rT  r  r  r   r     s
    

z
prim::datac                 C  s   |S r  r  r  r  r  r   r     s    zprim::layoutc                 C  s   | j dtddS rJ  r  r  r  r  r   r     s    r  c                 O  s   d S r  r  r1  rB  r  r  r  r   r     s    zprim::ListUnpackzOptional[List[_C.Value]])r1  r  c                 O  s2   t |dkr.|d   dkr.t|d S d S )NrH  r   r  )rV  r  r  r   r  r  r  r  r   r     s     zprim::TupleConstructc                 O  s   d S r  r  r  r  r  r   r     s    zprim::Uninitializedc                 O  s   d S r  r  r  r  r  r   r     s    zprim::unchecked_castc                 C  s   |S r  r  r  r  r  r   r     s    zprim::dtypec                 C  s.   t |}|d krtjj}| jdt|dS r[  )r   r  r   rk  ro  r.  r^  r   r  r  r  r   r     s    
prim::tolistc                 C  s&   t |d}|dkr"t dd|S |S )ztolist is currently supported only for 1D input tensors.

    dim_val and elem_ty_val represent dimension and type annotations
    that need to match dimension and type of the input tensor.
    r  rH  r  zdim_val > 1)r   r  r  )r1  r8  Zdim_valZelem_ty_valrE   r  r  r   r     s    r_  Nonec                 O  s>   | j   }t|tjrd S tdd|  d| j  S )Nr_  z,output type should be 'DeviceObjType', not '')	original_noder  r  r<  r	   r  r   r  r  )r1  rB  r  output_typer  r  r   r     s    z
prim::LoopzList[_C.Value]c                 O  s   | j }| j}| j}tj}tj}t| }tj	| df||
 t|d\}	}
}t||
D ]\}}t| D ]l\}}|dkr|t|k r|||   |dkrp|d t|k rpt| tjsp|||d    qptj||j||d q\tj||}tjrtj||| |S )NZLoopr  Zn_blocksr   rH  F)r  envparams_dictr   rU  r@  r  blocksr   add_op_with_blocksoutputsSizerV  rp  rA  rB  r/  r  r<  r	   r0  r^  _jit_pass_onnx_blockblock%_jit_pass_fixup_onnx_controlflow_nodeonnx_shape_inference(_jit_pass_onnx_node_shape_type_inference)r1  rB  attrsr  r  r  rU  opset_version
old_blocksnew_op_outputsnew_block_contextsnew_node	old_blocknew_block_contextr  Zb_infixed_outputsr  r  r   r     sX         zprim::Ifc                 O  s  | j }| j}| j}| j}tj}tj}|d   dk}	|	rt	
|d  d }
t|
trht|
nt|
}|rxdnd}t| | }tj||||d}t| }t| }g }tt|D ]B}|| |krtd||  d|| |||  }|| q|S t| }tj| df|| t|d	\}}}t||D ] \}}tj||j||d
 qNtj||}tj rtj!||| |S d S )Nr   r  r_  rH  TzThe sub block ATen output z is not in env.Ifr  F)"r  r  r  r  r   rU  r@  r  r  r   r)  r  r<  r=  r  r  r  r^  r	   r  r  r  rV  r   rd  r(  r  r   r  r  rp  r  r  r   )r1  rB  r!  r2  r  r  r  rU  r"  Z	static_ifZ
input_flagrO  Z	block_idxZ	current_bZif_output_listZcurrent_b_listZfinal_b_listr  Zonnx_br#  r$  r%  r&  r'  r(  r)  r  r  r   r   R  sz         r-  c                   s*   j }| rd S t|  tjr*d S |ddkrN jdt	
|ddS |ddkrr jdt	
|ddS |  tj s|  tj r jdtt	
|ddS |  tj r  fddt	
|dD } jd| S td
|d dtj d| d S )Nr_  r   r\  r]  r`  Zvalue_sc                   s   g | ]} j d |dqS )r\  r+  r6  )r  r`  r3  r  r   r    s   z!prim_constant.<locals>.<listcomp>r  z"Unsupported prim::Constant kind: 'z'. Please send a bug report at .)r  )r  r  r<  r  r  r	   r  r  r.  r   r)  ra  rb  rc  ZofFloatsr^  r   Z	ofStringsr   rd  r
   ZPYTORCH_GITHUB_ISSUES_URL)r1  rB  r!  r  Zstr_constantsr  r3  r   r     s6     


prim::type)r1  device_valuec                 O  sJ   |   dkr<t|   }|d k	r<| jdt|dS tdd|S )Nr_  r\  r+  r-  z,Device type cannot be statically determined.)	r  r  r   Zget_device_from_valuer8  r.  r#  r   r  )r1  r.  r  r  rB  r  r  r   r     s    zonnx::Placeholderc                 O  s"   | j }| j}| j}tj|||S r  )r  r  r  r^  r	   Z'_jit_onnx_convert_pattern_from_subblock)r1  rB  r!  r  r  r  r  r  r   r     s    zaten::resolve_conjzaten::resolve_negr7  c                 C  s   |S r  r  r7  r  r  r   r    s    	zaten::_conjzaten::conj_physicalc                 C  s    t |rt d|S t| |S )Nz aten::_conj, aten::conj_physical)r   Zis_complex_valuer  r  r7  r  r  r   r
    s    	
zaten::logit)r1  rB  r  c                 C  s   | j dtdd}t|s| j d|tj| d}|  d||}|  d||}|  d|||}|  d	||}|  d|||}n|}|  d||}	|  d
||	}
|  d|
S )Nr\  rZ  r]  rf  rg  rQ  r  r  r  re  r  )	r.  r^  r   r   r  r   rk  rl  rp  )r1  rB  r  r|  Zone_sub_epsZself_less_equal_one_sub_epsZtemporary_selfZtemporary_self_less_epszr   rF   r  r  r   r     s     
  )N)N)N)rZ  )T)T)N)N)N)N)N)N)r   N)N)F)N)N)NNN)N)N)FF)NN)NN)N)FN)NNNFN)F)NNF)NN)F)NNNFN)F)F)NNNFN)F)F)NNNFN)F)N)N)NN)NN)NNFN)NNFN)NNN)N)F)r  )NF)FN)N)r  )N)TNNNNF)N)N)r  r  )N)N(v  __doc__
__future__r   r  r  r  r  r  typingr   r   r   r   r   r   r^  Ztorch._C._onnxr	   Z_onnxrh  Ztorch.nn.modules.utilsZ
torch.onnxr
   r   r   r   r   Ztorch.onnx._globalsr   Ztorch.onnx._internalr   r   r   Ztorch.typesr   r'  partialZonnx_symbolicZ_onnx_symbolicr"  r,  r  r  r9  r?  r  r   r   r   r   r   r   rF   r  r   rW  rc  rb  rV   rX   r  r   r.   r   r  r   r*   r   r   r   r   r   r   r   r@   r   r"   r   r#   r$   r   r   r  r  r  r   r  rC   r  r  r   r   rR   r,   rQ   rK   rJ   r   r  r   r  r  r  r   r  r   r  r  r   r   r   r   r   r   rC  r   r   r0   rW   rK  r   rv   r`   r   r   r_   nnmodulesutilsZ_singleZ_pairZ_triplerx  r   r   r   r  r  r  r  r5   r  r   r   r   r  r  r(   r)   r  r  r  rN   r   rb   r  r   r  r]   ru   r  r  r  r   r   r   r   r  r  r  r   r  r  r  r?   r;   r<   r=   r8   r9   r:   r&   r   rt   rn   r  rI   r   rl   rk   rj   ri   r-   r  rA   r   r4   r   r   r   r   r   r3   r2   r1   r   r   r   r   r   r   r   rP   rH   r/  r   r7   r1  r6  
deprecatedr8  r9  r:  r;  r<  r=  r>  r?  r@  rM   rL   r   r   r   r!   r  r  r   r  r   r   r   r[   rZ   r   rS   r   rg   rf   re   r   rd   r   r   r	  r   r   r   r>   r   r   r   r   r   r  r  r  r   r   r  r  rD   r6   r  r  r   r   r   r   r   r   r   r'   r   rO   rU   r   r   rq   r  r  r   r   r   r   r   r   ro   r  r  r   r\   r  r   r  r  r   r   r   r~   rx   r   r   rm   r{   r|   rz   ry   rY   r   r%   r   r   r^   ra   r  rE   r  r  rr   r   r  r  rs   r   r    r   r  r}   rc   r   rG   r   rT   rh   r   rB   r/   rw   r+   rp   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r
  r   r  r  r  r   <module>   sT       		


5
*4	






*!


>5

	
*
	>+ 
  	 
  	 
  	;



8  
	  
	  
	


7 )		(,

             C	7*"#"#"#&&&*74:8**_$


  #


	 

 

 
$     	         	    	    	H
(



 K  hFC  f**  C*      




(	(	 "Q *0!0&0;    $(E	

",	"     "& E  
2W""	"