U
    9%e                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlZd dlZd dl	Z	d dl
Z
d dlZd dlZd dlZd dlZd dlZd dlmZmZ d dlmZmZmZmZmZmZmZmZ d dlmZ zd dlZW n ek
r   dZY nX d dlZd dl Zd dl!m"Z"m#Z#m$Z$ d dl%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+ d dl,m-Z-m.Z. d dl/m0Z0m1Z1 d d	l2m3Z3m4Z4 d
dl5m6Z6m7Z7m8Z8 d
dl9m:Z:m;Z; d
dl<m=Z= d
dl>m?Z? d
dlm@Z@mAZAmBZB d
dlCmDZDmEZEmFZFmGZGmHZHmIZImJZJmKZKmLZLmMZM e	NeOZPejQReOdZSejQReOdZTejUjVjWjXZXejUjVjWjYZYejUjVjWjZZZe[ddd Z\e]deZfdeYfdej^fdej_fdejUj`fdej]jafdeFfdeDfdeMfdeLfde
jbfd ecd fd!d"d# fd$ejCjdfd%ejefd&ejffgZgejhdd' d(krz"d dliZie jjekd)d*d+Zld,ZmW n enk
r   d-ZmY nX nd,Zme jjekd)d.d+Zld/d0 Zod1d2 ZpejqG d3d4 d4ZrG d5d6 d6e(ZsG d7d8 d8ZtG d9d: d:Zueekekf d;d<d=ZvdaweBejxeekeyf eze{dd>d?d@Z|eBejxeekeyf eze{dAdBdCZ}e:e} dDdE Z~dFdG ZdS )H    N)currentframegetframeinfo)AnyCallableDictListOptionalTupleTypeUnion)ReferenceType)is_from_local_sourceTensorPropertyTensorPropertySource)DuplicateInputsGuardGuardBuilderBaseGuardEnvExprGuardSourceSource)EqualityConstraintSYMPY_INTERP)format_framereport_compile_source_on_error)TensorWeakRef	WeakIdRef   )configconvert_framemutation_guard)set_guard_error_hookset_guard_fail_hook)unimplemented)
TypeSource)GuardedCode	GuardFailGuardFn)
dict_const_keysdict_const_keys_reprdict_param_key_idsguard_failures"is_guard_failure_reporting_enabledistypeorig_code_maptensor_always_has_static_shapetuple_iterator_getitemtuple_iterator_lenguardsZverbose_guardsc                  C   s    dd l } | jjg}dd |D S )Nr   c                 S   s   h | ]}t |qS  )inspectgetfile).0mr2   r2   S/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/torch/_dynamo/guards.py	<setcomp>R   s     z&uninteresting_files.<locals>.<setcomp>)Ztorch._dynamo.external_utils_dynamoZexternal_utils)torchmodsr2   r2   r7   uninteresting_filesK   s    r<   Z___check_type_idZ___check_obj_idZ___is_grad_enabledZ'___are_deterministic_algorithms_enabledZ___is_torch_function_enabledZ___odict_getitemZ___dict_param_key_idsZ___dict_const_keysZ___tuple_iterator_lenZ___tuple_iterator_getitemZ__math_isnaninfZ__load_modulec                 C   s
   t | S N)	importlibimport_modulenamer2   r2   r7   <lambda>f       rC   Zutils_devicedeviceZ__as_tensor   )      nodereturnc                 C   s   t | ddS N
 )
astunparseunparsereplacerJ   r2   r2   r7   _ast_unparseu   s    rS   TFc                 C   s   t | ddS rL   )astrP   rQ   rR   r2   r2   r7   rS      s    c                 C   sb   t d}d}| D ]F}|dkr$d}q|dkrP|rX|dkrX||rXt|  S q||7 }qt| S )z
    "___odict_getitem(a, 1)" => "a"
    "a.layers[slice(2)][0]._xyz" ==> "a"
    "getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a"
    "getattr(getattr(a.x[3], '0'), '3')" ==> "a"
    "a.layers[slice(None, -1, None)][0]._xyz" ==> "a"
    z[A-Za-z_].*rN   z (z),[]None)recompilematchstrip_function_callstrip_getattr_getitem)rB   Z
valid_namecurrcharr2   r2   r7   rY      s    	

rY   c                 C   s   t d| d S )z*
    "a[1]" => "a"
    "a.foo" => "a"
    z[.\[]r   )rV   splitrA   r2   r2   r7   rZ      s    rZ   c                   @   s"   e Zd ZU ee ed< eed< dS )GuardCodeList	code_listguardN)__name__
__module____qualname__r   str__annotations__r   r2   r2   r2   r7   r^      s   
r^   c                   @   s  e Zd Zeee gef eegef ee	eef  de
dddZeedddZeeef edd	d
ZedddZedddZedddZedddZedddZedddZedddZedddZedddZedddZedd d!Zedd"d#Zd$d% Zd&d' Zd(d) Zd*d+ Z d,d- Z!d.d/ Z"d0d1 Z#edd2d3Z$edd4d5Z%edd6d7Z&edd8d9Z'edd:d;Z(edd<d=Z)dDedd?d@Z*dEdBdCZ+d>S )FGuardBuilderCheckFunctionManager)id_ref
source_ref
user_scopecheck_fn_managerlocalc          	      C   s   || _ || _|| _|r(|rdnd|i}n|r0dndt i}|| _tj | jd< tj	j
j D ]<\}}|dddddd}|| jd |< || j|< q^g | _g | _g | _g | _g | _g | _|| _d S )	NLG__builtins__>_<.Z_dot_)rl   rh   ri   dictscopebuiltins__dict__copyr:   packageZpackage_importerZ_package_imported_modulesitemsrQ   argnamescodeshape_env_codetensor_check_namestensor_check_examplestensor_check_guardsrk   )	selfrh   ri   rj   rk   rl   ru   rB   Zpackage_moduler2   r2   r7   __init__   s,    	zGuardBuilder.__init__)rB   rK   c                 C   s   t || jtS r>   )evalru   CLOSURE_VARS)r   rB   r2   r2   r7   get   s    zGuardBuilder.get)r`   rK   c                 C   s`   t |tr|}n|j}tt|}|| jkr\td|r\td|rPt	d| | j
| |S )Nz[a-zA-Z0-9_]+z^\d+$zinvalid var name: %s)
isinstancerd   rB   rZ   rY   r{   rV   rX   logwarningappend)r   r`   rB   baser2   r2   r7   arg_ref   s    

zGuardBuilder.arg_ref)r`   c                 C   sD   t | |j}| |}d| | d| d}| ||g d S )N___check_type_id(, ))typer   rB   rh   r   _produce_guard_code)r   r`   tobj_idr|   r2   r2   r7   
TYPE_MATCH   s    
zGuardBuilder.TYPE_MATCHc                 C   s&   |  |}d| }| ||g d S )Nznot )r   r   )r   r`   refr|   r2   r2   r7   
BOOL_FALSE  s    

zGuardBuilder.BOOL_FALSEc                 C   s^   t |jtr&| t|jj|jtjS d| | d| 	| 
|j d}| ||g d S )Nz___check_obj_id(r   r   )r   Zoriginating_sourcer#   r   r   r   sourcerf   r   rh   r   rB   r   r   r`   r|   r2   r2   r7   ID_MATCH  s      &zGuardBuilder.ID_MATCHc                 C   s6   |  |j}| | d|j d}| ||g d S )Nz.__name__ == '')r   rB   r   ra   r   r   r`   objr|   r2   r2   r7   
NAME_MATCH$  s    zGuardBuilder.NAME_MATCHc                 C   s6   |  |j}| | d|  }| ||g d S )Nz.data_ptr() == )r   rB   r   Zdata_ptrr   r   r2   r2   r7   DATA_PTR_MATCH)  s    zGuardBuilder.DATA_PTR_MATCHc                 C   s   t d|j}|s"td|j |dd\}}| |}t| ||}d }|rhd| d|d}nd| d|d}| j||g| |d	 d S )
Nz^(.*)[.]([a-zA-Z0-9_]+)$zinvalid hasattr check r   rF   hasattr(r   r   znot hasattr()provided_guarded_object)	rV   rX   rB   AssertionErrorgroupr   hasattrr   r   )r   r`   r6   r   attrr   valr|   r2   r2   r7   HASATTR.  s    
zGuardBuilder.HASATTRc           	         s   |  |}| |j}t|}trTtjtjtjtjtj	tj
tjtjtjtjtjf}nd}ttttd tttttttttjtjtjf| t|trt  fddt!"|# |$ D st%nt| st%|j&t|tjtjfrd| dt|g}| '|| d S t|trjt()|rjt }|*d| d| +| d |*d	| d | '|| d S t }t|ttfr| ,| t-|D ]2\}}|*d| d
| d| +t| d qn|*d| d| +| d t|tjrt|}|*| d| | '|| d S )Nr2   c                 3   s   | ]}t | V  qd S r>   )r,   )r5   xZok_typesr2   r7   	<genexpr>c  s    z,GuardBuilder.EQUALS_MATCH.<locals>.<genexpr>str() == r   r   r   z__math_isnan([z],  == ).r   r   rB   r   npZint8Zint16Zint32Zint64Zuint8Zuint16Zuint32Zuint64Zfloat16Zfloat32Zfloat64intfloatboolrd   listtuplesetslice	frozensetranger:   SizerE   dtyper,   rt   all	itertoolschainkeysvaluesr   ra   r   mathisnanr   rh   LIST_LENGTH	enumerate)	r   r`   r   r   r   Znp_typesr|   idxelemr2   r   r7   EQUALS_MATCH<  s    


 
zGuardBuilder.EQUALS_MATCHc                 C   s8   |  |j}t|ttd fr*| | n
| | d S r>   )r   rB   r,   r   r   r   r   )r   r`   r   r2   r2   r7   CONSTANT_MATCH  s    zGuardBuilder.CONSTANT_MATCHc                    sZ        j fdd}tdrD|  ntdt  d S )Nc                      s4   t jtstjt dj g  d S )Nz.training == )r,   trainingr   r   r|   r   r^   r2   r`   r   r   r   r2   r7   setup_guard  s    z+GuardBuilder.NN_MODULE.<locals>.setup_guardr   z$Guard setup for uninitialized class )r   r   r   rB   r   r"   r   )r   r`   r   r2   r   r7   	NN_MODULE  s    


zGuardBuilder.NN_MODULEc                 C   s   |  r| |S dS )z0things like torch.add and user defined functionsN)is_localr   r   r`   r2   r2   r7   FUNCTION_MATCH  s    zGuardBuilder.FUNCTION_MATCHc                 C   s
   |  |S r>   r   r   r2   r2   r7   BUILTIN_MATCH  s    zGuardBuilder.BUILTIN_MATCHc                 C   s
   |  |S r>   r   r   r2   r2   r7   PYMODULE_MATCH  s    zGuardBuilder.PYMODULE_MATCHc                 C   sl   |  |}| |j}t|}t }|d| d| | d |d| dt|  | || d S )Nr   r   r   zlen(r   )	r   r   rB   r   r   r   rh   lenr   r   r`   r   valuer   r|   r2   r2   r7   r     s    
zGuardBuilder.LIST_LENGTHc                 C   sl   |  |}| |j}t|}t }|d| d| | d |d| dt|  | || d S )Nr   r   r   z___tuple_iterator_len(r   )	r   r   rB   r   r   r   rh   r0   r   r   r2   r2   r7   TUPLE_ITERATOR_LEN  s    
zGuardBuilder.TUPLE_ITERATOR_LENc                 C   s8   |  |}|  | }| d| g}| || d S )N is )r   rB   r   )r   r`   source_bZref_aZref_br|   r2   r2   r7   DUPLICATE_INPUT  s    
zGuardBuilder.DUPLICATE_INPUTc           	      C   s   |  |}| |j}t|}t }|d| d| | d tt|}tt	|}t
|| jd}|r|d| d| |d| d|  n|d| d	|  | || d S )
Nr   r   r   rl   z___dict_param_key_ids(r   z___dict_const_keys(zset(.keys()) == )r   r   rB   r   r   r   rh   r   r)   r'   r(   rl   r   )	r   r`   r   r   r   r|   Zparam_key_idsZ
const_keysZconst_keys_reprr2   r2   r7   	DICT_KEYS  s    
zGuardBuilder.DICT_KEYSc                 C   s   |  || | dg d S )Nz is not None)r   r   r   r2   r2   r7   WEAKREF_ALIVE  s    zGuardBuilder.WEAKREF_ALIVEc                 C   sz   |  |}| |j}t|}dd | D }t }|d| d| | d |d| d| | || d S )Nc                 S   s   h | ]\}}|qS r2   r2   )r5   kvr2   r2   r7   r8     s     z5GuardBuilder.NN_MODULE_PARAM_NAMES.<locals>.<setcomp>r   r   r   z{k for k, v in z.named_parameters()} == )	r   r   rB   r   Znamed_parametersr   r   rh   r   )r   r`   r   r   r   r   r|   r2   r2   r7   NN_MODULE_PARAM_NAMES  s    
z"GuardBuilder.NN_MODULE_PARAM_NAMESc                 C   sp   |  |}| |j}t|}t }|d| d| | d |d| dt|  | 	|| dS )zOrderedDict keys matchr   r   r   r   r   N)
r   r   rB   r   r   r   rh   rd   r   r   r   r2   r2   r7   
ODICT_KEYS  s    
zGuardBuilder.ODICT_KEYSc                 C   s   t | |j| j d S r>   )r   watchr   rB   rk   r   r2   r2   r7   OBJECT_MUTATION  s    zGuardBuilder.OBJECT_MUTATIONc                 C   sD   |j dkst|jtjkstd}tjr.d}nd}| ||g dS )zGuard on the initial grad staterN   Nz___is_grad_enabled()znot ___is_grad_enabled())rB   r   r   r   GLOBALr   Zinitial_grad_stater   r   r2   r2   r7   	GRAD_MODE  s    zGuardBuilder.GRAD_MODEc                 C   s6   |j tjkstd}tjr d}nd}| ||g dS )z1Guard on the initial determinism algorithms stateNz)___are_deterministic_algorithms_enabled()z-not ___are_deterministic_algorithms_enabled())r   r   r   r   r   Z&initial_deterministic_algorithms_stater   r   r2   r2   r7   DETERMINISTIC_ALGORITHMS  s    z%GuardBuilder.DETERMINISTIC_ALGORITHMSc                 C   s6   |j tjkstd }tjr d}nd}| ||g d S )Nz___is_torch_function_enabled()z"not ___is_torch_function_enabled())r   r   r   r   r   Zinitial_torch_function_stater   r   r2   r2   r7   TORCH_FUNCTION_STATE  s    z!GuardBuilder.TORCH_FUNCTION_STATEc                 C   s<   |j tjkstddlm  m} | |d|jg dS )z/Guard on CURRENT_DEVICE per torch.utils._devicer   Nzutils_device.CURRENT_DEVICE == )	r   r   r   r   torch.utils._deviceutils_devicer   ZCURRENT_DEVICE)r   r`   r6   r2   r2   r7   DEFAULT_DEVICE#  s     zGuardBuilder.DEFAULT_DEVICEc                    s  |j dkst| jj  j}dd |D } fdd} jrg } jD ]b}||j|j^}|fdd|D  |j	d k	rF||j	j|j	j}|fdd|D  qFt
|d	d
}nd } jjdd |D dd |D ||| j| jjj d}	 j  |	D ]}
| j||
gdd q d S )NrN   c                 S   s   g | ]
}|j qS r2   )Zconstraint_dimsr5   ar2   r2   r7   
<listcomp>4  s     z*GuardBuilder.SHAPE_ENV.<locals>.<listcomp>c                    s    fddj |  D S )Nc                    s   g | ]}t |tj qS r2   )r   r   ZSIZE)r5   r   dimr2   r7   r   9  s   z?GuardBuilder.SHAPE_ENV.<locals>.get_sources.<locals>.<listcomp>)Ztracked_fakes_id_to_source)t_idr   )output_graphr   r7   get_sources6  s    
z+GuardBuilder.SHAPE_ENV.<locals>.get_sourcesc                 3   s   | ]} |fV  qd S r>   r2   r5   Zother_sourcer   r2   r7   r   D  s    z)GuardBuilder.SHAPE_ENV.<locals>.<genexpr>c                 3   s   | ]} |fV  qd S r>   r2   r   r   r2   r7   r   N  s    F)source_pairsZ	warn_onlyc                 S   s   g | ]
}|j qS r2   )Zfaker   r2   r2   r7   r   X  s     c                 S   s   g | ]
}|j qS r2   r   r   r2   r2   r7   r   Y  s     )constraint_inputsequalities_inputsri   Zignore_staticT)	shape_env)rB   r   rk   r   Ztracked_fakesZexport_constraintsr   r   extendZsharedr   r   Zproduce_guardsri   exportfreezer   )r   r`   fsr   r   r   
constraintZother_sourcesr   r1   Zshape_guardr2   )r   r   r7   	SHAPE_ENV,  sH    

 

	
zGuardBuilder.SHAPE_ENVNc           
   
   C   s  |  r| | n|t|tr&| }|d k	r2|n
| |j}t|tjsNt| 	|}t
 }| jjjr| | ddddg}|D ]d}| |d | }t|tjtjfr|d| d| dt| q|| d| d|  qn$| j| | j| | j| |jd k	stt|d	|jd
\}}	|sxt|drf|d| d|j d| d n|d| d t|dkr| || d S )Nr   rE   requires_gradzndimension()rs   r   r   r   T)Z	is_tensorguard_source_dynamo_dynamic_indicesz((z"._dynamo_dynamic_indices.issubset(z)) if hasattr(z', '_dynamo_dynamic_indices') else True)r   z%, '_dynamo_dynamic_indices') == Falser   )is_nn_moduler   r   r   r   rB   r:   ZTensorr   r   r   rk   r   r   r   r,   rE   r   r   rd   r~   r   r   r   r.   r   r   r   r   )
r   r`   r   Ztensor_namer|   Ztermsterm
real_valueZstaticreasonr2   r2   r7   TENSOR_MATCHd  sN    



"  

zGuardBuilder.TENSOR_MATCHFc                 C   s  t  }|d k	st|j}~|d k	s&tt|d }~|t| jksPtd| |rh| jt|| n| j	t|| |d kr|j
d k	o|j
dk}|r| |j
nd }	n|}	|	d k	rtt|	nd }
d }t|	jdrt|	tjst|	}|||
|| d S )NrF   zH_produce_guard_code must be called from inside GuardedCode. Called from rN   __weakref__)r   r   f_backr   dir	__class__r}   r   r^   r|   rB   r   weakrefr   r   r   r   enumEnumZset_export_info)r   r`   r_   r   r   Z	cur_frameZcaller	func_nameZ
name_validZguarded_objectZguarded_object_typeZobj_refr2   r2   r7   r     sB     
z GuardBuilder._produce_guard_code)N)NF),ra   rb   rc   r   r
   objectrd   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r   r2   r2   r2   r7   rf      sJ   >W
		8e   rf   c                   @   s   e Zd ZdZejejejfZe	j
G dd dZG dd dejZG dd dejZdd	d
dZdeedddZee ddddZeeee ef dddZdS )PyExprCSEPassr   c                   @   s.   e Zd ZU eeef ed< eeef ed< dS )PyExprCSEPass.Config
expr_countexpr_to_nameN)ra   rb   rc   r   rd   r   re   r2   r2   r2   r7   Config  s   
r  c                       s6   e Zd ZdddddZejed fddZ  ZS )	zPyExprCSEPass.ExprCounterr  N)r   rK   c                 C   s
   || _ d S r>   )_config)r   r   r2   r2   r7   r     s    z"PyExprCSEPass.ExprCounter.__init__rI   c                    s4   t |tjr$| jjt|  d7  < t | d S Nr   )r   r  ALLOWED_NODE_TYPESr  r  rS   supervisit)r   rJ   r  r2   r7   r    s    zPyExprCSEPass.ExprCounter.visit)	ra   rb   rc   r   rT   ASTr   r  __classcell__r2   r2   r  r7   ExprCounter  s   r  c                       sD   e Zd Zdeg ef dd fddZejed fddZ	  Z
S )	zPyExprCSEPass.Replacerr  N)r   gen_namerK   c                    s    t    || _|| _g | _d S r>   )r  r   r  	_gen_namepreface)r   r   r  r  r2   r7   r   "  s    
zPyExprCSEPass.Replacer.__init__rI   c                    s   t |tjrt|}| jj| tjkr|| jjkrrt 	|}t|}| 
 }| j| d|  || jj|< n| jj| }t|t S t 	|S )Nz = )r   r  r  rS   r  r  USE_THRESHOLDr  r  r  r  r  r   rT   NameLoad)r   rJ   exprZnode_Zexpr_var_namer  r2   r7   r  ,  s    zPyExprCSEPass.Replacer.visit)ra   rb   rc   r   rd   r   rT   r  r   r  r  r2   r2   r  r7   Replacer!  s
   

r$  NrK   c                 C   s$   d| _ | jtdd i d| _d S )Nr   c                   S   s   dS )Nr   r2   r2   r2   r2   r7   rC   I  rD   z(PyExprCSEPass.__init__.<locals>.<lambda>)r  r  )_counterr  collectionsdefaultdictr  r   r2   r2   r7   r   F  s
     zPyExprCSEPass.__init___var)prefixrK   c                 C   s    | | j  }|  j d7  _ |S r  )r&  )r   r+  rB   r2   r2   r7   _new_varL  s    zPyExprCSEPass._new_var)exprsrK   c                 C   s*   |  | j}|D ]}|t| qd S r>   )r  r  r  rT   parse)r   r-  counterer2   r2   r7   countQ  s    zPyExprCSEPass.countr"  rK   c                 C   s.   |  | j| j}|t|}|jt|fS r>   )r$  r  r,  r  rT   r.  r  rS   )r   r"  replacernew_noder2   r2   r7   rQ   V  s    zPyExprCSEPass.replace)r*  )ra   rb   rc   r  rT   	AttributeCall	Subscriptr  dataclasses	dataclassr  NodeVisitorr  NodeTransformerr$  r   rd   r,  r   r1  r	   rQ   r2   r2   r2   r7   r    s   	%r  c                   @   sJ   e Zd Zdeeeeef gdf  dddZdd Zdd Z	d	d
 Z
dS )rg   N)guard_fail_fnc           	         s  |r
|j nd }d| _i | _|| _dd } fdd}t| j|||j|j| dd}t| j||j| dd}d|jkr|jd |jd< t	
|t	
| t|pg tjd	D ]D}tjs| rd
|jkrd|jkrtjsd|jkrq||| q| ||||| _| j  d S )NTc                 S   s    | d kr|S |d kr| S | |S r>   r2   )leftrightr2   r2   r7   combine_scopesr  s
    z5CheckFunctionManager.__init__.<locals>.combine_scopesc                    sD   |   }|tjkr|  S |   }|d k	s6t||  S r>   )r   r   ZCONSTANTrB   selectr   r   )r   r   ZbuilderZw_globalZw_localr2   r7   ri   {  s    
z1CheckFunctionManager.__init__.<locals>.source_refr   Frn   )key__defaults____kwdefaults__hooks)r1   valid	_weakrefsr   rf   rh   global_scopeZlocal_scoperu   r	  r   sortedr   sort_keyr   Zguard_nn_modulesr   rB   Zskip_nnmodule_hook_guardscreatecompile_check_fnZcheck_fnclear)	r   r   r<  r1   r?  ri   local_builderglobal_builderr`   r2   rA  r7   r   g  s\    		    


	   zCheckFunctionManager.__init__c           )         s  t |jt |j@ rt|j}|dg7 }d|}td dg tjt	}d% fdd	}|j
D ]}	|	jD ]}
||
|	j qlqb|j
D ]}	|	jD ]}
||
|	j qq|j|j }d }d }|r
jjrtd|j|j }d }d }d	d
 fdd|D }fdd|D }t|||d}|j}|j}d|dg } d| d |j|j }t|D ]\}}|| }t|}tj|tj B tj  }|j}|jj}|j }|| }|| }|d| d|j! d| d| d| d| d| d| d|| dd qdjrjj"j#j$ng }|D ]J}t%|t&r`|j'} |j(}!|| )  d|!)  d  nt*d| q&|j+D ] }	|	jD ]}
||
|	j qqx|j+rtt,-dfd|fd|fd|fgt.t/0  }"|"1t2 t.t3 }#d|"4 }$t5|#|$\}%}&tj67d d d!kr$t8d"|% t9 s6|d k	r>t:t; t< }'t=|&|j>|' |'d# |"?  }(|"|(_@||(_A |(_Bd$|j>d$ i|(_C||(_D|(S )&Nz**___kwargs_ignored,zGUARDS:z___guarded_code.validFc                    s   t tjrd}|d k	rz|jr\t|jD ]}|jt kr( qHq(|jd }dt|dd }n|j	rzdt|j	
 d  }t d| d|  ttjrd}d}|d k	rdd|j	  }|jrd	d|j  }td
| || |s |  d S )NrN   z  # T)linez%sz<60z
Stack:
z
User stack:
zGuard: %s%s%s)
guards_logisEnabledForloggingDEBUGZ
user_stackreversedfilenamer<   r   stacksummarydebugverbose_guards_logjoinformatr   )r|   r`   log_onlyextrar   Zmaybe_stackZmaybe_user_stack)
code_partsr2   r7   add_code_part  s8    
z<CheckFunctionManager.compile_check_fn.<locals>.add_code_partz,Illegal to set tensor_check_names in export.c                 S   sH   g }| D ]:}t |tr"|| qt |tjs2t||j  q|S r>   )r   r   r   r:   ZSymIntr   rJ   Zmaybe_as_int)Zsize_or_strideZ	convertedr   r2   r2   r7   convert  s    
z6CheckFunctionManager.compile_check_fn.<locals>.convertc                    s$   g | ]} j jt| d  qS )sizer   Ztensor_weakref_to_sizes_stridesr   r5   r   rc  r   r2   r7   r     s   z9CheckFunctionManager.compile_check_fn.<locals>.<listcomp>c                    s$   g | ]} j jt| d  qS )Zstridere  rf  rg  r2   r7   r     s   )dynamic_dims_sizesdynamic_dims_stridesr   z%tensor_check_names=tensor_check_namesz___check_tensors(r   zcheck_tensor(z	, device=z, requires_grad=z, size=z	, stride=T)r_  r   zUnknown GuardEnvExpr: Z___guarded_code___check_tensors___check_tensors_verboser~   ZTORCHDYNAMO_PRINT_GUARDS1zGUARDS
Z___make_guard_fnrn   )F)Er   r{   r   r]  rS  r[  ospathdirname__file__r|   r_   r`   r~   r   r   r   TensorGuardscheckZcheck_verboser   r   r   r   r:   _CZ_dispatch_keysZ_dispatch_tls_local_include_setZ_dispatch_tls_local_exclude_setr   rE   indexr   rc   Ztracing_contextZguards_contextaotautograd_guardsr   r   Zinput_source_aZinput_source_brB   RuntimeErrorr}   r'  OrderedDictr   r   rz   updater   uniquer   build_guard_functionenvironr   printr+   r!   guard_fail_hookrt   execru   r   closure_varsargsra  rH  r<  ))r   rN  rO  Z
guards_outr<  largsr  r   rb  Zgclr|   r~   Zcheck_tensors_fnZcheck_tensors_verbose_fnr   rh  ri  Ztensor_guardsZtensor_check_argsr   irB   r   ZpytypeZdispatch_keyr   Zdevice_indexr   sizesstridesru  r`   Zsource_ar   r  Zunique_code_partsZmake_guard_fn_args
guard_bodyZpycodeoutguard_fnr2   )ra  rc  r   r7   rL    s    


%





		

6
	


	

 z%CheckFunctionManager.compile_check_fnc                 C   s
   d| _ d S )NF)rF  r)  r2   r2   r7   
invalidatep  s    zCheckFunctionManager.invalidatec                 C   sR   z4t || jkr2t|| jt |< t|| j W n tk
rH   Y nX t |S )zadd a weakref, return the id)idrG  r	  r   finalizer  	TypeError)r   r   r2   r2   r7   rh   t  s    zCheckFunctionManager.id_ref)NN)ra   rb   rc   r   r   r	   rd   r   rL  r  rh   r2   r2   r2   r7   rg   f  s     @ Jrg   r%  c           	   
      s<  ddl m} trBt   |  tttt tf d fdd}ntttt tf ddd}| }| D ]J}||\}}|| |	d| d |
  |	d	 W 5 Q R X qh| }|	d
 |
  || |	d W 5 Q R X | }|	d| d |
  || |	d W 5 Q R X | | fS )Nr   )IndentedBufferr2  c                    s
     | S r>   )rQ   r"  Zcsepassr2   r7   rQ     s    z%build_guard_function.<locals>.replacec                 S   s   g | fS r>   r2   r  r2   r2   r7   rQ     s    zif not (z):zreturn Falsezdef guard(L):zreturn Truezdef ___make_guard_fn(zreturn guard)Ztorch._inductor.utilsr  HAS_UNPARSE_FUNCTIONSr  r1  rd   r	   r   
writelinesZ	writelineindentZsplicegetvalue)	ra  Zclosure_argsr  rQ   r  r"  r  r`   Zmake_guard_fnr2   r  r7   rz    s0    
"






rz  )r  r|   f_localsrt  lastrK   c              
   C   s8  |dk}| j s|s|sdS || jd d}|| j |d |d< d}| jD ]^}t| j}	||	d< t  t||	|}
W 5 Q R X t|
t	r|
} qqLt|
t
rL|
sL|} qqL|r|a|sdS tsttt|  t daz&| j dk	r|  t|pdt|  W n2 tk
r2 } ztjd	d
d W 5 d}~X Y nX dS )z(
    called whenever a guard fails.
    r   Nrn   )rm   rn   rk  rj  Z__compile_source__zunknown reasonzVFailure in guard_fail_fn callback - raising here will cause a NULL Error on guard evalT)exc_info)r<  rH  rx  r  ra  rt   r   r   r   rd   r   stashed_first_fail_reasonr   r*   r-   r   r%   	Exceptionr   error)r  r|   r  rt  r  firstru   r  partrH  Zfail_reasonr0  r2   r2   r7   r}    sF    




r}  r  r|   r  rt  r  c                 C   sN   t d|j d|j d|j  t dd| j d  t dd| j d S )NzERROR RUNNING GUARDS  :zlambda r   z and
  )r|  co_nameco_filenameco_firstlinenor]  r  ra  r  r2   r2   r7   guard_error_hook  s
    r  c                 c   s,   t  }| D ]}||kr
|V  || q
d S r>   )r   add)seqseenr   r2   r2   r7   ry    s
    ry  c                 C   s8   |r4|| kr4t |}t | }||kr4tjtj|dS d S )N)r   )r   	functoolspartialrf   r   )Z
obj_sourceZdupe_sourceZser_source_is_localZsource_is_localr2   r2   r7   make_dupe_guard	  s    r  )rT   rv   r'  r8  r
  r  r?   r3   r   rU  r   rm  rV   systypesr	  r   r   typingr   r   r   r   r   r	   r
   r   r   numpyr   ModuleNotFoundErrorr:   r   Ztorch._dynamo.sourcer   r   r   Ztorch._guardsr   r   r   r   r   r   Z%torch.fx.experimental.symbolic_shapesr   r   Ztorch.utils._tracebackr   r   Ztorch.utils.weakr   r   rN   r   r   r   Z
eval_framer    r!   excr"   r   r#   r$   r%   r&   r   r'   r(   r)   r*   r+   r,   r-   r.   r/   r0   	getLoggerra   r   Z_loggingZgetArtifactLoggerrS  r\  rs  r9   r1   rq  Zcheck_obj_idZcheck_type_id	lru_cacher<   rw  Zis_grad_enabledZ$are_deterministic_algorithms_enabledZ_is_torch_function_enabled__getitem__r   r   r   rE   Z	as_tensorr   version_inforO   r  rd   rS   r  ImportErrorrY   rZ   r9  r^   rf   r  rg   rz  r  CodeTyper  r   r   r}  r  ry  r  r2   r2   r2   r7   <module>   s   (
 0

	





    e[  +
=
