U
    9%e_U                     @   s  d Z ddlZddlZddlmZmZ ddlmZ ddlm	Z	m
Z
 ddlmZmZ ddlmZ ddlmZmZmZmZ ddlZd	d
lmZmZmZmZ e rddlmZ G dd deZdd Z dd Z!dd Z"dd Z#dd Z$dd Z%dd Z&dd Z'dd Z(dd  Z)d!d" Z*d#d$ Z+d%d& Z,d'd( Z-d)d* Z.d+d, Z/d-d. Z0d/d0 Z1d1d2 Z2d3d4 Z3G d5d6 d6eZ4G d7d8 d8e5eZ6G d9d: d:e6Z7G d;d< d<e6Z8G d=d> d>Z9d?d@ Z:dAdB Z;dZee5e5dEdFdGZ<e
d[e=dIdJdKZ>d\dLdMZ?dNdO Z@d]dPdQZAdRdS ZBdTdU ZCdVdW ZDdXdY ZEdS )^z
Generic utilities
    N)OrderedDictUserDict)MutableMapping)	ExitStackcontextmanager)fieldsis_dataclass)Enum)AnyContextManagerListTuple   )is_flax_availableis_tf_availableis_torch_availableis_torch_fx_proxyc                   @   s   e Zd ZdZdddZdS )cached_propertyz
    Descriptor that mimics @property but caches output in member variable.

    From tensorflow_datasets

    Built-in in functools from Python 3.8.
    Nc                 C   sX   |d kr| S | j d krtdd| j j }t||d }|d krT|  |}t||| |S )Nzunreadable attributeZ	__cached_)fgetAttributeError__name__getattrsetattr)selfobjobjtypeattrcached r   Y/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/utils/generic.py__get__-   s    

zcached_property.__get__)N)r   
__module____qualname____doc__r    r   r   r   r   r   $   s   r   c                 C   s2   |   } | dkrdS | dkr dS td| dS )zConvert a string representation of truth to true (1) or false (0).

    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
    Raises ValueError if 'val' is anything else.
    >   yesony1truetr   >   falsenonofff0r   zinvalid truth value N)lower
ValueError)valr   r   r   	strtobool<   s    r3   c                 C   sH   t t| }|drdS |dr(dS |dr6dS |drDdS d	S )
z
    Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
    frameworks in a smart order, without the need to import the frameworks).
    z<class 'torch.ptz<class 'tensorflow.tfz<class 'jaxjaxz<class 'numpy.npN)strtype
startswith)xZrepresentationr   r   r   infer_framework_from_reprJ   s    



r<   c                    sd   t tttd t| dkr"g ng}dkr:|d |fdd D   fdd|D S )z
    Returns an (ordered since we are in Python 3.7+) dictionary framework to test function, which places the framework
    we can guess from the repr first, then Numpy, then the others.
    r4   r5   r6   r7   Nr7   c                    s   g | ]}| d fkr|qS )r7   r   .0r.   )preferred_frameworkr   r   
<listcomp>j   s      z1_get_frameworks_and_test_func.<locals>.<listcomp>c                    s   i | ]}| | qS r   r   r>   )framework_to_testr   r   
<dictcomp>k   s      z1_get_frameworks_and_test_func.<locals>.<dictcomp>)is_torch_tensoris_tf_tensoris_jax_tensoris_numpy_arrayr<   appendextend)r;   Z
frameworksr   )rB   r@   r   _get_frameworks_and_test_funcZ   s    
rJ   c                 C   sT   t | }| D ]}|| r dS qt| r0dS t rPddlm} t| |rPdS dS )z
    Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray` in the order
    defined by `infer_framework_from_repr`
    Tr   )TracerF)rJ   valuesr   r   Zjax.corerK   
isinstance)r;   framework_to_test_func	test_funcrK   r   r   r   	is_tensorn   s    
rP   c                 C   s   t | tjS N)rM   r7   ndarrayr;   r   r   r   	_is_numpy   s    rT   c                 C   s   t | S )z/
    Tests if `x` is a numpy array or not.
    )rT   rS   r   r   r   rG      s    rG   c                 C   s   dd l }t| |jS Nr   )torchrM   Tensorr;   rV   r   r   r   	_is_torch   s    rY   c                 C   s   t  s
dS t| S )z]
    Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
    F)r   rY   rS   r   r   r   rD      s    rD   c                 C   s   dd l }t| |jS rU   )rV   rM   ZdevicerX   r   r   r   _is_torch_device   s    rZ   c                 C   s   t  s
dS t| S )z]
    Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
    F)r   rZ   rS   r   r   r   is_torch_device   s    r[   c                 C   s8   dd l }t| tr,t|| r(t|| } ndS t| |jS )Nr   F)rV   rM   r8   hasattrr   ZdtyperX   r   r   r   _is_torch_dtype   s    

r]   c                 C   s   t  s
dS t| S )z\
    Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
    F)r   r]   rS   r   r   r   is_torch_dtype   s    r^   c                 C   s   dd l }t| |jS rU   )
tensorflowrM   rW   r;   r5   r   r   r   _is_tensorflow   s    ra   c                 C   s   t  s
dS t| S )zg
    Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
    F)r   ra   rS   r   r   r   rE      s    rE   c                 C   s*   dd l }t|dr|| S t| |jkS )Nr   is_symbolic_tensor)r_   r\   rb   r9   rW   r`   r   r   r   _is_tf_symbolic_tensor   s    

rc   c                 C   s   t  s
dS t| S )z
    Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not
    installed.
    F)r   rc   rS   r   r   r   is_tf_symbolic_tensor   s    rd   c                 C   s   dd l m} t| |jS rU   )	jax.numpynumpyrM   rR   )r;   jnpr   r   r   _is_jax   s    rh   c                 C   s   t  s
dS t| S )zY
    Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
    F)r   rh   rS   r   r   r   rF      s    rF   c                 C   s   dd dd dd dd d}t | ttfr>dd |  D S t | ttfrZd	d
 | D S t| }| D ] \}}|| rj|| |   S qjt | tjr| 	 S | S dS )zc
    Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
    c                 S   s   |     S rQ   )detachcputolistr   r   r   r   <lambda>       zto_py_obj.<locals>.<lambda>c                 S   s   |    S rQ   )rf   rk   rl   r   r   r   rm      rn   c                 S   s   t |  S rQ   )r7   asarrayrk   rl   r   r   r   rm      rn   c                 S   s   |   S rQ   )rk   rl   r   r   r   rm      rn   r=   c                 S   s   i | ]\}}|t |qS r   	to_py_objr?   kvr   r   r   rC      s      zto_py_obj.<locals>.<dictcomp>c                 S   s   g | ]}t |qS r   rp   )r?   or   r   r   rA      s     zto_py_obj.<locals>.<listcomp>N)
rM   dictr   itemslisttuplerJ   r7   numberrk   )r   Zframework_to_py_objrN   	frameworkrO   r   r   r   rq      s     rq   c                 C   s   dd dd dd dd d}t | ttfr>dd |  D S t | ttfrVt| S t| }| D ] \}}|| rf|| |   S qf| S )	zc
    Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array.
    c                 S   s   |     S rQ   )ri   rj   rf   rl   r   r   r   rm     rn   zto_numpy.<locals>.<lambda>c                 S   s   |   S rQ   )rf   rl   r   r   r   rm     rn   c                 S   s
   t | S rQ   )r7   ro   rl   r   r   r   rm     rn   c                 S   s   | S rQ   r   rl   r   r   r   rm     rn   r=   c                 S   s   i | ]\}}|t |qS r   )to_numpyrr   r   r   r   rC     s      zto_numpy.<locals>.<dictcomp>)	rM   rv   r   rw   rx   ry   r7   arrayrJ   )r   Zframework_to_numpyrN   r{   rO   r   r   r   r|     s    
r|   c                       s   e Zd ZdZddddZ fddZdd	 Zd
d Zdd Zdd Z	dd Z
dd Z fddZ fddZee dddZ  ZS )ModelOutputa  
    Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
    tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
    python dictionary.

    <Tip warning={true}>

    You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
    before.

    </Tip>
    N)returnc                    s4   t  r0ddljj jjj fdd dS )zRegister subclasses as pytree nodes.

        This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
        `static_graph=True` with modules that output `ModelOutput` subclasses.
        r   Nc                    s    f j j| |S rQ   )utils_pytreeZ_dict_unflatten)rL   contextclsrV   r   r   rm   :  rn   z/ModelOutput.__init_subclass__.<locals>.<lambda>)r   Ztorch.utils._pytreer   r   Z_register_pytree_nodeZ_dict_flatten)r   r   r   r   __init_subclass__.  s    zModelOutput.__init_subclass__c                    sB   t  j|| | jtk}|r>t| s>t| j d| jj dd S )N.za is not a dataclasss. This is a subclass of ModelOutput and so must use the @dataclass decorator.)super__init__	__class__r~   r   	TypeErrorr!   r   )r   argskwargsZis_modeloutput_subclassr   r   r   r   =  s    
zModelOutput.__init__c           
         s  t  }t|s"t jj dtdd |dd D sNt jj dt |d j}t fdd|dd D }|rt|st	|t
r| }d	}n*zt|}d	}W n tk
r   d
}Y nX |rvt|D ]\}}t	|ttfrt|dkrt	|d tsB|dkr,| |d j< ntd| d qt |d |d  |d dk	r|d  |d < qn|dk	r| |d j< n,|D ]&}t |j}	|	dk	r|	 |j< qdS )zeCheck the ModelOutput dataclass.

        Only occurs if @dataclass decorator has been used.
        z has no fields.c                 s   s   | ]}|j d kV  qd S rQ   )defaultr?   fieldr   r   r   	<genexpr>V  s     z,ModelOutput.__post_init__.<locals>.<genexpr>r   Nz. should not have more than one required field.r   c                 3   s   | ]}t  |jd kV  qd S rQ   )r   namer   r   r   r   r   Z  s     TF   zCannot set key/value for z&. It needs to be a tuple (key, value).)r   lenr1   r   r   allr   r   rP   rM   rv   rw   iterr   	enumeraterx   ry   r8   r   )
r   Zclass_fieldsZfirst_fieldZother_fields_are_noneiteratorZfirst_field_iteratoridxelementr   rt   r   r   r   __post_init__L  sN    






zModelOutput.__post_init__c                 O   s   t d| jj dd S )Nz$You cannot use ``__delitem__`` on a 
 instance.	Exceptionr   r   r   r   r   r   r   r   __delitem__  s    zModelOutput.__delitem__c                 O   s   t d| jj dd S )Nz#You cannot use ``setdefault`` on a r   r   r   r   r   r   
setdefault  s    zModelOutput.setdefaultc                 O   s   t d| jj dd S )NzYou cannot use ``pop`` on a r   r   r   r   r   r   pop  s    zModelOutput.popc                 O   s   t d| jj dd S )NzYou cannot use ``update`` on a r   r   r   r   r   r   update  s    zModelOutput.updatec                 C   s.   t |trt|  }|| S |  | S d S rQ   )rM   r8   rv   rw   to_tuple)r   rs   Z
inner_dictr   r   r   __getitem__  s    
zModelOutput.__getitem__c                    s4   ||   kr"|d k	r"t || t || d S rQ   )keysr   __setitem____setattr__)r   r   valuer   r   r   r     s    zModelOutput.__setattr__c                    s    t  || t  || d S rQ   )r   r   r   )r   keyr   r   r   r   r     s    zModelOutput.__setitem__c                    s   t  fdd  D S )za
        Convert self to a tuple containing all the attributes/keys that are not `None`.
        c                 3   s   | ]} | V  qd S rQ   r   )r?   rs   r   r   r   r     s     z'ModelOutput.to_tuple.<locals>.<genexpr>)ry   r   r   r   r   r   r     s    zModelOutput.to_tuple)r   r!   r"   r#   r   r   r   r   r   r   r   r   r   r   r   r
   r   __classcell__r   r   r   r   r~      s   8r~   c                   @   s   e Zd ZdZedd ZdS )ExplicitEnumzC
    Enum with more explicit error message for missing values.
    c                 C   s(   t | d| j dt| j  d S )Nz is not a valid z, please select one of )r1   r   rx   _value2member_map_r   )r   r   r   r   r   	_missing_  s    zExplicitEnum._missing_N)r   r!   r"   r#   classmethodr   r   r   r   r   r     s   r   c                   @   s   e Zd ZdZdZdZdZdS )PaddingStrategyz
    Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
    IDE.
    longest
max_lengthZ
do_not_padN)r   r!   r"   r#   ZLONGESTZ
MAX_LENGTHZ
DO_NOT_PADr   r   r   r   r     s   r   c                   @   s    e Zd ZdZdZdZdZdZdS )
TensorTypez
    Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
    tab-completion in an IDE.
    r4   r5   r7   r6   N)r   r!   r"   r#   ZPYTORCHZ
TENSORFLOWZNUMPYZJAXr   r   r   r   r     s
   r   c                   @   s2   e Zd ZdZee dddZdd Zdd Zd	S )
ContextManagersz
    Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
    in the `fastcore` library.
    )context_managersc                 C   s   || _ t | _d S rQ   )r   r   stack)r   r   r   r   r   r     s    zContextManagers.__init__c                 C   s   | j D ]}| j| qd S rQ   )r   r   enter_context)r   Zcontext_managerr   r   r   	__enter__  s    
zContextManagers.__enter__c                 O   s   | j j|| d S rQ   )r   __exit__r   r   r   r   r     s    zContextManagers.__exit__N)	r   r!   r"   r#   r   r   r   r   r   r   r   r   r   r     s   r   c                 C   sn   t | }|dkrt| j}n"|dkr4t| j}nt| j}|jD ]"}|dkrF|j| jdkrF dS qFdS )zr
    Check if a given model can return loss.

    Args:
        model_class (`type`): The class of the model.
    r5   r4   Zreturn_lossTF)infer_frameworkinspect	signaturecallforward__call__
parametersr   )model_classr{   r   pr   r   r   can_return_loss  s    
r   c                 C   sr   | j }t| }|dkr$t| j}n"|dkr:t| j}nt| j}d|kr^dd |jD S dd |jD S dS )zq
    Find the labels used by a given model.

    Args:
        model_class (`type`): The class of the model.
    r5   r4   ZQuestionAnsweringc                 S   s    g | ]}d |ks|dkr|qS )label)Zstart_positionsZend_positionsr   r?   r   r   r   r   rA     s       zfind_labels.<locals>.<listcomp>c                 S   s   g | ]}d |kr|qS )r   r   r   r   r   r   rA     s      N)r   r   r   r   r   r   r   r   )r   Z
model_namer{   r   r   r   r   find_labels  s    r    r   )d
parent_key	delimiterc                 C   s   ddd}t || ||S )z/Flatten a nested dict into a single level dict.r   r   c                 s   sd   |   D ]V\}}|r(t|| t| n|}|rTt|trTt|||d  E d H  q||fV  qd S )N)r   )rw   r8   rM   r   flatten_dict)r   r   r   rs   rt   r   r   r   r   _flatten_dict  s
    z#flatten_dict.<locals>._flatten_dict)r   r   )rv   )r   r   r   r   r   r   r   r     s    
r   F)use_temp_dirc              	   c   s*   |r t  }|V  W 5 Q R X n| V  d S rQ   )tempfileTemporaryDirectory)Zworking_dirr   Ztmp_dirr   r   r   working_or_temp_dir  s    
r   c                 C   s   t | rtj| |dS t| r6|dkr,| jS | j| S t| rTddl}|j| |dS t| rjt	j| |dS t
dt|  ddS )z
    Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    )axesNr   )permz"Type not supported for transpose: r   )rG   r7   	transposerD   TZpermuterE   r_   rF   rg   r1   r9   )r}   r   r5   r   r   r   r   "  s    r   c                 C   sn   t | rt| |S t| r&| j| S t| rBddl}|| |S t| rVt| |S tdt	|  ddS )z
    Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    r   Nz Type not supported for reshape: r   )
rG   r7   reshaperD   rE   r_   rF   rg   r1   r9   )r}   Znewshaper5   r   r   r   r   5  s    
r   c                 C   s   t | rtj| |dS t| r:|dkr.|  S | j|dS t| rXddl}|j| |dS t| rntj| |dS tdt	|  ddS )z
    Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    axisNdimr   z Type not supported for squeeze: r   )
rG   r7   squeezerD   rE   r_   rF   rg   r1   r9   r}   r   r5   r   r   r   r   H  s    r   c                 C   st   t | rt| |S t| r(| j|dS t| rFddl}|j| |dS t| r\tj| |dS t	dt
|  ddS )z
    Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    r   r   Nr   $Type not supported for expand_dims: r   )rG   r7   expand_dimsrD   Z	unsqueezerE   r_   rF   rg   r1   r9   r   r   r   r   r   [  s    r   c                 C   sb   t | rt| S t| r"|  S t| r<ddl}|| S t| rJ| jS tdt	|  ddS )z|
    Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays.
    r   Nr   r   )
rG   r7   sizerD   ZnumelrE   r_   rF   r1   r9   )r}   r5   r   r   r   tensor_sizen  s    

r   c                    s^   |   D ]P\}}t|ttfr6 fdd|D | |< q|dk	rd|kr  d| | |< q| S )zB
    Adds the information of the repo_id to a given auto map.
    c                    s.   g | ]&}|d k	r&d|kr&  d| n|qS )N--r   )r?   rt   repo_idr   r   rA     s     z.add_model_info_to_auto_map.<locals>.<listcomp>Nr   )rw   rM   ry   rx   )Zauto_mapr   r   r   r   r   r   add_model_info_to_auto_map  s    r   c                 C   s   t | D ]l}|j}|j}|ds6|ds6|dkr< dS |dsN|dkrT dS |dsp|d	sp|d
kr
 dS q
td|  ddS )z
    Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
    classes are imported or available.
    r_   ZkerasZTFPreTrainedModelr5   rV   ZPreTrainedModelr4   Zflaxr6   ZFlaxPreTrainedModelz%Could not infer framework from class r   N)r   getmror!   r   r:   r   )r   Z
base_classmoduler   r   r   r   r     s    r   )r   r   )F)N)N)Fr#   r   r   collectionsr   r   collections.abcr   
contextlibr   r   dataclassesr   r   enumr	   typingr
   r   r   r   rf   r7   Zimport_utilsr   r   r   r   re   rg   propertyr   r3   r<   rJ   rP   rT   rG   rY   rD   rZ   r[   r]   r^   ra   rE   rc   rd   rh   rF   rq   r|   r~   r8   r   r   r   r   r   r   r   boolr   r   r   r   r   r   r   r   r   r   r   r   <module>   sf   	 

