U
    9%eG                  	   @   s  d dl Z d dlZd dlm  mZ d dlmZmZm	Z	m
Z
mZmZmZ d dlmZmZmZ e jejddZdd Zede
d	d
dejdkejdddZede
d	d
dejdlejdddZede
d	d	d
ejdmejdddZede
d	d	d
d
ejdnejdddZedejdoejdddZedejdpejdddZ edejdqejdd d!Z!ed"e
d	d	d
d
ejdrejdd#d$Z"ed%e
d	d
d
ejdsejdd&d'Z#ed(ejdtejdd)d*Z$ed+e
d	d	d	d
ejduejdd,d-Z%ed.e
d	d	d	d
d
d
ejdvejdd1d2Z&ed3e
d	d	d	d
d
ejdwejdd4d5Z'ejd6d7 Z(ed8ed9d:gd;ejd<d= Z)ed>ejejdd?d@Z*edAe
d	d
d
d
ejdxejddBdCZ+edDejejddEdFZ,edGejdyejddHdIZ-edJe
d	d
d
d
ejejddKdLZ.edMejejddNdOZ/edPejejddQdRZ0edSejejddTdUZ1edVejejddWdXZ2edYejejddZd[Z3ed\ejejdd]d^Z4ed_ejejdd`daZ5edbejejddcddZ6edeejejddfdgZ7edhejejddidjZ8dS )z    N)
_constants_type_utilserrorssymbolic_helpersymbolic_opset11symbolic_opset9utils)	_beartype	jit_utilsregistration   )Zopsetc                     s    fdd}|S )z_Returns a decorator that calls the decorated (higher-order) function with the given parameters.c                    s
   |  S N )fnargskwargsr   Z/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/torch/onnx/symbolic_opset13.py_apply   s    z_apply_params.<locals>._applyr   )r   r   r   r   r   r   _apply_params   s    r   zaten::softmaxvinonegc                 C   sP   | j d||d}|rL|  dkrLt|dd}| j d|t| d}|S )NZSoftmaxaxis_iprim::Constantr   dtypeCastZto_iopnodekindr   
_get_constr   JitScalarType	onnx_type)r   inputdimr   softmaxparsed_dtyper   r   r   r*   !   s      r*   zaten::log_softmaxc                 C   sP   | j d||d}|rL|  dkrLt|dd}| j d|t| d}|S )NZ
LogSoftmaxr   r   r   r   r   r    r!   )r   r(   r)   r   Z	return_opr+   r   r   r   log_softmax/   s      r,   zaten::frobenius_normFc                 C   s^   t |d}t |s2t|dkr2| jd|ddS | d||}t j| |||d}| d|S )Nisr   ZReduceL2
keepdims_iMulZSqrt)r   Z_maybe_get_constZ	_is_valuelenr"   _reducesum_helper)r   selfr)   keepdimZdim_valZsqrZsumsqrr   r   r   frobenius_norm<   s    r5   zaten::splitc              
      s  t ||s jd|||d|d kr*S t |rtt ||krڇ fddt |D } jdtjdgtjdd} jdtj|gtjdd}g }t	|D ]2}	 d	|||	 }
|
 d
|||
| |
}q|S  fddt	|D S t | d}| dkr$ jd||||dS t |dd}t ||}|d krh|d k	r\|| }ntd||g||  }|| }|r|
|  jdt|d} jd||||dS )NSplitToSequencer   c                    s   g | ]}t  |d gqS )r   )r   _unsqueeze_helper).0r   r   r   r   
<listcomp>U   s   zsplit.<locals>.<listcomp>Constantr   r   Zvalue_tAddSlicec                    s2   g | ]*}  d  j dtj|gtjddqS )
SequenceAtr:   r;   r<   r"   torchtensorlong)r8   r   r   Z	split_outr   r   r9   d   s   valueSplitr   outputsr   
split_size$Unknown dimension size not supported)r   _is_split_staticr"   Z_is_packed_listr1   Z_unpack_listrA   rB   rC   rangeappend	_node_getr#   r)   r%   _get_tensor_dim_sizer   SymbolicValueError)r   r3   split_size_or_sizesr)   _outputssplit_sizesstartaxisresr   end	split_valrI   sizesplitsleftoverr   rD   r   splitH   sX    
  	


 
r\   zaten::split_with_sizesc                 C   s   t | ||||S r   r\   r   r3   rS   r)   rR   r   r   r   split_with_sizes   s    r_   zaten::unsafe_splitc                 C   s   t | ||||S r   r]   )r   r3   rQ   r)   rR   r   r   r   unsafe_split   s    r`   zaten::unsafe_split_with_sizesc                 C   s   t | ||||S r   )r_   r^   r   r   r   unsafe_split_with_sizes   s    ra   zaten::tensor_splitc           "      C   s&  | j dtj|tjdd}t| |d}| j dtjdtjdd}t||rt|	 d}|
 dkr| j dtjdgtjdd}g }	|d k	stt|d D ]J}
| j d|| j dtj|
gtjdddd}|	|  d	|||| |}qt| ||}|	|  d	|||| |	S t|d
d}t||}|d kr`|d k	rT|| }ntd||| }|| }||d g }|| |g }| j dtj|| tjdd}| j d||||dS t|rxt|dkrxt| || j dtdd}t| |d}| j d|tjjd}| j dtjdgtjdd}| j d||dd}|  d}tj| d|||ddd\}\}}|j}t|}t|}t|}|j d||dd}|j d|| d||dd}| d	||||}| d||}| d|}t|| t|| |	  }| j d|| j dtjdtjdddd}t| |d}t| ||}|  d	||||}|  d||S t| ||} |  d| |}|  d||}!|  d| |}|  d|!|}|  d||  dt| |d|}| j d||dd}|d kr| j d|||dS | j d||||dS d S )Nr:   r;   r<   r      rE   ZGatherr   r>   r   indices_or_sectionsrJ   rF   rG   r   r    ConcatSequenceEmptyLoop)rH   n_blocksr=   SequenceInsertIdentityDivModTileSubr6   )r"   rA   rB   rC   opset11	unsqueezer   rK   rN   r#   r)   AssertionErrorrL   rM   _size_helperr%   rO   r   rP   Z
_is_tensor_get_tensor_rank_C_onnxTensorProtoDataTypeBOOLr
   add_op_with_blocksblockr   _add_input_to_block_add_output_to_blockoutput)"r   r3   rc   r)   rR   rU   Zconst_1rX   rT   rV   r   rW   rI   rY   Zmin_split_sizeZnum_splits_one_extrarZ   r[   loop_lenloop_conditionZ	padding_0final_splitslooploop_context_
loop_blockblock_input_itercondslicecond_outloop_outZ
last_sliceZdim_sizeZmin_split_size_plus_1r   r   r   tensor_split   s      


    
      


   

r   zaten::unbindc              	      s   |d kr2j d|j dtjdtjdd ddS j dtdg| d}j d|| |d	}|dkrn|gn|} fd
d|D }|S )Nr6   r:   rb   r;   r<   r   r   r/   rF   rG   c                    s,   g | ]$} d |j dt gdqS )ZSqueezer:   r<   )r"   rA   rB   )r8   outr)   r   r   r   r9   3  s   zunbind.<locals>.<listcomp>r@   )r   r3   r)   rR   rZ   rH   Zsqueezed_outputsr   r   r   unbind#  s    r   zaten::nonzero_numpyc                 C   s   t | t| |d|dS )Nrb   )rR   )r   opset9nonzero)r   r(   rR   r   r   r   nonzero_numpy:  s    r   zaten::wherec              	   C   sb   t |s| jd|tjjd}|d krRt| |}t | || jdt	
dd|S | d|||S )Nr   r    r:   rb   r<   ZWhere)r   Z_is_boolr"   rt   ru   rv   r   r   Z_unbind_helperrA   rB   )r   	conditionr3   otherrR   r   r   r   whereA  s    
   r   z&aten::fake_quantize_per_channel_affine   c                 C   s   ||fdkr&t d| d| d||dkrD| jd|tjjd}n| jd|tjjd}| jd||||d	}||fd
kr| d|t| | jdt	j
dt	jdd}| jd||||d	S )N)r      )r   r   r   r   VFor (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). Got (, )r   r   r    QuantizeLinearr   r   Clipr:   r   r;   r<   DequantizeLinear)r   rP   r"   rt   ru   UINT8INT8r   unusedrA   rB   uint8)r   inputsscale
zero_pointrU   	quant_min	quant_max	quantizedr   r   r    fake_quantize_per_channel_affineP  s"    r   z%aten::fake_quantize_per_tensor_affinec                 C   s   ||fdkr&t d| d| d||dkrD| jd|tjjd}n| jd|tjjd}tj	|tjj
tjjkr| jd|tjjd}| d|||}||fd	kr| d
|t| | jdtjdtjdd}| d|||S )Nr   r   r   r   r   r   r    r   r   r   r:   r   r;   r<   r   )r   rP   r"   rt   ru   r   r   r   r&   
from_value	UNDEFINEDFLOATr   r   rA   rB   r   )r   r   r   r   r   r   r   r   r   r   fake_quantize_per_tensor_affinet  s,    r   c                    s   t jd fdd	}|S )Nc                    sF   t | |}|d kr"t| | S t|dd}| j |||dS d S )Nr   r4   r.   )r   Z_maybe_cast_reduce_op_inputr   Z_handle_reduce_dim_noner%   r"   )r   r3   r)   r4   onnx_op_namer   r   symbolic  s
    z%_reduce_op_symbolic.<locals>.symbolic)NN)r	   beartype)r   r   r   r   r   _reduce_op_symbolic  s    	r   z	aten::sumZ	ReduceSumsum)Zdecoratec                    s&   t | tjtj fdd}|S )Nc                    sL   t ddtj fdd}t ddddtj fdd}||fS )Nr   r   c                    s   d }|   dkrBt|dd}t| }| jd||d}n|   dkr`t d|S | |}|d k	rtj	| }||kr| jd||d}|S Nzonnx::Constantr   r   r   r    r   
r#   r$   r   r%   r   r&   r'   r"   _unimplementedr   )r   r3   r   
dtype_onnxresultresult_dtype_onnxnamer   r   r   reduce_nodim  s    
z8_reduce_with_dtype.<locals>.reduce.<locals>.reduce_nodimr   c                    s   d }|   dkrBt|dd}t| }| jd||d}n|   dkr`t d|S | |||}|d k	rtj	| }||kr| jd||d}|S r   r   )r   r3   r)   r4   r   r   r   r   r   r   r   
reduce_dim  s    z6_reduce_with_dtype.<locals>.reduce.<locals>.reduce_dim)r   
parse_argsr	   r   )r   r   r   r   r   r   r   r   reduce  s    
z"_reduce_with_dtype.<locals>.reduce)r   r   Zoverload_by_arg_countr	   r   )Zonnx_opr   r   r   r   r   _reduce_with_dtype  s
    )r   zaten::unflattenc              
   C   sP  t |}|d krt ddS | jdtj|gtjdd}| d||}| d||}| d|}| jdtjd	gtjdd}| d
|| jdtjdgtjdd}| d|||}| d|| jdtjdgtjdd}	| d
|	| jdtjdgtjdd}
| jdtjtjgtjdd}| d||
|}| jd|||d	d}t 	| ||S )Nr)   zfONNX and PyTorch use different strategies to split the input. Input rank must be known at export time.r:   r;   r<   r=   rl   Shaper   Reshaperb   r>   rd   r   )
r   rs   r   r"   rA   rB   int64r   Z	INT64_MAX_reshape_helper)r   r(   r)   Zunflattened_sizeZ	input_dim
input_sizeZhead_start_idxZhead_end_idxZhead_part_rankZdim_plus_oneZtail_start_idxZtail_end_idxZtail_part_rankZfinal_shaper   r   r   	unflatten  sP    
         r   zaten::unsafe_chunkc           	   	   C   s   |d kr2| j d|| j dtjdtjdd|ddS t||}|d krRtdd	S || d | }|g||  }|| }|r|| | j dtj|tjdd}| j d
||||dS )Nr6   r:   rb   r;   r<   r   r   unsafe_chunkzunknown dimension sizerF   rG   )r"   rA   rB   rC   r   rO   r   rM   )	r   r3   chunksr)   rR   rY   rI   rZ   r[   r   r   r   r     s$    
r   z
aten::tilec                 C   s  |  d|}|  d|}|  d|}|  d||}| j dtdgd}|  d||}tj| d|d	d
d\}	\}
}}|
j dtd
gd}|
 d||}|
 d||}|
j d||dd}t|
j| | d|}t|j| |		 
 }|  d||}tj| d|d	d
d\}\}}}|j dtd
gd}| d| d||}| d||}|j d||dd}| d||}t|j| | d|}t|j| |	 
 }| j d|tjjd}|  d||S )Nr   Sizern   r:   r   r<   ZGreaterIf   rb   )rg   rH   r   Expandrd   r   ri   ZLessZAbsr   r    rm   )r"   rA   rB   r
   rw   
LongTensorr   rz   rx   r#   r{   rt   ru   ZINT64)r   r3   ZdimsZ
self_shapeZ	self_rankZ	dims_rankdiffZ
const_zeroZdims_shorter_than_self_shapeZif_op_greaterZif_context_greaterZelse_context_greaterr   Z	const_oneZdiff_1d_greaterZexapnd_ones_greaterZdims_Zidentity_dimZ
dims_finalZdims_longer_than_self_shapeZ
if_op_lessZif_context_lessZelse_context_lessZdiff_1d_lessZexapnd_ones_lessZself_final_shapeZself_Zidentity_selfZ
self_finalr   r   r   tile,  sr            
   r   zaten::repeat_interleavec           "   	   C   s  |}|}t |rDt | || jdtdgd}tjdtjd}n
t |}t |}t 	|}t 	|}	|d krt
d||d krt
d||	d krt
d||dk r|t|	7 }|	 }
t|	D ] \}}|d krd	\|
|< |	|< q|dks|d
kr |d d
kr t | |||S |d
ko4|d d k}|
| dksJ|rt | ||}t| |d}|rt | || jdtdgd}| d|| jdtd
gd}t| || d|||}nt| |||S | jd| d|tjd
gtjdd}t| ||d}t| |||}d\|
|< |	|< | jdtd
d}| jd|tjjd}|}| d}tj| d|||d
d\}\}}|j}t |}t |}t |}|d||}|d||}t|||d
 }|jdt|	d |d
  d||jdt|	|d
 d  dg}|jd|ddi}t!|||d }t ||| jdt|
d}|d||}|jd|tjjd} t"||  t"|| |# $ }!| jd|!|d}!|!S )Nr:   rj   r<   r   r;   zGUnsupported: ONNX export of repeat_interleave for unknown repeats rank.zGUnsupported: ONNX export of repeat_interleave for unknown repeats size.zEUnsupported: ONNX export of repeat_interleave for unknown input size.)r   rj   rb   Equalr   ZConstantOfShaper   )rj   rb   r   r    re   rf   rg   r?   rd   r   rh   ZConcatFromSequencer   )rd   )%r   Z_is_noner   r"   rA   rB   r   Z_maybe_get_scalarrs   Z_get_tensor_sizesr   rP   r1   copy	enumerateZ-_repeat_interleave_single_value_repeat_helperrr   ro   rp   r   r   r   repeat_interleaverC   r\   rt   ru   rv   r
   rw   rx   r   ry   expandrz   r#   r{   )"r   r3   Zrepeatsr)   Zoutput_sizer(   Z	final_dimZrepeats_dimZrepeats_sizesZinput_sizesZoutput_sizesidxr   Zcond_dynamic_repeatsZrepsZ
repeat_dimZrepeat_condZ	reps_likeZr_splitsZi_splitsr}   r|   r~   r   r   r   r   r   r   Zr_splitZi_splitZr_concatr   r   r   r   r   r   d  s    
  



"       

     


    r   zaten::diagonalc                    s  t j | jdt|gdd}t j | jdt|gdd} jd||dd}t  |d d d } jd||d}t|}	|	d k	rtt	|	}
|

| |

|  jd	||
||g d
}ntddS  d||tj dgdd jdt|gd}|dkrL d d| d|| jdtdgd}d}n4 d d d||| jdtdgd} jd|dd}t  |dd d } d| jdtdgd} d| jdtt|d gd} fddtt	|	d d D }||  jd$|ddi}t  |dd d } d d| jdtjdtjdd}tj d|dd \}\}}}|d||}t|||	d g}|jd!||	d d"}t ||d#d d }t|j| t|j| |S )%Nr:   r<   r)   rd   r   r   ZEyeLike)Zk_iZ	Transpose)Zperm_idiagonalzunknown input rankr0   rj   )Zaxes_ir/   ZMaxZMinrn   r=      ZCumSumrb   c              
      s.   g | ]&}t j  jd t|gddqS )r:   r<   r   )r   rY   r"   rA   r   )r8   rU   r   r   r   r   r9   ;  s   zdiagonal.<locals>.<listcomp>r   Notr   r;   r   r   r   ZGatherND)Zbatch_dims_i   )rd   )r   rY   r"   rA   r   Zzerosr   rs   listrL   remover   r2   ZonesabsrM   rB   r   r
   rw   r7   r   rz   rx   )r   r3   offsetZdim1Zdim2Z	dim1_sizeZ	dim2_sizeZ
mask_shapemaskZrankZaxesZ	offset_opZ	diag_sizeZselect_window_ones_fillZselect_windowZgather_shapeZgather_indicesZoverrun_condZif_opZ
if_contextZelse_contextr   Zgather_indices_if_blockZfinal_non_overrunZfinal_overrunr   r   r   r     s        



	
	        r   zquantized::linearc                 C   sn   t | |\}}}}t | |\}	}
}}t | |||
|}t | |\}}}}t| ||	|}t | |||S r   )r   dequantize_helperrequantize_bias_helperr   Zlinearquantize_helper)r   q_inputq_weightbiasop_scaleop_zero_pointr(   input_scaler   weightweight_scalerU   q_biasr{   r   r   r   quantized_linearj  s        r   zquantized::conv1d_reluc
              
   C   s   t | |\}
}}}t | |\}}}}t | ||||}t | |\}}}}t| |
||||||}t| |}t | |||	S r   )r   r   r   r   conv1drelur   r   r   r   r   stridepaddingdilationgroupsr   r   r(   r   r   r   r   rU   r   r{   r   r   r   quantized_conv1d_relu{  s        r   zquantized::conv2d_reluc
              
   C   s   t | |\}
}}}t | |\}}}}t | ||||}t | |\}}}}t| |
||||||}t| |}t | |||	S r   )r   r   r   r   conv2dr   r   r   r   r   r   quantized_conv2d_relu  s        r   zquantized::conv3d_reluc
              
   C   s   t | |\}
}}}t | |\}}}}t | ||||}t | |\}}}}t| |
||||||}t| |}t | |||	S r   )r   r   r   r   conv3dr   r   r   r   r   r   quantized_conv3d_relu  s        r   zquantized::conv1dc
              
   C   sv   t | |\}
}}}t | |\}}}}t | ||||}t | |\}}}}t| |
||||||}t | |||	S r   )r   r   r   r   r   r   r   r   r   r   quantized_conv1d  s        r   zquantized::conv2dc
              
   C   sv   t | |\}
}}}t | |\}}}}t | ||||}t | |\}}}}t| |
||||||}t | |||	S r   )r   r   r   r   r   r   r   r   r   r   quantized_conv2d  s        r   zquantized::conv3dc
              
   C   sv   t | |\}
}}}t | |\}}}}t | ||||}t | |\}}}}t| |
||||||}t | |||	S r   )r   r   r   r   r   r   r   r   r   r   quantized_conv3d   s        r   zquantized::conv_transpose1dc                 C   sx   t | |\}}}}t | |\}}}}t | ||||}t | |\}}}}t| ||||||||	}t | ||	|
S r   r   r   r   r   Zconv_transpose2dr   r   r   r   r   r   r   Zoutput_paddingr   r   r   r   r(   r   r   r   r   rU   r   r{   r   r   r   quantized_conv_transpose1d  s,                r   zquantized::conv_transpose2dc                 C   sx   t | |\}}}}t | |\}}}}t | ||||}t | |\}}}}t| ||||||||	}t | ||	|
S r   r   r   r   r   r   quantized_conv_transpose2d7  s,                r  zquantized::conv_transpose3dc                 C   sx   t | |\}}}}t | |\}}}}t | ||||}t | |\}}}}t| ||||||||	}t | ||	|
S r   )r   r   r   r   Zconv_transpose3dr   r   r   r   r   quantized_conv_transpose3dT  s,                r  )N)N)NF)N)N)N)N)N)r   N)N)NNN)r   r   )r   r   )N)NN)9	functoolsrA   Ztorch._C._onnxZ_CZ_onnxrt   Z
torch.onnxr   r   r   r   r   ro   r   r   r   Ztorch.onnx._internalr	   r
   r   partialZonnx_symbolicZ_onnx_symbolicr   r   r   ZGraphContextr*   r,   r5   r\   r_   r`   ra   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r   r   r   r   <module>   s  $		
	7    	  !  $

4*6    u