U
    ,-eRE                     @   s  d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZmZ d dlZddlmZ ddlmZ eeZdd	d
dddZddiddiddiddidZddiZi ZdZG dd deZee
ee	f ejdddZeddddZ eee
ef dddZ!dS )     N)Path)copyfile)AnyDictListOptionalTupleUnion   )PreTrainedTokenizer)loggingz
source.spmz
target.spmz
vocab.jsonztarget_vocab.jsonztokenizer_config.json)
source_spm
target_spmvocabtarget_vocab_filetokenizer_config_filezHelsinki-NLP/opus-mt-en-dezIhttps://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/source.spmzIhttps://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/target.spmzIhttps://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.jsonzThttps://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/tokenizer_config.json)r   r   r   r      u   ▁c                	       s  e Zd ZdZeZeZeZ	e
ZddgZedZdCeeeef  dd fddZdd ZeedddZdd ZedddZeee dddZeedddZ fddZ fd d!Zee ed"d#d$ZdDee d%d&d'Z d(d) Z!d*d+ Z"e#ed%d,d-Z$dEeee e%e d.d/d0Z&ed%d1d2Z'd3d4 Z(d5d6 Z)ed%d7d8Z*edd9d:d;Z+d<d= Z,d>d? Z-dFeee e.ee d@dAdBZ/  Z0S )GMarianTokenizeraB  
    Construct a Marian tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    Args:
        source_spm (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that
            contains the vocabulary for the source language.
        target_spm (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that
            contains the vocabulary for the target language.
        source_lang (`str`, *optional*):
            A string representing the source language.
        target_lang (`str`, *optional*):
            A string representing the target language.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        model_max_length (`int`, *optional*, defaults to 512):
            The maximum sentence length the model accepts.
        additional_special_tokens (`List[str]`, *optional*, defaults to `["<eop>", "<eod>"]`):
            Additional special tokens used by the tokenizer.
        sp_model_kwargs (`dict`, *optional*):
            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
            to set:

            - `enable_sampling`: Enable subword regularization.
            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.

              - `nbest_size = {0,1}`: No sampling is performed.
              - `nbest_size > 1`: samples from the nbest_size results.
              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
                using forward-filtering-and-backward-sampling algorithm.

            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
              BPE-dropout.

    Examples:

    ```python
    >>> from transformers import MarianForCausalLM, MarianTokenizer

    >>> model = MarianForCausalLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
    >>> tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
    >>> src_texts = ["I am a small frog.", "Tom asked his teacher for advice."]
    >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."]  # optional
    >>> inputs = tokenizer(src_texts, text_target=tgt_texts, return_tensors="pt", padding=True)

    >>> outputs = model(**inputs)  # should work
    ```Z	input_idsZattention_maskz>>.+<<N<unk></s><pad>r   F)sp_model_kwargsreturnc                    s&  |d kri n|| _ t| s,td| || _t|| _|| jkrNtd|	| jks\t|rt|| _dd | j	 D | _
g | _n(dd | j	 D | _
dd | jD | _|| _|| _||g| _t|| j | _t|| j | _| j| _| j| _|   t jf |||||	|
| j ||d	| d S )	Nzcannot find spm source z <unk> token must be in the vocabc                 S   s   i | ]\}}||qS  r   .0kvr   r   o/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/marian/tokenization_marian.py
<dictcomp>   s      z,MarianTokenizer.__init__.<locals>.<dictcomp>c                 S   s   i | ]\}}||qS r   r   r   r   r   r   r      s      c                 S   s$   g | ]}| d r|dr|qS )z>>z<<)
startswithendswithr   r   r   r   r   
<listcomp>   s     
 
 z,MarianTokenizer.__init__.<locals>.<listcomp>)	source_langtarget_lang	unk_token	eos_token	pad_tokenmodel_max_lengthr   r   separate_vocabs)r   r   existsAssertionErrorr*   	load_jsonencoderKeyErrortarget_encoderitemsdecoderZsupported_language_codesr$   r%   	spm_filesload_spm
spm_source
spm_targetcurrent_spmcurrent_encoder_setup_normalizersuper__init__)selfr   r   r   r   r$   r%   r&   r'   r(   r)   r   r*   kwargs	__class__r   r   r;      sD    



zMarianTokenizer.__init__c              	   C   sP   zddl m} || jj| _W n, ttfk
rJ   td dd | _Y nX d S )Nr   )MosesPunctNormalizerz$Recommended: pip install sacremoses.c                 S   s   | S Nr   )xr   r   r   <lambda>       z3MarianTokenizer._setup_normalizer.<locals>.<lambda>)	Z
sacremosesr@   r$   	normalizepunc_normalizerImportErrorFileNotFoundErrorwarningswarn)r<   r@   r   r   r   r9      s    
z!MarianTokenizer._setup_normalizer)rB   r   c                 C   s   |r|  |S dS )zHCover moses empty string edge case. They return empty list for '' input! )rF   )r<   rB   r   r   r   rE      s    zMarianTokenizer.normalizec                 C   s   | j || j | j S rA   )r8   getr&   )r<   tokenr   r   r   _convert_token_to_id   s    z$MarianTokenizer._convert_token_to_id)textc                 C   s2   | j |}|r|dgng }|| j d|fS )z6Remove language codes like >>fr<< before sentencepiecer   rK   )language_code_rematchgroupsub)r<   rO   rQ   coder   r   r   remove_language_code   s    z$MarianTokenizer.remove_language_code)rO   r   c                 C   s&   |  |\}}| jj|td}|| S )N)Zout_type)rU   r7   encodestr)r<   rO   rT   piecesr   r   r   	_tokenize   s    zMarianTokenizer._tokenize)indexr   c                 C   s   | j || jS )z?Converts an index (integer) in a token (str) using the decoder.)r2   rL   r&   )r<   rZ   r   r   r   _convert_id_to_token   s    z$MarianTokenizer._convert_id_to_tokenc                    s   t  j|f|S )ao  
        Convert a list of lists of token ids into a list of strings by calling decode.

        Args:
            sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
                List of tokenized input ids. Can be obtained using the `__call__` method.
            skip_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not to remove special tokens in the decoding.
            clean_up_tokenization_spaces (`bool`, *optional*):
                Whether or not to clean up the tokenization spaces. If `None`, will default to
                `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
            use_source_tokenizer (`bool`, *optional*, defaults to `False`):
                Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
                problems).
            kwargs (additional keyword arguments, *optional*):
                Will be passed to the underlying model specific decode method.

        Returns:
            `List[str]`: The list of decoded sentences.
        )r:   batch_decode)r<   	sequencesr=   r>   r   r   r\      s    zMarianTokenizer.batch_decodec                    s   t  j|f|S )a  
        Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
        tokens and clean up tokenization spaces.

        Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.

        Args:
            token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
                List of tokenized input ids. Can be obtained using the `__call__` method.
            skip_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not to remove special tokens in the decoding.
            clean_up_tokenization_spaces (`bool`, *optional*):
                Whether or not to clean up the tokenization spaces. If `None`, will default to
                `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
            use_source_tokenizer (`bool`, *optional*, defaults to `False`):
                Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
                problems).
            kwargs (additional keyword arguments, *optional*):
                Will be passed to the underlying model specific decode method.

        Returns:
            `str`: The decoded sentence.
        )r:   decode)r<   Z	token_idsr=   r>   r   r   r^      s    zMarianTokenizer.decode)tokensr   c                 C   sv   | j r| jn| j}g }d}|D ]4}|| jkrH|||| d 7 }g }q|| q|||7 }|td}| S )zQUses source spm if _decode_use_source_tokenizer is True, and target spm otherwiserK    )	Z_decode_use_source_tokenizerr5   r6   Zall_special_tokensZdecode_piecesappendreplaceSPIECE_UNDERLINEstrip)r<   r_   Zsp_modelZcurrent_sub_tokensZ
out_stringrM   r   r   r   convert_tokens_to_string  s    
z(MarianTokenizer.convert_tokens_to_string)r   c                 C   s$   |dkr|| j g S || | j g S )z=Build model inputs from a sequence by appending eos_token_id.N)Zeos_token_id)r<   token_ids_0token_ids_1r   r   r    build_inputs_with_special_tokens  s    z0MarianTokenizer.build_inputs_with_special_tokensc                 C   s   | j | _| j| _d S rA   )r5   r7   r.   r8   r<   r   r   r   _switch_to_input_mode&  s    z%MarianTokenizer._switch_to_input_modec                 C   s   | j | _| jr| j| _d S rA   )r6   r7   r*   r0   r8   ri   r   r   r   _switch_to_target_mode*  s    z&MarianTokenizer._switch_to_target_modec                 C   s
   t | jS rA   )lenr.   ri   r   r   r   
vocab_size/  s    zMarianTokenizer.vocab_size)save_directoryfilename_prefixr   c              
   C   s  t j|s"td| d d S g }| jrt j||r@|d ndtd  }t j||rb|d ndtd  }t| j	| t| j
| || || n8t j||r|d ndtd  }t| j	| || ttd td g| j| j| jgD ]\}}}	t j||r|d nd| }
t j|t j|
kr`t j|r`t||
 ||
 qt j|st|
d	}|	 }|| W 5 Q R X ||
 qt|S )
NzVocabulary path (z) should be a directory-rK   r   r   r   r   wb)ospathisdirloggererrorr*   joinVOCAB_FILES_NAMES	save_jsonr.   r0   ra   zipr3   r5   r6   abspathisfiler   openZserialized_model_protowritetuple)r<   rn   ro   Zsaved_filesZout_src_vocab_fileZout_tgt_vocab_fileZout_vocab_fileZspm_save_filenameZspm_orig_pathZ	spm_modelZspm_save_pathfiZcontent_spiece_modelr   r   r   save_vocabulary3  sR    
 

 (
zMarianTokenizer.save_vocabularyc                 C   s   |   S rA   )get_src_vocabri   r   r   r   	get_vocab`  s    zMarianTokenizer.get_vocabc                 C   s   t | jf| jS rA   )dictr.   Zadded_tokens_encoderri   r   r   r   r   c  s    zMarianTokenizer.get_src_vocabc                 C   s   t | jf| jS rA   )r   r0   Zadded_tokens_decoderri   r   r   r   get_tgt_vocabf  s    zMarianTokenizer.get_tgt_vocabc                 C   s"   | j  }|dd dD  |S )Nc                 S   s   i | ]
}|d qS rA   r   r"   r   r   r   r   l  s      z0MarianTokenizer.__getstate__.<locals>.<dictcomp>)r5   r6   r7   rF   r   )__dict__copyupdate)r<   stater   r   r   __getstate__i  s
    
zMarianTokenizer.__getstate__)dr   c                    sF   | _ t dsi  _ fdd jD \ _ _ j _   d S )Nr   c                 3   s   | ]}t | jV  qd S rA   )r4   r   )r   fri   r   r   	<genexpr>w  s     z/MarianTokenizer.__setstate__.<locals>.<genexpr>)r   hasattrr   r3   r5   r6   r7   r9   )r<   r   r   ri   r   __setstate__p  s    
zMarianTokenizer.__setstate__c                 O   s   dS )zJust EOS   r   )r<   argsr=   r   r   r   num_special_tokens_to_add{  s    z)MarianTokenizer.num_special_tokens_to_addc                    s(   t | j  | j  fdd|D S )Nc                    s   g | ]}| krd ndqS )r   r   r   )r   rB   all_special_idsr   r   r#     s     z7MarianTokenizer._special_token_mask.<locals>.<listcomp>)setr   removeZunk_token_id)r<   seqr   r   r   _special_token_mask  s    
z#MarianTokenizer._special_token_mask)rf   rg   already_has_special_tokensr   c                 C   s>   |r|  |S |dkr&|  |dg S |  || dg S dS )zCGet list where entries are [1] if a token is [eos] or [pad] else 0.Nr   )r   )r<   rf   rg   r   r   r   r   get_special_tokens_mask  s
    
z'MarianTokenizer.get_special_tokens_mask)	NNNr   r   r   r   NF)N)N)NF)1__name__
__module____qualname____doc__rx   Zvocab_files_namesPRETRAINED_VOCAB_FILES_MAPZpretrained_vocab_files_mapPRETRAINED_INIT_CONFIGURATIONZpretrained_init_configuration&PRETRAINED_POSITIONAL_EMBEDDINGS_SIZESZmax_model_input_sizesZmodel_input_namesrecompilerP   r   r   rW   r   r;   r9   rE   rN   rU   r   rY   intr[   r\   r^   re   rh   rj   rk   propertyrm   r   r   r   r   r   r   r   r   r   boolr   __classcell__r   r   r>   r   r   ?   sd   :
         >	-     r   )rs   r   r   c                 C   s   t jf |}||  |S rA   )sentencepieceSentencePieceProcessorLoad)rs   r   Zspmr   r   r   r4     s    
r4   )rs   r   c              	   C   s*   t |d}tj| |dd W 5 Q R X d S )Nw   )indent)r}   jsondump)datars   r   r   r   r   ry     s    ry   c              
   C   s,   t | d}t|W  5 Q R  S Q R X d S )Nr)r}   r   load)rs   r   r   r   r   r-     s    r-   )"r   rr   r   rI   pathlibr   shutilr   typingr   r   r   r   r   r	   r   Ztokenization_utilsr   utilsr   Z
get_loggerr   ru   rx   r   r   r   rc   r   rW   r   r4   ry   r-   r   r   r   r   <module>   sL    

     S