U
    9%e-x                     @   s2  d Z ddlmZmZ ddlmZ ddlZddlZddlm	Z	 ddl
ZddlmZmZ ddlmZmZ ddlmZ dd	lmZ d
d Zdd Zdd ZG dd deZdd ZG dd deZG dd deZdd Zdd Zdd Z dd Z!G d d! d!eZ"G d"d# d#eZ#G d$d% d%eZ$G d&d' d'ed(Z%dS ))zA
Implements custom ufunc dispatch mechanism for non-CPU devices.
    )ABCMetaabstractmethod)OrderedDictN)reduce)_BaseUFuncBuilderparse_identity)typessigutils)	signatureparse_signaturec                 C   s8   | |kr| S | dkr|S |dkr$| S t d| |dS )=
    Raises
    ------
    ValueError if broadcast fails
       zfailed to broadcast {0} and {1}N)
ValueErrorformat)ab r   Y/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/numba/np/ufunc/deviceufunc.py_broadcast_axis   s    r   c                 C   s^   t t| |g\} }t| t|k r,d|  } qt| t|krFd| }q,tdd t| |D S )r   r   c                 s   s   | ]\}}t ||V  qd S N)r   ).0r   r   r   r   r   	<genexpr>1   s     z&_pairwise_broadcast.<locals>.<genexpr>)maptuplelenzip)Zshape1Zshape2r   r   r   _pairwise_broadcast#   s    

r   c                  G   sl   | st | d }| dd }z$t|ddD ]\}}t||}q*W n" tk
rb   td|Y nX |S dS )r   r   r   N)startz!failed to broadcast argument #{0})AssertionError	enumerater   r   r   )	shapelistresultZothersiZeachr   r   r   _multi_broadcast4   s    r%   c                   @   s   e Zd ZdZdZdZdd Zdd Zdd	 Zd
d Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zdd Zedd Zdd Zd d! Zd"d# Zd$d% ZdS )&UFuncMechanismz0
    Prepare ufunc arguments for vectorize.
    NFc                 C   s>   || _ || _t| j}dg| | _g | _d| _dg| | _dS )zFNever used directly by user. Invoke by UFuncMechanism.call().
        N)typemapargsr   argtypes	scalarposr
   arrays)selfr'   r(   nargsr   r   r   __init__N   s    
zUFuncMechanism.__init__c                 C   sf   t | jD ]V\}}| |r.| || j|< q
t|tttt	j
frP| j| q
t	|| j|< q
dS )z1
        Get all arguments in array form
        N)r!   r(   is_device_arrayas_device_arrayr+   
isinstanceintfloatcomplexnpnumberr*   appendasarray)r,   r$   argr   r   r   _fill_arraysY   s    
zUFuncMechanism._fill_arraysc                 C   sH   t | jD ]8\}}|dk	r
t|d}|dkr8t|j}|| j|< q
dS )z
        Get dtypes
        Ndtype)r!   r+   getattrr5   r8   r;   r)   )r,   r$   aryr;   r   r   r   _fill_argtypesf   s    
zUFuncMechanism._fill_argtypesc                 C   s   g }| j rr| jD ]`}g }tt|| jD ]4\}\}}|dkrNt| j| j}|	||k q(t
|r|	| q|sg }| jD ],}t
dd t|| jD }|r|	| q|stdt|dkrtd|d | _dS )z<Resolve signature.
        May have ambiguous case.
        Nc                 s   s"   | ]\}}|d kp||kV  qd S r   r   )r   formalactualr   r   r   r      s   z4UFuncMechanism._resolve_signature.<locals>.<genexpr>zNo matching version.  GPU ufunc requires array arguments to have the exact types.  This behaves like regular ufunc with casting='no'.r   zqFailed to resolve ufunc due to ambiguous signature. Too many untyped scalars. Use numpy dtype object to type tag.r   )r*   r'   r!   r   r)   r5   r8   r(   r;   r7   all	TypeErrorr   )r,   matchesZ	formaltysZ	match_mapr$   r?   r@   Zall_matchesr   r   r   _resolve_signatureq   s2    


z!UFuncMechanism._resolve_signaturec                 C   s4   | j D ]&}tj| j| g| j| d| j|< q| jS )zPReturn the actual arguments
        Casts scalar arguments to np.array.
        r;   )r*   r5   arrayr(   r)   r+   )r,   r$   r   r   r   _get_actual_args   s    
$zUFuncMechanism._get_actual_argsc           	         s   dd |D }t | t|D ]\}  jkr2q|  rN|  ||< q fddttD }tt j }dg| t j }|D ]}d||< qt	j
jj |d}| |||< q|S )z)Perform numpy ufunc broadcasting
        c                 S   s   g | ]
}|j qS r   shaper   r   r   r   r   
<listcomp>   s     z-UFuncMechanism._broadcast.<locals>.<listcomp>c                    s,   g | ]$}| j ks$ j| | kr|qS r   )ndimrI   )r   axr=   rI   r   r   rK      s    
r   )rI   strides)r%   r!   rI   r/   broadcast_deviceranger   listrO   r5   libZstride_tricksZ
as_stridedforce_array_layout)	r,   arysr"   r$   Z
ax_differsZ
missingdimrO   rM   Zstridedr   rN   r   
_broadcast   s$    



zUFuncMechanism._broadcastc                 C   s*   |    |   |   |  }| |S )z[Prepare and return the arguments for the ufunc.
        Does not call to_device().
        )r:   r>   rD   rG   rV   )r,   rU   r   r   r   get_arguments   s
    zUFuncMechanism.get_argumentsc                 C   s   | j | j S )z)Returns (result_dtype, function)
        )r'   r)   r,   r   r   r   get_function   s    zUFuncMechanism.get_functionc                 C   s   dS )zBIs the `obj` a device array?
        Override in subclass
        Fr   r,   objr   r   r   r/      s    zUFuncMechanism.is_device_arrayc                 C   s   |S )zConvert the `obj` to a device array
        Override in subclass

        Default implementation is an identity function
        r   rZ   r   r   r   r0      s    zUFuncMechanism.as_device_arrayc                 C   s   t ddS )zTHandles ondevice broadcasting

        Override in subclass to add support.
        z'broadcasting on device is not supportedNNotImplementedErrorr,   r=   rI   r   r   r   rP      s    zUFuncMechanism.broadcast_devicec                 C   s   |S )zSEnsures array layout met device requirement.

        Override in sublcass
        r   )r,   r=   r   r   r   rT      s    z!UFuncMechanism.force_array_layoutc                    s  | d| j| dd}|r2tdd|  | || } \}}|d j}|dk	rv|rv	|}fdd |d j
d	kr fd
d|D }g }d}	|D ]6}
|
r||
 d}	qj|
d}|| q|d j}|dkrLj||d}||g ||d | |	r<||S | |S n|r|j
d	krl |}|}||g ||d | ||S |j|kst|j|kstj||d}||g ||d | |j|d|S dS )z1Perform the entire ufunc call mechanism.
        streamoutNzunrecognized keywords: %s, r   c                    s\    j r
tz
|  W S  tk
rV    | s2 n  |  } | Y S Y nX d S r   )SUPPORT_DEVICE_SLICINGr]   Zravelr/   to_host	to_device)r   hostary)crr_   r   r   attempt_ravel  s    

z*UFuncMechanism.call.<locals>.attempt_ravelr   c                    s   g | ]} |qS r   r   rJ   )rg   r   r   rK     s     z'UFuncMechanism.call.<locals>.<listcomp>FT)r_   )popDEFAULT_STREAMwarningswarnjoinrW   rY   rI   r/   r0   rL   r7   rd   allocate_device_arrayextendlaunchreshapeZcopy_to_hostr    r;   )clsr'   r(   kwsr`   ZrestyfuncZoutshapeZdevarysZ
any_devicer   Zdev_arI   Zdevoutr   )rg   rf   r_   r   call   sT    








zUFuncMechanism.callc                 C   s   t dS )zBImplement to device transfer
        Override in subclass
        Nr\   )r,   re   r_   r   r   r   rd   K  s    zUFuncMechanism.to_devicec                 C   s   t dS )z@Implement to host transfer
        Override in subclass
        Nr\   )r,   devaryr_   r   r   r   rc   Q  s    zUFuncMechanism.to_hostc                 C   s   t dS )zBImplements device allocation
        Override in subclass
        Nr\   )r,   rI   r;   r_   r   r   r   rm   W  s    z$UFuncMechanism.allocate_device_arrayc                 C   s   t dS )zKImplements device function invocation
        Override in subclass
        Nr\   )r,   rs   countr_   r(   r   r   r   ro   ]  s    zUFuncMechanism.launch)__name__
__module____qualname____doc__ri   rb   r.   r:   r>   rD   rG   rV   rW   rY   r/   r0   rP   rT   classmethodrt   rd   rc   rm   ro   r   r   r   r   r&   G   s*   +	!

Yr&   c                 C   s    t | tjr| j} tt| S r   )r1   r   Z
EnumMemberr;   r5   str)tyr   r   r   to_dtyped  s    r~   c                   @   sZ   e Zd Zddi fddZedd ZdddZd	d
 Zdd Zdd Z	dd Z
dd ZdS )DeviceVectorizeNFc                 C   s`   |rt d|D ]2}|dkr*tdt qd}|d7 }t|| q|| _t|| _t | _	d S )Ncaching is not supportednopythonz+nopython kwarg for cuda target is redundantzUnrecognized options. z3cuda vectorize target does not support option: '%s')
rB   rj   rk   RuntimeWarningKeyErrorpy_funcr   identityr   	kernelmap)r,   rs   r   cachetargetoptionsoptfmtr   r   r   r.   k  s    
zDeviceVectorize.__init__c                 C   s   | j S r   r   rX   r   r   r   pyfunc{  s    zDeviceVectorize.pyfuncc                 C   s   t |\}}t|f| }| jj}| | j||}| |\}}| |}tt	j
fdd |D |d d  g  }t|| |d|  }	| |	|}
tdd |jD }t|}||
f| jt|< d S )Nc                 S   s   g | ]}|d d  qS r   r   rJ   r   r   r   rK     s     z'DeviceVectorize.add.<locals>.<listcomp>z__vectorized_%sc                 s   s   | ]}t |V  qd S r   )r~   r   tr   r   r   r     s     z&DeviceVectorize.add.<locals>.<genexpr>)r	   normalize_signaturer
   r   rw   _get_kernel_source_kernel_template_compile_core_get_globalsr   voidexec_compile_kernelr   r(   r~   r   )r,   sigr(   return_typeZdevfnsigfuncnameZkernelsourcecorefnZglblZstagerkernelZ	argdtypesZresdtyper   r   r   add  s      
(
zDeviceVectorize.addc                 C   s   t d S r   r\   rX   r   r   r   build_ufunc  s    zDeviceVectorize.build_ufuncc                 C   sH   dd t t|jD }t|d|ddd |D d}|jf |S )Nc                 S   s   g | ]}d | qS )za%dr   r   r$   r   r   r   rK     s     z6DeviceVectorize._get_kernel_source.<locals>.<listcomp>ra   c                 s   s   | ]}d | V  qdS )z%s[__tid__]Nr   r   r   r   r   r     s     z5DeviceVectorize._get_kernel_source.<locals>.<genexpr>)namer(   argitems)rQ   r   r(   dictrl   r   )r,   templater   r   r(   Zfmtsr   r   r   r     s    z"DeviceVectorize._get_kernel_sourcec                 C   s   t d S r   r\   r,   r   r   r   r   r     s    zDeviceVectorize._compile_corec                 C   s   t d S r   r\   )r,   r   r   r   r   r     s    zDeviceVectorize._get_globalsc                 C   s   t d S r   r\   r,   fnobjr   r   r   r   r     s    zDeviceVectorize._compile_kernel)N)rw   rx   ry   r.   propertyr   r   r   r   r   r   r   r   r   r   r   r   j  s   

r   c                   @   sD   e Zd Zddi dfddZedd Zddd	Zd
d Zdd ZdS )DeviceGUFuncVectorizeNFr   c           	      C   s   |rt d|rt d|dds,t d|rZddd | D }d	}t |||| _t|| _|| _t	| j\| _
| _t | _d S )
Nr   zwritable_args are not supportedr   Tznopython flag must be Truera   c                 S   s   g | ]}t |qS r   )reprr   kr   r   r   rK     s     z2DeviceGUFuncVectorize.__init__.<locals>.<listcomp>z3The following target options are not supported: {0})rB   rh   rl   keysr   r   r   r   r
   r   inputsig	outputsigr   r   )	r,   rs   r   r   r   r   Zwritable_argsoptsr   r   r   r   r.     s    
zDeviceGUFuncVectorize.__init__c                 C   s   | j S r   r   rX   r   r   r   r     s    zDeviceGUFuncVectorize.pyfuncc                 C   s  dd | j D }dd | jD }t|\}}|tjd fk}|sVtd| d| d| jj}t	| j
||||}| |}	t||	 |	dj|d }
tt||| }| j|
t|d	}t|}d
d |D }t|d |  }t|| d  }||f| j|< d S )Nc                 S   s   g | ]}t |qS r   r   r   xr   r   r   rK     s     z-DeviceGUFuncVectorize.add.<locals>.<listcomp>c                 S   s   g | ]}t |qS r   r   r   r   r   r   rK     s     z7guvectorized functions cannot return values: signature z specifies z return typez__gufunc_{name})r   )r   c                 S   s   g | ]}t t|jqS r   )r5   r;   r|   r   r   r   r   rK     s     )r   r   r	   r   r   nonerB   r   rw   expand_gufunc_templater   r   r   r   rR   _determine_gufunc_outer_typesr   r   r   r   )r,   r   indimsoutdimsr(   r   Zvalid_return_typer   srcZglblsr   Zoutertysr   noutZdtypesindtypes	outdtypesr   r   r   r     s,      

zDeviceGUFuncVectorize.addc                 C   s   t d S r   r\   r   r   r   r   r     s    z%DeviceGUFuncVectorize._compile_kernelc                 C   s   t d S r   r\   r   r   r   r   r     s    z"DeviceGUFuncVectorize._get_globals)N)	rw   rx   ry   r.   r   r   r   r   r   r   r   r   r   r     s   


 r   c                 c   sZ   t | |D ]J\}}t|tjr2|j|d dV  q
|dkrBtdtj|dddV  q
d S )Nr   )rL   r   z,gufunc signature mismatch: ndim>0 for scalarA)r;   rL   Zlayout)r   r1   r   Arraycopyr   )ZargtysZdimsatndr   r   r   r     s    r   c                 C   s   || }dd t t|D }dddd |D }dd t|||D }dd t|t|d ||t|d D }	||	 }
| j|d||d|
d	}|S )
z"Expand gufunc source template
    c                 S   s   g | ]}d  |qS )zarg{0}r   r   r   r   r   rK     s     z*expand_gufunc_template.<locals>.<listcomp>zmin({0})ra   c                 S   s   g | ]}d  |qS )z{0}.shape[0]r   rJ   r   r   r   rK     s   c                 S   s   g | ]\}}}t |||qS r   _gen_src_for_indexingr   arefadimsatyper   r   r   rK     s   c                 S   s   g | ]\}}}t |||qS r   r   r   r   r   r   rK     s   N)r   r(   
checkedargr   )rQ   r   r   rl   r   )r   r   r   r   r)   Zargdimsargnamesr   inputsoutputsr   r   r   r   r   r     s&    

r   c                 C   s   dj | t||dS )Nz{aref}[{sliced}])r   Zsliced)r   _gen_src_index)r   r   r   r   r   r   r     s    r   c                 C   sD   | dkrd dgdg|   S t|tjr<|jd | kr<dS dS d S )Nr   ,Z__tid__:r   z__tid__:(__tid__ + 1))rl   r1   r   r   rL   )r   r   r   r   r   r     s
    r   c                   @   s,   e Zd ZdZedd Zdd Zdd ZdS )	GUFuncEnginezZDetermine how to broadcast and execute a gufunc
    base on input shape and signature
    c                 C   s   | t | S r   r   )rq   r
   r   r   r   from_signature  s    zGUFuncEngine.from_signaturec                 C   s(   || _ || _t| j | _t| j| _d S r   )sinsoutr   ninr   )r,   r   r   r   r   r   r.   "  s    zGUFuncEngine.__init__c                 C   s  t || jkrtdi }g }g }tt|| jD ]\}\}}|d7 }t |}t ||k rld}	t|	|f |r|| d  }
|d |  }nd}
|}tt|
|D ]H\}\}}|t |7 }||kr|| |krd}	t|	||f |||< q|| ||
 q2g }| jD ]2}g }|D ]}|||  q|t	| qdd |D }t
|}|| }dg| j }t|D ]H\}}||krv|d	ks|dkrd
||< nd}	t|	|d f qvt| ||||S )Nz invalid number of input argumentr   z%arg #%d: insufficient inner dimensionr   z$arg #%d: shape[%d] mismatch argumentc                 S   s   g | ]}t tj|d qS r   )r   operatormulr   sr   r   r   rK   T  s     z)GUFuncEngine.schedule.<locals>.<listcomp>Fr   Tz!arg #%d: outer dimension mismatch)r   r   rB   r!   r   r   r   r7   r   r   r5   ZargmaxGUFuncSchedule)r,   ishapesZ	symbolmapZouter_shapesZinner_shapesZargnrI   symbolsZ
inner_ndimr   Zinner_shapeZouter_shapeZaxisdimsymoshapesZoutsigoshapesizesZ	largest_iloopdimspinnedr$   dr   r   r   schedule*  sT    





zGUFuncEngine.scheduleN)rw   rx   ry   rz   r{   r   r.   r   r   r   r   r   r     s
   
r   c                   @   s   e Zd Zdd Zdd ZdS )r   c                    sF   || _ || _|| _ | _ttj d| _|| _ fdd|D | _	d S )Nr   c                    s   g | ]} | qS r   r   r   r   r   r   rK   p  s     z+GUFuncSchedule.__init__.<locals>.<listcomp>)
parentr   r   r   r   r   r   loopnr   output_shapes)r,   r   r   r   r   r   r   r   r   r.   e  s    zGUFuncSchedule.__init__c                    s,   dd l }d} fdd|D }|t|S )Nr   )r   r   r   r   r   c                    s   g | ]}|t  |fqS r   )r<   r   rX   r   r   rK   v  s     z*GUFuncSchedule.__str__.<locals>.<listcomp>)pprintpformatr   )r,   r   attrsvaluesr   rX   r   __str__r  s    zGUFuncSchedule.__str__N)rw   rx   ry   r.   r   r   r   r   r   r   d  s   r   c                   @   sL   e Zd Zdd Zdd Zdd Zdd Zd	d
 Zdd Zdd Z	dd Z
dS )GeneralizedUFuncc                 C   s   || _ || _d| _d S )Ni   @)r   engineZmax_blocksize)r,   r   r   r   r   r   r.   {  s    zGeneralizedUFunc.__init__c                 O   sv   |  | jj| jj||}| |j|j\}}}}|| |||}|	 }	| 
||	|}
|||j|
 ||S r   )Z_call_stepsr   r   r   	_scheduler   r   adjust_input_typesprepare_outputsprepare_inputsrV   launch_kernelr   post_process_outputs)r,   r(   rr   Z	callstepsr   r   r   r   r   r   
parametersr   r   r   __call__  s      
zGeneralizedUFunc.__call__c           
      C   s   dd |D }| j |}tdd |D }z| j| \}}W n, tk
rj   | |}| j| \}}Y nX t|j|D ]"\}}	|	d k	rx||	jkrxt	dqx||||fS )Nc                 S   s   g | ]
}|j qS r   rH   rJ   r   r   r   rK     s     z.GeneralizedUFunc._schedule.<locals>.<listcomp>c                 s   s   | ]}|j V  qd S r   rE   r   r   r   r   r     s     z-GeneralizedUFunc._schedule.<locals>.<genexpr>zoutput shape mismatch)
r   r   r   r   r   _search_matching_signaturer   r   rI   r   )
r,   r   ZoutsZinput_shapesr   r   r   r   Zsched_shaper`   r   r   r   r     s    

zGeneralizedUFunc._schedulec                 C   s<   | j  D ]$}tdd t||D r
|  S q
tddS )z
        Given the input types in `idtypes`, return a compatible sequence of
        types that is defined in `kernelmap`.

        Note: Ordering is guaranteed by `kernelmap` being a OrderedDict
        c                 s   s   | ]\}}t ||V  qd S r   )r5   Zcan_cast)r   r@   Zdesiredr   r   r   r     s   z>GeneralizedUFunc._search_matching_signature.<locals>.<genexpr>zno matching signatureN)r   r   rA   r   rB   )r,   Zidtypesr   r   r   r   r     s    
z+GeneralizedUFunc._search_matching_signaturec                 C   s   |j dkstd|jsdn|j }g }t||jD ]B\}}|s`|jdkr`| ||}|| q2|| ||| q2g }	t||j	D ]\}
}|	|
j
|f|  qt|t|	 S )Nr   zzero looping dimensionr   )r   r    r   r   r   size_broadcast_scalar_inputr7   _broadcast_arrayr   rp   r   )r,   r   paramsZretvalsZodim	newparamspcsru   Z
newretvalsretvalr   r   r   r   rV     s    zGeneralizedUFunc._broadcastc                 C   sf   |f| }|j |kr|S t|j t|k rX|t|j  d  |j ksLtd| ||S |j| S d S )Nz+cannot add dim and reshape at the same time)rI   r   r    _broadcast_add_axisrp   )r,   r=   ZnewdimZinnerdimnewshaper   r   r   r     s    

z!GeneralizedUFunc._broadcast_arrayc                 C   s   t dd S )Nzcannot add new axisr\   )r,   r=   r   r   r   r   r     s    z$GeneralizedUFunc._broadcast_add_axisc                 C   s   t d S r   r\   r^   r   r   r   r     s    z(GeneralizedUFunc._broadcast_scalar_inputN)rw   rx   ry   r.   r   r   r   rV   r   r   r   r   r   r   r   r   z  s   r   c                   @   s~   e Zd ZdZdddgZedd Zedd Zed	d
 Zedd Z	edd Z
dd Zdd Zdd Zdd Zdd ZdS )GUFuncCallStepsab  
    Implements memory management and kernel launch operations for GUFunc calls.

    One instance of this class is instantiated for each call, and the instance
    is specific to the arguments given to the GUFunc call.

    The base class implements the overall logic; subclasses provide
    target-specific implementations of individual functions.
    r   r   _copy_result_to_hostc                 C   s   dS )zImplement the kernel launchNr   )r,   r   Znelemr(   r   r   r   r     s    zGUFuncCallSteps.launch_kernelc                 C   s   dS )zb
        Return True if `obj` is a device array for this target, False
        otherwise.
        Nr   rZ   r   r   r   r/     s    zGUFuncCallSteps.is_device_arrayc                 C   s   dS )z
        Return `obj` as a device array on this target.

        May return `obj` directly if it is already on the target.
        Nr   rZ   r   r   r   r0     s    zGUFuncCallSteps.as_device_arrayc                 C   s   dS )zK
        Copy `hostary` to the device and return the device array.
        Nr   )r,   re   r   r   r   rd     s    zGUFuncCallSteps.to_devicec                 C   s   dS )zc
        Allocate a new uninitialized device array with the given shape and
        dtype.
        Nr   )r,   rI   r;   r   r   r   rm   	  s    z%GUFuncCallSteps.allocate_device_arrayc                    s6  | d}|d krbt|||| fkrbdd }d|| d|||  d|t| d}t||d k	rt||krtdn
|g| }d	}g _|D ]2}	|	rj|	 d
}qj|	 qtfdd|D  }
|
o|_	fdd  fdd|D }|d | _
||d  }|r2|_d S )Nr`   c                 S   s   |  dd| dk  S )Nz positional argumentr   r   r   )nr   r   r   pos_argn  s    z*GUFuncCallSteps.__init__.<locals>.pos_argnzThis gufunc accepts z  (when providing input only) or z( (when providing input and output). Got .z<cannot specify argument 'out' as both positional and keywordTFc                    s   g | ]}  |qS r   )r/   rJ   rX   r   r   rK   2  s     z,GUFuncCallSteps.__init__.<locals>.<listcomp>c                    s      | r j}ntj}|| S r   )r/   r0   r5   r8   )r   convertrX   r   r   normalize_arg;  s    
z/GUFuncCallSteps.__init__.<locals>.normalize_argc                    s   g | ]} |qS r   r   rJ   )r  r   r   rK   C  s     )getr   rB   r   r   r/   r7   r0   anyr  r   )r,   r   r   r(   kwargsr   r  msgZall_user_outputs_are_hostoutputZall_host_arraysZnormalized_argsZunused_inputsr   )r  r,   r   r.     s2    
,


zGUFuncCallSteps.__init__c                 C   s\   t t|| jD ]F\}\}}||jkrt|dsFdt|}t|||| j|< qdS )z
        Attempt to cast the inputs to the required types if necessary
        and if they are not device arrays.

        Side effect: Only affects the elements of `inputs` that require
        a type cast.
        astypezNcompatible signature is possible by casting but {0} does not support .astype()N)	r!   r   r   r;   hasattrr   typerB   r  )r,   r   r$   Zityvalr  r   r   r   r   K  s    

z"GUFuncCallSteps.adjust_input_typesc                 C   sH   g }t |j|| jD ].\}}}|dks,| jr8| ||}|| q|S )z
        Returns a list of output parameters that all reside on the target device.

        Outputs that were passed-in to the GUFunc are used if they reside on the
        device; other outputs are allocated as necessary.
        N)r   r   r   r  rm   r7   )r,   r   r   r   rI   r;   r  r   r   r   r   \  s    zGUFuncCallSteps.prepare_outputsc                    s    fdd  fddj D S )zZ
        Returns a list of input parameters that all reside on the target device.
        c                    s      | r j}n j}|| S r   )r/   r0   rd   )Z	parameterr  rX   r   r   ensure_devicep  s    
z5GUFuncCallSteps.prepare_inputs.<locals>.ensure_devicec                    s   g | ]} |qS r   r   )r   r   )r  r   r   rK   x  s     z2GUFuncCallSteps.prepare_inputs.<locals>.<listcomp>)r   rX   r   )r  r,   r   r   l  s    zGUFuncCallSteps.prepare_inputsc                    sV    j r" fddt| jD }n jd dk	r6 j}t|dkrJ|d S t|S dS )a+  
        Moves the given output(s) to the host if necessary.

        Returns a single value (e.g. an array) if there was one output, or a
        tuple of arrays if there were multiple. Although this feels a little
        jarring, it is consistent with the behavior of GUFuncs in general.
        c                    s   g | ]\}}  ||qS r   )rc   )r   r  Zself_outputrX   r   r   rK     s   z8GUFuncCallSteps.post_process_outputs.<locals>.<listcomp>r   Nr   )r  r   r   r   r   )r,   r   r   rX   r   r   z  s    

z$GUFuncCallSteps.post_process_outputsN)rw   rx   ry   rz   	__slots__r   r   r/   r0   rd   rm   r.   r   r   r   r   r   r   r   r   r    s(   




;r  )	metaclass)&rz   abcr   r   collectionsr   r   rj   	functoolsr   numpyr5   Znumba.np.ufunc.ufuncbuilderr   r   Z
numba.corer   r	   Znumba.core.typingr
   Znumba.np.ufunc.sigparser   r   r   r%   objectr&   r~   r   r   r   r   r   r   r   r   r   r  r   r   r   r   <module>   s6     =D
Kd