U
    9%e4                  9   @  s  d Z ddlmZ ddlZddlZddlZddlmZmZ ddl	Z	ddl	m
Z
 ddlmZ ddlmZmZmZmZmZmZ ddlmZ dd	lmZmZmZ d
dddddddddddddddddddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2d3d4d5d6d7d8d9d:d;d<d=d>d?d@dAdBg9ZejejdCdDZ dEdF Z!e dGe"dHe#dIdJdJej$dKdLdMdMdNdOdZ%e dPej$dKdQdRdZ&e dSe#dIdIej$dKdQdTdZ'e dUe#dIdIej$dKdQdVdZ(e dWej$dKdQdXd2Z)e dYe"dHe#dIdZdIej$dKdQd[d7Z*e d\ej$ddKdQd^d#Z+e d_e#dIdZej$dKdQd`d.Z,e dae!dbdcddgdee dfe!dgdhddgdee die!djdkddgdee dle!dmdcdngdee doe!dpdhdngdee dqe!drdkdngdee dse!dtdhdugdeej$dvdwdvdxdydzZ-e d{e"dHd]d]d]d]d]d]ej$dKdQd|d}Z.e d~e#dIdZdIdIej$ddKdQddZ/e de#dIdZdIdIej$dKdQdd6Z0e de#dIdZdej$ddKdQddZ1e dej$dKdQdd)Z2e dej$dKdQdd(Z3e dej$dKdQddZ4e dej$dKdQddZ5e dej$dKdQddZ6e dej$dKdQddZ7e dej$ddKdQdd
Z8e dej$dKdQdd$Z9e dej$dKdQdd/Z:e dej$dKdQddZ;e de"dHej$dKdQddZ<e dej$dKdQdd=Z=e de#dIdZdZdZej$dKdQddZ>e de#dIdZdZdZdZej$dKdQdd@Z?e de#dIdIdZdZdZdej$ddKdQdd>Z@e de#dIdZdZdej$ddKdQdd9ZAe de#dIdZdZdej$ddKdQddZBe de#dIdZej$ddKdQdd5ZCe dej$dKdQdd3ZDe de#dIdIdZdZej$ddKdQdd;ZEe de#dIdIdZdZej$ddKdQdd:ZFe de#dIdZdZej$ddKdQdd?ZGej$dKdQddZHe dej$ddKdQddZIe de de dej$dKdQdd1ZJe de de dej$dKdQdd4ZKe dej$dKdLdLdLdLddd-ZLe dăej$dKdQdd%ZMe dƃej$dKdQdd'ZNe dȃej$dKdQddZOe dʃe#dIdZej$dKdQdd̄ZPe d̓ej"dHd]d΍ej$ddKdQdd8ZQe dЃej$ddKdQdd<ZRe d҃ej$dKdQddAZSe dԃej$dKdQdd*ZTe dփej$dKdQdd!ZUe d؃ej$dKdQdd ZVe dڃej$dKdQdd"ZWe d܃ej$dKdQddބZXe d߃ej$dKdQddZYej$dKdQddZZej$dKdQddZ[ej$dKdQddZ\e de#dIddddej$dKdQddZ]e dej$dKdQdd+Z^e de"dHd]d]e#dIdZdZej$dKdQddZ_e de#dIdJdddIej$dKddddd&Z`e de#dIdIdIdZdZdZdIdZdZ	ej$dKdQddZae de#dIdIdJdJej$dKdQddZbe dej$dKdQddZce dej$ddKdQdd,Zde dej$dKddd dZee dej$dKddddZfe dej$dKddddZge dej$dKdQdd0Zhe dej$dKdLdd	dZie d
ej$dKdLdddBZjdS (  z(This file exports ONNX ops for opset 11.    )annotationsN)OptionalSequence)_C)_onnx)_type_utilserrorssymbolic_helpersymbolic_opset10symbolic_opset9utils)GLOBALS)	_beartype	jit_utilsregistrationaddappendarangeargsort
atleast_1d
atleast_2d
atleast_3dcatchunk	clamp_max	clamp_minclampconstant_pad_ndcumsumDeleteembedding_bagembedding_renormflattengatherhardtanhhstackim2col
index_fillindex
index_copy	index_putinsert
linalg_detlinalg_vector_normlogdetmasked_scattermasked_selectmmnarrownormalpadpixel_shufflepopprim_constant_chunkreflection_padrelu6	remainderreplication_padroundscatterselectsizesortsplit_with_sizessplitsqueezestacktopkunbind
unique_dim	unsqueezevstack   )Zopsetc                    s    fdd}|S )z_Returns a decorator that calls the decorated (higher-order) function with the given parameters.c                   s
   |  S N )fnargskwargsrL   Z/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/torch/onnx/symbolic_opset11.py_apply\   s    z_apply_params.<locals>._applyrL   )rO   rP   rR   rL   rN   rQ   _apply_paramsY   s    rS   zaten::hardtanhTvfzjit_utils.GraphContextz_C.Valuefloat)gselfmin_valmax_valc                 C  s`   t j|t jj}| jdtj|| dd}| jdtj|| dd}tj	| d|||ddS )NConstantdtypeZvalue_tClip   Zopset_before)
r   JitScalarType
from_valueFLOAToptorchtensorr]   opset9_op_with_optional_float_cast)rW   rX   rY   rZ   scalar_typerL   rL   rQ   r$   b   s(          zaten::clamprW   c                   s   t j fdd}tj|tjj}|tjjkrD|||}|||}t|rZt ||S t|rpt	 ||S t
|dkrt
|dkrtj d|||ddS t t	 |||S d S )Nc                   s.   | d k	r&t | s& jd| | dS | S d S )NCastZto_i)r	   _is_nonere   	onnx_type)rg   r]   rk   rL   rQ   _cast_if_not_nonez   s    z clamp.<locals>._cast_if_not_noner   r_   r`   ra   )r   beartyper   rb   rc   	UNDEFINEDr	   rn   r   r   _get_tensor_rankrh   ri   )rW   rX   minmaxrp   rj   rL   rk   rQ   r   w   s4    
 



     zaten::clamp_minc                 C  sb   | j d|tj| d}t|dkrJt| }tj	| d|||ddS tj	| d||ddS d S )Nrl   rm   r   r_   r`   ra   ZMax
re   r   rb   rc   ro   r	   rs   rh   Zunusedri   )rW   rX   rt   ru   rL   rL   rQ   r      s    
     zaten::clamp_maxc                 C  sb   | j d|tj| d}t|dkrJt| }tj	| d|||ddS tj	| d||ddS d S )Nrl   rm   r   r_   r`   ra   ZMinrv   )rW   rX   ru   rt   rL   rL   rQ   r      s    
     zaten::relu6c                 C  sX   t j|t jj}| jdtjd| dd}| jdtjd| dd}t| |||S )Nr[   r   r\   r^      )	r   rb   rc   rd   re   rf   rg   r]   r   )rW   inputrj   rY   rZ   rL   rL   rQ   r9      s     zaten::selectic                 C  s   | j d|||dS )NGatheraxis_ire   )rW   rX   dimr(   rL   rL   rQ   r>      s    zaten::index_putFc                   s  t |rt |}n|g}t  rD|g| ||g }jd| S t |d}t|dkr`|S t|dkr tt|D ]&}t || rz	d|| ||< qz|d }|dd  D ]}	t
||	}q	d|  fdd|D }j	d|d
di}n|d }|}
t |
rt |}|d k	rF|dkrFt
||
|S t |
}t |}|d k	r|d k	r||krt |
tt||}
t||
|S 	d| t |dg}t j	d|dgt|gtjgd}j	d	 |dd}t |}|d k	r"|dkr"t
||d }t ||}tj|tjj}|tjjkrtj|tjj}||krj	d|| d}n|rtd||rj	d	d|tjdg| dd}	d|||}t||}n	d|||}|S )Nr*   br      ZNonZeroShapec                   s(   g | ] }t t| d dgqS )N)r	   _unsqueeze_helperrh   expand).0indZbroadcast_index_shaperW   rL   rQ   
<listcomp>   s     zindex_put.<locals>.<listcomp>Concatr|   r   axesstartsendsr{   rl   rm   z'self does not have a valid scalar type.ConstantOfShaper\   r^   	ScatterND)r*   )r   ) r	   _is_packed_list_unpack_listis_caffe2_aten_fallbackat
_parse_arglenrange_is_boolre   rh   r   rs   Zmasked_fillr   listr/   _slice_helpersysmaxsizer   _reshape_helperr   rb   rc   rr   ro   r   SymbolicValueErrorrf   rg   r]   )rW   rX   Zindices_list_valuevalues
accumulateZindices_listrO   Zidx_r(   r   Zbool_inprankZ	mask_rankZ	self_rankZsub_data_shapeZvalues_shapeZself_scalar_typeZvalues_scalar_typeZzerosresultrL   r   rQ   r*      s    
(


   
   
  

zaten::pixel_shufflec                 C  s8   t |}|d k	r&|dkr&t ddS | jd||ddS )N   r5   zonly support 4d inputZDepthToSpaceZCRD)Zblocksize_imode_s)r	   rs   _unimplementedre   )rW   rX   Zupscale_factorr   rL   rL   rQ   r5   S  s    
zaten::upsample_nearest1dZupsample_nearest1d   Znearest)Zdecoratezaten::upsample_nearest2dZupsample_nearest2dr   zaten::upsample_nearest3dZupsample_nearest3d   zaten::upsample_linear1dZupsample_linear1dZlinearzaten::upsample_bilinear2dZupsample_bilinear2dzaten::upsample_trilinear3dZupsample_trilinear3dzaten::upsample_bicubic2dZupsample_bicubic2dZcubicstrintnamer~   Zinterpolate_modec                 C  s   t | ||S rK   )r	   Z_interpolate_helperr   rL   rL   rQ   _interpolate]  s    r   zaten::__interpolatec              	   C  s   t | ||||||S rK   )r	   Z__interpolate_helper)rW   rx   r?   Zscale_factormodeZalign_cornersZrecompute_scale_factorZ	antialiasrL   rL   rQ   __interpolate~  s          r   zaten::gatherc                 C  sD   t |drt ddS t  r2| d||||S | jd|||dS )Nry   r#   zsparse_grad == TrueZGatherElementsr{   )r	   _maybe_get_constr   r   r   re   )rW   rX   r~   r(   Zsparse_gradrL   rL   rQ   r#     s
    zaten::scatterc              	   C  s   t  r| jd||||ddS tj|}t |}t |rR| jd||||dS tj||kr~| jd|tj|	 d}| jd||t
| |||dS d S )Nr=   srcoverload_nameZScatterElementsr{   rl   rm   )r	   r   r   r   rb   rc   _maybe_get_scalar	_is_valuere   ro   rh   	expand_as)rW   rX   r~   r(   r   Zsrc_typerL   rL   rQ   r=     s&    

    zaten::cumsumnonec                 C  sn   | j dtj|tjdd}|rX|  dkrXt|dd}| j d|t	|
 d}n|}|  d	||}|S )
Nr[   r\   r^   zprim::Constantry   r]   rl   rm   ZCumSum)re   rf   rg   r   nodekindr	   
_get_constr   rb   ro   )rW   rX   r~   r]   Z
dim_tensorZparsed_dtypecastZcsumrL   rL   rQ   r     s      zaten::masked_selectc                 C  s$   t | t | ||}| d||S )NGatherND)rh   nonzeror   re   )rW   rX   maskr(   rL   rL   rQ   r0     s    zaten::masked_scatterc                 C  sr   t | t | ||}t| |tdg}tj| |tdgtdgt | |tdgd}| 	d|||S )Nr   r   r   r   )
rh   r   r   r	   r   rf   
LongTensorr   r?   re   )rW   rX   r   sourcer(   rL   rL   rQ   r/     s    

z	aten::lenc                 C  sT   t |s|  dkr&| d|S t| || jdtdgd}t | |dgS )Nzonnx::SplitToSequenceZSequenceLengthr[   r   r^   )	r	   _is_tensor_listr   r   re   r?   rf   r   _squeeze_helper)rW   rX   Zsz_0rL   rL   rQ   _len  s    r   zaten::__getitem_c                 C  s4   t |r| d||S ddlm} || ||S d S )N
SequenceAtr   )
__getitem_)r	   r   re   Ztorch.onnx.symbolic_opset9r   )rW   rX   ry   getitemrL   rL   rQ   r     s    
r   zaten::_set_itemc                 C  s   |  d||}|  d|||S )NSequenceEraseSequenceInsertr}   )rW   tensor_listry   rT   rL   rL   rQ   	_set_item  s    r   zaten::appendc                 C  s   |  d||S Nr   r}   )rW   rX   rg   rL   rL   rQ   r     s    z	aten::addc                 C  sn   t |r^t |r^| }| dkr4t ddS t |}|}|D ]}| d||}qF|S t	| |||S )Nzprim::ListConstructr   z6does not support adding dynamic tensor list to anotherr   )
r	   r   r   r   r   r   r   re   rh   r   )rW   rX   otheralphaZtensor_list_nodeZtensorsltrL   rL   rQ   r     s     
zaten::insertc                 C  s   |  d|||S r   r}   )rW   rX   posrg   rL   rL   rQ   r+     s    z	aten::popc                 C  s   |  d||S Nr   r}   rW   r   r~   rL   rL   rQ   r6     s    zaten::Deletec                 C  s   |  d||S r   r}   r   rL   rL   rQ   r     s    z	aten::catc                 C  s:   t |rt| ||S t |dd}| jd||dS d S )Nry   r~   ConcatFromSequencer{   )r	   r   rh   r   r   re   r   rL   rL   rQ   r   %  s    
zaten::stackc                 C  s<   t |rt| ||S t |dd}| jd||ddS d S )Nry   r~   r   r   r|   Z
new_axis_i)r	   r   rh   rD   r   re   r   rL   rL   rQ   rD   0  s    
zaten::_unique2c           	      C  s$   | j d||dd\}}}}|||fS )NUniquer   )sorted_ioutputsr}   )	rW   rX   sortedreturn_inversereturn_countsuindicesinverse_indicescountsrL   rL   rQ   _unique2:  s       r   zaten::unique_dimc           
      C  s&   | j d|||dd\}}}}	|||	fS )Nr   r   )r|   r   r   r}   )
rW   rX   r~   r   r   r   r   r   r   r   rL   rL   rQ   rG   D  s        z
aten::topkc              	   C  s   t j| ||||||dS )N)largestr   out)r	   Z_topk_helper)rW   rX   kr~   r   r   r   rL   rL   rQ   rE   P  s          z
aten::sortc                 C  s   t j| ||||dS N)	decendingr   r	   Z_sort_helper)rW   rX   r~   r   r   rL   rL   rQ   r@   Y  s    zaten::argsortc                 C  s   t j| ||||d\}}|S r   r   )rW   rX   r~   r   r   _r   rL   rL   rQ   r   `  s        
zaten::roundc                 C  sz   t |s|S |dkr"| d|S | d|| jdttd|d}| d|}| d|| jdttdd| dS )Nr   ZRoundMulr[   
   r^   r   )r	   _is_fpre   rf   rg   pow)rW   rX   Zdecimalsmulr<   rL   rL   rQ   r<   j  s    
$  zaten::remainderc                 C  s4   t |st |r"t| ||S | jd||ddS )NModr   )Zfmod_i)r	   r   rh   r:   re   )rW   rx   r   rL   rL   rQ   r:   y  s    zaten::splitc              
     s  t ||s jd|||d|d kr*S t |rtt ||krڇ fddt |D } jdtjdgtjdd} jdtj|gtjdd}g }t	|D ]2}	 d	|||	 }
|
 d
|||
| |
}q|S  fddt	|D S t ||||S d S )NSplitToSequencer{   c                   s   g | ]}t  |d gqS )r   )r	   r   )r   rT   rk   rL   rQ   r     s   zsplit.<locals>.<listcomp>r[   r   r\   r^   AddSlicec                   s2   g | ]*}  d  j dtj|gtjddqS )r   r[   r\   r^   )re   rf   rg   long)r   ry   rW   Z	split_outrL   rQ   r     s   )r	   Z_is_split_staticre   r   r   r   rf   rg   r   r   r   rh   rB   )rW   rX   Zsplit_size_or_sizesr~   _outputssplit_sizesstartaxisresry   endrL   r   rQ   rB     s6    
  	zaten::split_with_sizesc                 C  s   t | ||||S rK   )rB   )rW   rX   r   r~   r   rL   rL   rQ   rA     s    zaten::unbindc              	   C  sF   |d kr2| j d|| j dtjdtjdd|ddS t| |||S d S )Nr   r[   r   r\   r^   r   )r|   
keepdims_i)re   rf   rg   r   rh   rF   )rW   rX   r~   r   rL   rL   rQ   rF     s    c                 C  sz  t |s0t |r0t |r0| jd|ddd}t| || jdtdgd}t 	|}|dkrx| d| d	|}n| jdtj|tj
d
d}| d| d|| jdtjdtj
d
d|}| jd|tjjd}| jd|| jd|tjdgtj
d
ddd}t | || jdtddgd}| jdt| |dgddgd}t | || jdtdgd}| jd|tjjd}|S )a!  Generate paddings in ONNX order based on pad in pytorch.

    Args:
        input: the input tensor.
        pad: the paddings in pytorch.
            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
            where m is in range [0, n].
    r   r   r   r   r[   r^   NSizer   r\   Subr      rl   rm   r   r   r{   r   	TransposeZperm_i)r	   r   Z_is_listZ_is_scalar_listre   rh   r?   rf   rg   rs   int64_C_onnxTensorProtoDataTypeZINT64r   opset10flip)rW   rx   r4   Zpad_lenr   	extensionpaddingsZ	padding_crL   rL   rQ   _prepare_onnx_paddings  sR     
"       r  zaten::constant_pad_ndc                 C  s:   d}t |}t ||}t| ||}| jd||||dS )NconstantPadr   )r	   r   _if_scalar_type_asr  re   )rW   rx   paddingvaluer   r4   rL   rL   rQ   r     s
    
zaten::reflection_pad1dzaten::reflection_pad2dzaten::reflection_pad3dc                 C  s"   d}t | ||}| jd|||dS )Nreflectr  r  r  re   rW   rx   r  r   r  rL   rL   rQ   r8     s    zaten::replication_pad1dzaten::replication_pad2dzaten::replication_pad3dc                 C  s"   d}t | ||}| jd|||dS )Nedger  r  r  r  rL   rL   rQ   r;     s    z	aten::padrW   rx   r4   r   r	  c                 C  sv   t |d}|dkr t| ||S |dkr4t| ||S |dkrJt| |||S |dkr`t| ||S td| |d S )NsZ	replicater
  r  ZcircularzUnrecognized padding mode )	r	   r   r;   r8   r   rh   Z_pad_circularr   r   r  rL   rL   rQ   r4     s    	zaten::linalg_detc                 C  s   |  d|S )NZDetr}   rW   rX   rL   rL   rQ   r,   -  s    zaten::logdetc                 C  s   t | t| |S rK   )rh   logr,   )rW   rx   rL   rL   rQ   r.   3  s    aten::arangec                 G  s   dd }t |dkrtdd |D rtj}| jdtj|d |dd	}| jdtj|d
 |dd	}| jdtjd
|dd	}| d|||S t |dkst |dkr(t |dkrd }n||d
 }tj| |d |d\}}}}| jdtjd| dd	}	| jdtjd
| dd	}| d|	||S t |dksDt |dkrt |dkrXd }n||d }tj| |d |d
 |d |d\}
}}}| d|||S t |dkr||d }tj| |d |d
 |d\}}}}| jdtjd
| dd	}| d|||S t	ddt | dS d S )Nc                 S  s   t | d} | S )Nry   )r	   r   r\   rL   rL   rQ   _get_arange_dtype<  s    z!arange.<locals>._get_arange_dtyper   c                 s  s   | ]}t |tV  qd S rK   )
isinstancer   )r   valrL   rL   rQ   	<genexpr>@  s     zarange.<locals>.<genexpr>r[   r   r\   r^   r   Ranger   )r   r]   r      r   )r   r   stepr]   rw   )r   r   r]   r  zwith z
 arguments)
r   allrf   r   re   rg   r	   Z_arange_cast_helperr]   r   )rW   rO   r  r]   r   r   Zdelta_defaulttype_r  Zstart_defaultr   rL   rL   rQ   r   9  s~              zaten::_dim_arangec                 C  sT   |  d|}| j d|| j dt|ddd}t rB|  d|S t| |dd d d S )	Nr   rz   r[   r^   r   r{   z_caffe2::Ranger   )re   rf   rg   r	   r   r   )rW   liker~   Z
like_shapestoprL   rL   rQ   _dim_arange  s       r  z
aten::size)Zquantize_outputc                 C  s"   |d kr|  d|S t| ||S )Nr   )re   r	   _size_helperrW   rX   r~   rL   rL   rQ   r?     s    zaten::squeezec                 C  s|  |d kr|  d|S t|s.t| ||gS t|dd}t|}|}|d k	rb|dk rb||7 }t||}|dk r~|d ks|d kr,| j dt|gd}t	| ||}| j dtj
dtjdd}|  d	||}	tj| d
|	dd\}
\}}}t|||g}t|j| | d|}t|j| |
S |}|dkrltdt| d d t| d d d  |S t| ||gS )NZSqueezery   r~   r   r[   r^   r   r\   EqualIfr   n_blocksZIdentityz5This model contains a squeeze operation on dimension z. The size of z%this dimension in the given input is z. The model will zWbe exported without the squeeze node. If the model is intended to be used with dynamic z7input shapes, please export with dynamic_axes argument.)re   r	   _is_constantr   r   rs   Z_get_tensor_dim_sizerf   rg   r  Zonesr   r   add_op_with_blocksr   _add_output_to_blockblockwarningswarnr   )rW   rX   r~   Z
input_rankZadjusted_dimdim_sizeZdim_constantr?   	const_onecondZif_opZ
if_contextZelse_contextr   Zsqueeze_Z	identity_rL   rL   rQ   rC     s^    

   

zaten::unsqueezec                 C  s(   t |rt |dd}t | ||gS )Nry   r~   )r	   r%  r   r   r   rL   rL   rQ   rH     s    
zaten::mmc                 C  s   | j d||dddS )NZGemmg        g      ?)Zbeta_fZalpha_fr}   )rW   rX   r   rL   rL   rQ   r1     s    zaten::indexc                 C  s   t  r| jd||ddS t |r0t |}n|g}t|dkr|d }t |st |srtj	
|tj	jkrt| |}| d||S t| ||S )Nr(   ZTensorr   r   r   r   )r	   r   r   r   r   r   rn   r   r   rb   rc   UINT8rh   r   re   r(   )rW   rX   r(   r   rL   rL   rQ   r(     s"    


zaten::index_fillc           	      C  st   t |d}t  r*| jd|||d|dS t | |||\}}t |}t ||}t| ||d }t	| ||||S )Nry   r'   Z
int_Scalar)r   dim_i)
r	   r   r   r   _index_fill_reshape_helperr   r  rh   r   r=   )	rW   rX   r~   r(   r	  	dim_valueexpanded_index_shapeexpanded_indexZexpanded_valuerL   rL   rQ   r'     s(    	   
zaten::index_copyc                 C  sL   t |d}t  r(| jd||||dS t | |||\}}t| ||||S )Nry   r)   )r/  )r	   r   r   r   r0  r=   )rW   rX   r~   r(   r   r1  r2  r3  rL   rL   rQ   r)   
  s       zaten::__rshift_c                 C  s   t j|t jjt j|kr:| jd|t j| d}t j|t jjt jjkrf| jd||ddS | jdtjdtj	dd	}t
|s| jd|tjjd}| d
||}| jd|t j| d}| d||}|S )Nrl   rm   BitShiftRIGHTZdirection_sr[   r   r\   r^   PowDivr   rb   rc   rr   re   ro   r.  rf   rg   Zfloat32r	   r   r   r   rd   )rW   rX   r   twotwo_powrshiftrL   rL   rQ   	__rshift_  s6     

r=  zaten::__lshift_c                 C  s   t j|t jjt j|kr:| jd|t j| d}t j|t jjt jjkrf| jd||ddS | jdtjdtj	dd	}t
|s| jd|tjjd}| d
||}| jd|t j| d}| d||}|S )Nrl   rm   r4  LEFTr6  r[   r   r\   r^   r7  r   r9  )rW   rX   r   r:  r;  lshiftrL   rL   rQ   	__lshift_8  s6     

r@  c                 C  s   |  d|| j dt|d d}|  d|| j dt||d  d}|  d| j dtdd|| j dt|d}td|| |}| j d|dd}t| |dg}t| || j dtd	dgd}	|  d||	}
|
S )
Nr   r[   r   r^   r   r   r  r   r   )re   rf   rg   r   rH   r	   r   r   )rW   Zinput_dZkernel_size_dZ
dilation_dZ	padding_dZstride_dZblocks_dZblocks_d_indicesZkernel_gridZkernel_maskZ
block_maskrL   rL   rQ   _get_im2col_indices_along_dimZ  s<    
      rA  c                 C  s.   | j dtdd||gd d}|  d||S )Nr[   r   r   r^   r  )re   rf   r   )rW   rx   	padding_h	padding_wr4   rL   rL   rQ   _get_im2col_padded_input  s     rD  c              
   C  s   t | || jdtdd}t | || jdtdd}| d|| jdt|| d}| jdt| |dgt| |dg| jdtdgdddS )	Nr[   r   r^   r   r   r   r   r{   )r?   re   rf   rg   r	   r   )rW   rx   kernel_hkernel_wZ	batch_dimZchannel_dimZchannel_unfoldedrL   rL   rQ   _get_im2col_output_shape  s      rG  zaten::im2colisc              	   C  s  t | || jdtdd}t | || jdtdd}|d |d  }}	|d |d  }
}|d |d  }}|d |d  }}t| ||||
|}t| |||||	}t| |||}t| ||
|}| jd||dd}| jd||d	d}| jd
|dddd	ddgd}t| ||S )Nr[   r   r^   r   r   r   rz   r{   r   r   r   r   )	r?   re   rf   rg   rA  rG  rD  r	   r   )rW   rx   Zkernel_sizeZdilationr  ZstrideZinput_hZinput_wZstride_hZstride_wrB  rC  Z
dilation_hZ
dilation_wrE  rF  Zblocks_row_indicesZblocks_col_indicesZoutput_shapeZpadded_inputoutputrL   rL   rQ   r&     s8              zaten::narrowc                 C  s"   |  d||}tj| ||||dS )Nr   r   )re   r	   r   )rW   rx   r~   r   lengthr   rL   rL   rQ   r2     s    zaten::flattenc                 C  s   t |}|dkr|S |dkrL|dks:|d k	r||d kr| jd||dS n8|dkr|dksp|d k	r||d kr| jd||d dS |d krt dd	S |dk r|| }t | ||||S )
Nr   r   ZFlattenr{   r   r   r~   zfONNX and PyTorch use different strategies to split the input. Input rank must be known at export time.)r	   rs   re   r   Z_flatten_helper)rW   rx   Z	start_dimZend_dimr~   rL   rL   rQ   r"     s"    
zaten::linalg_vector_normr   zOptional[Sequence[int]]bool)rW   r~   keepdimc                 C  s   |dkr|d kr:t | || jdtjdgtjdd}d}| d| d|| jdtdgd}| jd	|tj	|
 d
}t j| |||dS t| |||||S d S )Nr   r[   r   r\   r^   FNotr!  rl   rm   axes_ir   )r	   r   re   rf   rg   r   r   r   rb   rc   ro   _reducesum_helperrh   r-   )rW   rX   ordr~   rM  r]   Zcond_oprL   rL   rQ   r-     s0          zaten::embedding_bagc
                 C  s  |rt jrtdS |	d k	r,|	dkr,td| jdtdd}
| jd|
tj	j
d}
| jdtdgd}t| t| || jdtdddg}|s||g}| jd|d
di}tj| |dgdgtjgdgd}tj| |dgdgtjgdgd}t| || jdtdd}tj| d||
dd\}\}}|j}t|}t|}|jd||dd}|jd||dd}t||dg}t||dg}|d||||}|jd||dd}t|s|d||||}t||dg}|d||}|dkrtj||dgdd}n4|dkr,|jd|dgdd}n|jd|dgdd}|jd|
tj	j
d}t|| t|| |  d d d fS )Nz7embedding_bag with scale_grad_by_freq for training moder   zembedding_bag with padding_idxr[   r   r^   rl   rm   r   r|   )r   r   r   ZstepsZLoopr#  rz   r{   r   r   rO  Z
ReduceMeanZ	ReduceMax)r   )r   Zexport_trainingr	   Z_onnx_unsupportedRuntimeErrorre   rf   rg   r   r   ZBOOLr   r  r   r   r   r   r&  r(  r   Z_add_input_to_blockrn   rQ  r'  r   rI  )rW   Zembedding_matrixr   offsetsZscale_grad_by_freqr   sparseZper_sample_weightsZinclude_last_offsetZpadding_idxZloop_conditionzeroZindices_lenZoffsets_startsZoffsets_endsZloop_lenloopZloop_contextr   Z
loop_blockZblock_input_iterr-  Zindices_startZindices_endZindices_rowZ
embeddingsZper_sample_weights_rowZcond_outrL   rL   rQ   r      s    
                  

         
   
  zaten::embedding_renormc              	   C  s   |  d|}|  d||}t|}|dkr0d}n"|dkr>d}ntd| d|| j ||dgdd	}|  d
|| j dtdd}	t|}|  d||	}
|  d||
}|  d|  d||||}|  d|t| |dg|S )Nr   rz   r   ZReduceL1r   ZReduceL2z8Unsupported: ONNX export of embedding_renorm with norm: z. Only 1. and 2. are supported.rO  r   r[   gHz>r^   r8  r   ZWhereZGreaterr   )re   r   r   r   rf   rg   r	   r   )rW   weightr   Zmax_normZ	norm_typeZunique_indicesZpartial_weightZnorm_iZpartial_weight_normZpartial_weight_norm_scalesZpartial_weight_renormrL   rL   rQ   r!   {  s@    
  
zaten::chunkc              
   C  s   | j d|  d||dd}|  d|| j dtjdgtjdd	}|  d
|  d|||}t| ||d |  d||  d||g}| j d|ddi}t| |||S )Nrz   r   r   r{   r   r[   r   r\   r^   r8  r   r   r   r|   )r   )re   rf   rg   r   rh   r   rB   )rW   rX   chunksr~   r+  Zchunk_size_s
chunk_sizeZ	chunk_vecrL   rL   rQ   r     s      zaten::normalc	           
      C  sD   |d k	r"t |s"t| ||d }t| || d|}	t| |	|S )NZRandomNormalLike)r	   rn   rh   r   r   re   r   )
rW   ZmeanZstdsizes	generatorr]   ZlayoutZdeviceZ
pin_memoryr   rL   rL   rQ   r3     s    zaten::atleast_1dztorch._C.Valuer  c              
   C  s   t |rzt |rzt |}g }|D ]D}|}t |}|dkr`t | || jdtdgd}|	| q&| jd| S t |}|dkrt | || jdtdgd}|S )Nr   r[   r   r^   SequenceConstruct)r^  )
r	   r   r   r   rs   r   re   rf   rg   r   rW   rX   r   Znew_tensor_listrg   Z
new_tensorZtensor_rankrL   rL   rQ   r     s,    

  
  zaten::atleast_2dc                 C  s   t |rt |rt |}g }|D ]b}|}t |}|dkrdt | || jdtddgd}n|dkr~t j	| |dgd}|
| q&| jd| S t |}|dkrt | || jdtddgd}n|dkrt j	| |dgd}|S )Nr   r[   r   r^   rP  r^  )r^  r	   r   r   r   rs   r   re   rf   rg   r   r   r_  rL   rL   rQ   r     s<    

    
  zaten::atleast_3dc                 C  sR  t |rt |rt |}g }|D ]}|}t |}|dkrft | || jdtdddgd}nH|dkrt j	| |dgd}t j	| |dgd}n|dkrt j	| |dgd}|
| q&| jd	| S t |}|dkrt | || jdtdddgd}nL|dkr2t j	| |dgd}t j	| |dgd}n|dkrNt j	| |dgd}|S )
Nr   r[   r   r^   r`  r   r   r^  )r^  ra  r_  rL   rL   rQ   r   
  sX    

        

  

zprim::ConstantChunkc              
   C  s  |  d|}| j dtj|gtjdd}| j d||dd}| j dtjdgtjdd}| j dtj|gtjdd}| j dtj|d gtjdd}	|  d	||	}
|  d
|
|}g }t|D ]N}| j dtj|d gtjdd}|  d||}||  d|||| |}q|S )Nr   r[   r\   r^   rz   r   r{   r   r   r8  r   r   )re   rf   rg   r   r   r   )rW   rX   rZ  r~   Zinput_shaper   Zinput_shape_dimr   r[  Zchunk_size_minus_1Zinput_shape_dim_shiftZ	chunk_dimr   ry   r(   r   rL   rL   rQ   r7   7  s$      zaten::hstackrW   r   c              
   C  s   t | |}| d|| jdtjdtjdd}| d|}| d|}| jdtjdtjdd}| d	||}tj| d
|ddd\}\}}	}
|jd|ddd}t|j	| |	jd|ddd}t|	j	| |
  }|S )Nr   r[   r   r\   r^   r   r   r   r!  r"  r   )r$  r   r   r   )r   re   rf   rg   r   r   r&  r   r'  r(  r   rI  )rW   r   Zfirst_tensorZfirst_tensor_shapeZfirst_tensor_dimr,  Zequal_to_oneZif_op_greaterZif_context_equalZelse_context_equalr   Z	result_ifZresult_elser   rL   rL   rQ   r%   M  s>    
      zaten::vstackc                 C  s   t | |}| jd|dddS )Nr   r   r   )r   re   rb  rL   rL   rQ   rI   n  s    
)F)F)N)N)N)N)N)r   )N)N)r   N)N)N)N)NNNNNN)k__doc__
__future__r   	functoolsr   r)  typingr   r   rf   r   Ztorch._Cr   r   Z
torch.onnxr   r   r	   r
   r   r   rh   r   Ztorch.onnx._globalsr   Ztorch.onnx._internalr   r   r   __all__partialZonnx_symbolicZ_onnx_symbolicrS   Zquantized_args
parse_argsrq   r$   r   r   r   r9   r>   r*   r5   r   r   r#   r=   r   r0   r/   r   r   r   r   r   r+   r6   r   r   rD   r   rG   rE   r@   r   r<   r:   rB   rA   rF   r  r   r8   r;   r4   r,   r.   r   r  r?   rC   rH   r1   r(   r'   r)   r=  r@  rA  rD  rG  r&   r2   r"   r-   r    r!   r   r3   r   r   r   r7   r%   rI   rL   rL   rL   rQ   <module>   s   <	#


"

	
$9G

2
  +3^%      +