U
    9%e9w                     @   s  d dl Zd dlZd dlZd dlZd dlmZ d dlmZ d dl	Z	d dl
mZmZmZ ddlmZ ddlmZmZmZmZ dd	lmZmZmZmZmZ dd
lmZmZ ddlmZ e e!Z"e rd dl#Z#e rddl$m%Z%m&Z& ddlm'Z' ne(Z%da)e*e+e,e-e.e/dZ0eG dd dZ1i a2ddddgZ3d/ddZ4dd Z5d0ddZ6d1ddZ7dd  Z8d!d" Z9G d#d$ d$Z:G d%d& d&e:Z;G d'd( d(e:Z<G d)d* d*e:Z=G d+d, d,e:Z>G d-d. d.e%Z?dS )2    N)	dataclass)Dict)HfFolderhf_hub_downloadlist_spaces   )AutoTokenizer)is_offline_modeis_openai_availableis_torch_availablelogging   )TASK_MAPPINGTOOL_CONFIG_FILETool	load_toolsupports_remote)CHAT_MESSAGE_PROMPTdownload_prompt)evaluate)StoppingCriteriaStoppingCriteriaList)AutoModelForCausalLMF)printrangefloatintboolstrc                   @   s&   e Zd ZU eed< eed< eed< dS )PreTooltaskdescriptionrepo_idN)__name__
__module____qualname__r   __annotations__ r'   r'   X/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/tools/agents.pyr   ;   s   
r   zimage-transformationztext-downloadztext-to-imageztext-to-videohuggingface-toolsc           	   
   C   s   t  rtd i S t| d}i }|D ]b}|j}t|tdd}t|dd}t	|}W 5 Q R X |
dd }t||d	 |d
||d < q&|S )Nz;You are in offline mode, so remote tools are not available.)authorspace)Z	repo_typezutf-8)encoding/r!   r    r!   r"   name)r	   loggerinfor   idr   r   openjsonloadsplitr   )	ZorganizationspacestoolsZ
space_infor"   Zresolved_config_filereaderconfigr    r'   r'   r(   get_remote_toolsM   s    

r<   c            
      C   s   t rd S td} | j}t }t D ],\}}t||}|j}t	||d dt
|j< q&t stD ]F}d}| D ]"\}}	|	j|krn|	t
|< d} qqn|s^t| dq^da d S )NZtransformersr/   FTz is not implemented on the Hub.)_tools_are_initialized	importlibimport_moduler9   r<   r   itemsgetattrr!   r   HUGGINGFACE_DEFAULT_TOOLSr0   r	   "HUGGINGFACE_DEFAULT_TOOLS_FROM_HUBr    
ValueError)
Zmain_moduleZtools_moduleZremote_toolsZ	task_nameZtool_class_nameZ
tool_classr!   foundZ	tool_nametoolr'   r'   r(   _setup_default_tools`   s(    


rG   c           	      C   s   |d krt  }n|}| D ]`\}}|| ks||kr8qt|trL|||< q|jd kr\|jn|j}|olt|}t||d||< q|S )Nremote)	BASE_PYTHON_TOOLScopyr@   
isinstancer   r"   r    r   r   )	codetoolboxrI   cached_toolsZresolved_toolsr0   rF   task_or_repo_idZ_remoter'   r'   r(   resolve_tools   s    


rQ   c                 C   s   ddg}|  D ]`\}}|| kst|tr,q|jd kr<|jn|j}| d| d}|r^|d7 }|d7 }|| qd|d S )Nz"from transformers import load_tool z = load_tool(""z, remote=True)
)r@   rL   r   r"   r    appendjoin)rM   rN   rI   
code_linesr0   rF   rP   liner'   r'   r(   get_tool_creation_code   s    rZ   c                 C   s   |  d}d}|t|k r6||  ds6|d7 }qd|d |  }|t|kr`|d fS |d7 }|}||  ds|d7 }qld|||  }||fS )NrU   r   ```r   )r7   lenlstrip
startswithrW   strip)resultlinesidxexplanationZ	start_idxrM   r'   r'   r(   clean_code_for_chat   s    


rd   c                 C   st   d|  } |  d\}}| }| }| d}|d dkrJ|dd  }|d dkrb|d d }d|}||fS )	NzI will use the following zAnswer:rU   r   )r[   z```pyz	```pythonr   r.   r[   )r7   r_   rW   )r`   rc   rM   rX   r'   r'   r(   clean_code_for_run   s    


re   c                   @   s~   e Zd ZdZdddZeeeef dddZ	dd	d
Z
dd ZdddddZdd ZdddddZdd Zdd ZdS )Agenta&  
    Base class for all agents which contains the main API methods.

    Args:
        chat_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `chat_prompt_template.txt` in this repo in this case.
        run_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `run` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `run_prompt_template.txt` in this repo in this case.
        additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
            Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
            one of the default tools, that default tool will be overridden.
    Nc                 C   s  t   | jj}t||dd| _t||dd| _t | _t	| _
|d k	rt|ttfrfdd |D }nt|tsz|j|i}dd | D }| j| t|dkrdd	d
 | D }td| d n8t|dkrt| d }t| d||  d |   d S )Nchat)moderunc                 S   s   i | ]}|j |qS r'   )r0   ).0tr'   r'   r(   
<dictcomp>   s      z"Agent.__init__.<locals>.<dictcomp>c                 S   s   i | ]\}}|t kr||qS r'   )rB   rj   r0   rF   r'   r'   r(   rl      s       r   rU   c                 S   s    g | ]\}}d | d| qS z- : r'   )rj   nrk   r'   r'   r(   
<listcomp>   s     z"Agent.__init__.<locals>.<listcomp>zSThe following tools have been replaced by the ones provided in `additional_tools`:
.r   z has been replaced by z# as provided in `additional_tools`.)rG   	__class__r#   r   chat_prompt_templaterun_prompt_templaterB   rK   _toolboxr   logrL   listtupledictr0   r@   updater\   rW   r1   warningkeysprepare_for_new_chat)selfrt   ru   additional_toolsZ
agent_nameZreplacementsnamesr0   r'   r'   r(   __init__   s,    




zAgent.__init__returnc                 C   s   | j S )z-Get all tool currently available to the agent)rv   r   r'   r'   r(   rN      s    zAgent.toolboxFc                 C   sn   d dd | j D }|rP| jd kr8| jd|}n| j}|td|7 }n| jd|}|d|}|S )NrU   c                 S   s"   g | ]\}}d | d|j  qS rn   )r!   rm   r'   r'   r(   rq      s     z'Agent.format_prompt.<locals>.<listcomp>z<<all_tools>>z<<task>>z
<<prompt>>)rW   rN   r@   chat_historyrt   replacer   ru   )r   r    	chat_moder!   promptr'   r'   r(   format_prompt   s    
zAgent.format_promptc                 C   s
   || _ dS )z
        Set the function use to stream results (which is `print` by default).

        Args:
            streamer (`callable`): The function to call when streaming results from the LLM.
        N)rw   )r   streamerr'   r'   r(   
set_stream  s    zAgent.set_stream)return_coderI   c          
      K   s   | j |dd}| j|ddgd}||  d | _t|\}}| d|  |dk	r| d	|  |s| d
 t|| j|| jd| _| j	
| t|| j| j	ddS t|| j|d}	|	 d| S dS )a  
        Sends a new request to the agent in a chat. Will use the previous ones in its history.

        Args:
            task (`str`): The task to perform
            return_code (`bool`, *optional*, defaults to `False`):
                Whether to just return code and not evaluate it.
            remote (`bool`, *optional*, defaults to `False`):
                Whether or not to use remote tools (inference endpoints) instead of local ones.
            kwargs (additional keyword arguments, *optional*):
                Any keyword argument to send to the agent when evaluating the code.

        Example:

        ```py
        from transformers import HfAgent

        agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
        agent.chat("Draw me a picture of rivers and lakes")

        agent.chat("Transform the picture so that there is a rock in there")
        ```
        T)r   zHuman:z=====stoprU   ==Explanation from the agent==
N"

==Code generated by the agent==


==Result==rI   rO   rH   )r   generate_oner_   r   rd   rw   rQ   rN   rO   
chat_stater{   r   rZ   
r   r    r   rI   kwargsr   r`   rc   rM   Z	tool_coder'   r'   r(   rg     s    
z
Agent.chatc                 C   s   d| _ i | _d| _dS )zG
        Clears the history of prior calls to [`~Agent.chat`].
        N)r   r   rO   r   r'   r'   r(   r~   6  s    zAgent.prepare_for_new_chatc          
      K   s   |  |}| j|dgd}t|\}}| d|  | d|  |s~| d t|| j|| jd| _t|| j| dS t	|| j|d}	|	 d	| S d
S )a  
        Sends a request to the agent.

        Args:
            task (`str`): The task to perform
            return_code (`bool`, *optional*, defaults to `False`):
                Whether to just return code and not evaluate it.
            remote (`bool`, *optional*, defaults to `False`):
                Whether or not to use remote tools (inference endpoints) instead of local ones.
            kwargs (additional keyword arguments, *optional*):
                Any keyword argument to send to the agent when evaluating the code.

        Example:

        ```py
        from transformers import HfAgent

        agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
        agent.run("Draw me a picture of rivers and lakes")
        ```
        zTask:r   r   r   r   r   )staterH   rU   N)
r   r   re   rw   rQ   rN   rO   r   rK   rZ   r   r'   r'   r(   ri   >  s    

z	Agent.runc                 C   s   t d S N)NotImplementedErrorr   r   r   r'   r'   r(   r   c  s    zAgent.generate_onec                    s    fdd|D S )Nc                    s   g | ]}  |qS r'   )r   rj   r   r   r   r'   r(   rq   i  s     z'Agent.generate_many.<locals>.<listcomp>r'   r   promptsr   r'   r   r(   generate_manyg  s    zAgent.generate_many)NNN)F)r#   r$   r%   __doc__r   propertyr   r   r   rN   r   r   rg   r~   ri   r   r   r'   r'   r'   r(   rf      s   

	*%rf   c                       sB   e Zd ZdZd fdd	Zdd Zdd	 Zd
d Zdd Z  Z	S )OpenAiAgentu!  
    Agent that uses the openai API to generate code.

    <Tip warning={true}>

    The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
    `"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.

    </Tip>

    Args:
        model (`str`, *optional*, defaults to `"text-davinci-003"`):
            The name of the OpenAI model to use.
        api_key (`str`, *optional*):
            The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`.
        chat_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `chat_prompt_template.txt` in this repo in this case.
        run_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `run` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `run_prompt_template.txt` in this repo in this case.
        additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
            Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
            one of the default tools, that default tool will be overridden.

    Example:

    ```py
    from transformers import OpenAiAgent

    agent = OpenAiAgent(model="text-davinci-003", api_key=xxx)
    agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
    ```
    text-davinci-003Nc                    sX   t  std|d kr$tjdd }|d kr6tdn|t_|| _t	 j
|||d d S )N<Using `OpenAiAgent` requires `openai`: `pip install openai`.ZOPENAI_API_KEYzYou need an openai key to use `OpenAIAgent`. You can get one here: Get one here https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = xxx.rt   ru   r   )r
   ImportErrorosenvirongetrD   openaiapi_keymodelsuperr   )r   r   r   rt   ru   r   rs   r'   r(   r     s    zOpenAiAgent.__init__c                    s.   d j kr fdd|D S  |S d S )Ngptc                    s   g | ]}  |qS r'   _chat_generater   r   r'   r(   rq     s     z-OpenAiAgent.generate_many.<locals>.<listcomp>)r   _completion_generater   r'   r   r(   r     s    
zOpenAiAgent.generate_manyc                 C   s,   d| j kr| ||S | |g|d S d S )Nr   r   )r   r   r   r   r'   r'   r(   r     s    
zOpenAiAgent.generate_onec                 C   s2   t jj| jd|dgd|d}|d d d d S )NuserZrolecontentr   )r   messagestemperaturer   choicesmessager   )r   ChatCompletioncreater   r   r   r   r`   r'   r'   r(   r     s    
zOpenAiAgent._chat_generatec                 C   s*   t jj| j|d|dd}dd |d D S )Nr      )r   r   r   r   
max_tokensc                 S   s   g | ]}|d  qS textr'   rj   Zanswerr'   r'   r(   rq     s     z4OpenAiAgent._completion_generate.<locals>.<listcomp>r   )r   
Completionr   r   r   r   r   r`   r'   r'   r(   r     s    z OpenAiAgent._completion_generate)r   NNNN
r#   r$   r%   r   r   r   r   r   r   __classcell__r'   r'   r   r(   r   l  s   '     	r   c                       sB   e Zd ZdZd fdd	Zdd Zdd	 Zd
d Zdd Z  Z	S )AzureOpenAiAgentu	  
    Agent that uses Azure OpenAI to generate code. See the [official
    documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI
    model on Azure

    <Tip warning={true}>

    The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
    `"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.

    </Tip>

    Args:
        deployment_id (`str`):
            The name of the deployed Azure openAI model to use.
        api_key (`str`, *optional*):
            The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`.
        resource_name (`str`, *optional*):
            The name of your Azure OpenAI Resource. If unset, will look for the environment variable
            `"AZURE_OPENAI_RESOURCE_NAME"`.
        api_version (`str`, *optional*, default to `"2022-12-01"`):
            The API version to use for this agent.
        is_chat_mode (`bool`, *optional*):
            Whether you are using a completion model or a chat model (see note above, chat models won't be as
            efficient). Will default to `gpt` being in the `deployment_id` or not.
        chat_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `chat_prompt_template.txt` in this repo in this case.
        run_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `run` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `run_prompt_template.txt` in this repo in this case.
        additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
            Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
            one of the default tools, that default tool will be overridden.

    Example:

    ```py
    from transformers import AzureOpenAiAgent

    agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
    agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
    ```
    N
2022-12-01c	           	         s   t  std|| _dt_|d kr0tjdd }|d krBtdn|t_	|d kr^tjdd }|d krptdnd| dt_
|t_|d krd	| k}|| _t j|||d
 d S )Nr   ZazureZAZURE_OPENAI_API_KEYzYou need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with `os.environ['AZURE_OPENAI_API_KEY'] = xxx.ZAZURE_OPENAI_RESOURCE_NAMEzYou need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with `os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx.zhttps://z.openai.azure.comr   r   )r
   r   deployment_idr   Zapi_typer   r   r   rD   r   Zapi_baseapi_versionloweris_chat_modelr   r   )	r   r   r   resource_namer   r   rt   ru   r   r   r'   r(   r     s6    zAzureOpenAiAgent.__init__c                    s*    j r fdd|D S  |S d S )Nc                    s   g | ]}  |qS r'   r   r   r   r'   r(   rq   .  s     z2AzureOpenAiAgent.generate_many.<locals>.<listcomp>)r   r   r   r'   r   r(   r   ,  s    zAzureOpenAiAgent.generate_manyc                 C   s(   | j r| ||S | |g|d S d S )Nr   )r   r   r   r   r'   r'   r(   r   2  s    zAzureOpenAiAgent.generate_onec                 C   s2   t jj| jd|dgd|d}|d d d d S )Nr   r   r   )enginer   r   r   r   r   r   )r   r   r   r   r   r'   r'   r(   r   8  s    
zAzureOpenAiAgent._chat_generatec                 C   s*   t jj| j|d|dd}dd |d D S )Nr   r   )r   r   r   r   r   c                 S   s   g | ]}|d  qS r   r'   r   r'   r'   r(   rq   I  s     z9AzureOpenAiAgent._completion_generate.<locals>.<listcomp>r   )r   r   r   r   r   r'   r'   r(   r   A  s    z%AzureOpenAiAgent._completion_generate)NNr   NNNNr   r'   r'   r   r(   r     s   2       .	r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )HfAgentuc  
    Agent that uses an inference endpoint to generate code.

    Args:
        url_endpoint (`str`):
            The name of the url endpoint to use.
        token (`str`, *optional*):
            The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
            running `huggingface-cli login` (stored in `~/.huggingface`).
        chat_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `chat_prompt_template.txt` in this repo in this case.
        run_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `run` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `run_prompt_template.txt` in this repo in this case.
        additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
            Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
            one of the default tools, that default tool will be overridden.

    Example:

    ```py
    from transformers import HfAgent

    agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
    agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
    ```
    Nc                    s`   || _ |d kr"dt   | _n(|ds6|dr>|| _nd| | _t j|||d d S )NzBearer ZBearerZBasicr   )url_endpointr   	get_tokentokenr^   r   r   )r   r   r   rt   ru   r   r   r'   r(   r   l  s    zHfAgent.__init__c                 C   s   d| j i}|dd|dd}tj| j||d}|jdkrVtd td	 | 	|S |jdkrzt
d
|j d|  | d d }|D ]$}||r|d t|    S q|S )NAuthorizationr   F)max_new_tokensZreturn_full_textr   )inputs
parameters)r5   headersi  z=Getting rate-limited, waiting a tiny bit before trying again.r   zError ro   r   Zgenerated_text)r   requestspostr   status_coder1   r2   timesleepZ_generate_onerD   r5   endswithr\   )r   r   r   r   r   responser`   stop_seqr'   r'   r(   r   |  s     







zHfAgent.generate_one)NNNN)r#   r$   r%   r   r   r   r   r'   r'   r   r(   r   L  s           r   c                       sB   e Zd ZdZd fdd	Zedd Zedd Zd	d
 Z	  Z
S )
LocalAgenta  
    Agent that uses a local model and tokenizer to generate code.

    Args:
        model ([`PreTrainedModel`]):
            The model to use for the agent.
        tokenizer ([`PreTrainedTokenizer`]):
            The tokenizer to use for the agent.
        chat_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `chat_prompt_template.txt` in this repo in this case.
        run_prompt_template (`str`, *optional*):
            Pass along your own prompt if you want to override the default template for the `run` method. Can be the
            actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
            `run_prompt_template.txt` in this repo in this case.
        additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
            Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
            one of the default tools, that default tool will be overridden.

    Example:

    ```py
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent

    checkpoint = "bigcode/starcoder"
    model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    agent = LocalAgent(model, tokenizer)
    agent.run("Draw me a picture of rivers and lakes.")
    ```
    Nc                    s"   || _ || _t j|||d d S )Nr   )r   	tokenizerr   r   )r   r   r   rt   ru   r   r   r'   r(   r     s    zLocalAgent.__init__c                 K   s&   t j|f|}tj|f|}| ||S )a  
        Convenience method to build a `LocalAgent` from a pretrained checkpoint.

        Args:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
            kwargs (`Dict[str, Any]`, *optional*):
                Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].

        Example:

        ```py
        import torch
        from transformers import LocalAgent

        agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
        agent.run("Draw me a picture of rivers and lakes.")
        ```
        )r   from_pretrainedr   )clsZpretrained_model_name_or_pathr   r   r   r'   r'   r(   r     s    zLocalAgent.from_pretrainedc                 C   s<   t | jdr t| jj d S | j D ]}|j  S d S )Nhf_device_mapr   )hasattrr   rx   r   valuesr   Zdevice)r   paramr'   r'   r(   _model_device  s    zLocalAgent._model_devicec           	      C   s   | j |dd| j}|d jd }tt|| j g}| jj|d d|d}| j |d 	 |d  }|D ] }|
|rl|d t|  }ql|S )Npt)Zreturn_tensors	input_idsr   r   )r   stopping_criteriar   )r   tor   shaper   StopSequenceCriteriar   generatedecodetolistr   r\   )	r   r   r   Zencoded_inputsZsrc_lenr   outputsr`   r   r'   r'   r(   r     s      
zLocalAgent.generate_one)NNN)r#   r$   r%   r   r   classmethodr   r   r   r   r   r'   r'   r   r(   r     s   #	

r   c                   @   s&   e Zd ZdZdd ZedddZdS )r   a6  
    This class can be used to stop generation whenever a sequence of tokens is encountered.

    Args:
        stop_sequences (`str` or `List[str]`):
            The sequence (or list of sequences) on which to stop execution.
        tokenizer:
            The tokenizer used to decode the model outputs.
    c                 C   s    t |tr|g}|| _|| _d S r   )rL   r   stop_sequencesr   )r   r   r   r'   r'   r(   r     s    
zStopSequenceCriteria.__init__r   c                    s,   | j | d  t fdd| jD S )Nr   c                 3   s   | ]}  |V  qd S r   )r   )rj   Zstop_sequenceZdecoded_outputr'   r(   	<genexpr>  s     z0StopSequenceCriteria.__call__.<locals>.<genexpr>)r   r   r   anyr   )r   r   Zscoresr   r'   r   r(   __call__  s    zStopSequenceCriteria.__call__N)r#   r$   r%   r   r   r   r   r'   r'   r'   r(   r     s   
r   )r)   )FN)F)@importlib.utilr>   r5   r   r   dataclassesr   typingr   r   Zhuggingface_hubr   r   r   Zmodels.autor   utilsr	   r
   r   r   baser   r   r   r   r   r   r   r   Zpython_interpreterr   Z
get_loggerr#   r1   r   Z
generationr   r   r   objectr=   r   r   r   r   r   r   rJ   r   rB   rC   r<   rG   rQ   rZ   rd   re   rf   r   r   r   r   r   r'   r'   r'   r(   <module>   s`   




 )b~G]