U
    9%e&                    @   s  d Z ddlZddlZddlZddlZddlZddlZddlZddlZddl	m
Z
 ddlmZmZmZmZmZmZmZmZ ddlZddlmZmZmZmZ ddlmZmZmZmZm Z m!Z! ddl"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z( e' rddl)Z*e(+e,Z-d	d
iZ.d	dddddddddddddddddddddd d!d"d#iZ/d$d% e/0 D Z1d&d% e/0 D Z2G d'd( d(e#Z3e4d)d*d+d,gZ5e
d-d.G d/d0 d0Z6e
G d1d2 d2Z7e
d-d.G d3d4 d4Z8ed5d6d7Z9d8d9 Z:d:d; Z;d<Z<G d=d> d>eZ=G d?d@ d@e>Z?G dAdB dBe>Z@G dCdD dDejAZBe
G dEdF dFZCe
G dGdH dHZDe
G dIdJ dJZEe
G dKdL dLZFe
G dMdN dNZGe4dOdPdQdRgZHeHd-dSdSZIeHd-d-dSZJeHd-d-d-ZKeHdSd-dSZLeHdSd-d-ZMdTeLfdUeIfdVeIfdWeJfdXeJfdYeMfdZeMfd[eMfd\eMfd]eKfd^eKfd_eKfd`eKfdaeJfdXeJfdbeKfd`eKfdceKfddeKfdeeMfdfeMffZNdgZOdhdi ZPdjdk ZQeQ ZRdlZSdmdndodpdqdrdsdtdudvdwdxdygZTdzd{d|d}d~ddddddddgZUddddgZVeWdZXdZYdZZe[dZ\dd Z]dd Z^dd Z_dd Z`dd Zadd Zbdd Zcdd Zddd Zeee[eee[ ee[ ee[ f f ZfeeDgee[egf f ZhdZidZjdZkdZldd Zmdd Zndd Zodd Zpdd Zqdd Zrdd Zsdd Ztdd Zudd Zvdd ZwdddZxdS )z$ Tokenization class for TAPAS model.    N)	dataclass)CallableDict	GeneratorListOptionalTextTupleUnion   )PreTrainedTokenizer_is_control_is_punctuation_is_whitespace)ENCODE_KWARGS_DOCSTRINGVERY_LARGE_INTEGERBatchEncodingEncodedInputPreTokenizedInput	TextInput)ExplicitEnumPaddingStrategy
TensorTypeadd_end_docstringsis_pandas_availablelogging
vocab_filez	vocab.txtzNhttps://huggingface.co/google/tapas-large-finetuned-sqa/resolve/main/vocab.txtzNhttps://huggingface.co/google/tapas-large-finetuned-wtq/resolve/main/vocab.txtz]https://huggingface.co/google/tapas-large-finetuned-wikisql-supervised/resolve/main/vocab.txtzRhttps://huggingface.co/google/tapas-large-finetuned-tabfact/resolve/main/vocab.txtzMhttps://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/vocab.txtzMhttps://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/vocab.txtz\https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/vocab.txtzQhttps://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/vocab.txtzOhttps://huggingface.co/google/tapas-medium-finetuned-sqa/resolve/main/vocab.txtzOhttps://huggingface.co/google/tapas-medium-finetuned-wtq/resolve/main/vocab.txtz^https://huggingface.co/google/tapas-medium-finetuned-wikisql-supervised/resolve/main/vocab.txtzShttps://huggingface.co/google/tapas-medium-finetuned-tabfact/resolve/main/vocab.txtzNhttps://huggingface.co/google/tapas-small-finetuned-sqa/resolve/main/vocab.txtzNhttps://huggingface.co/google/tapas-small-finetuned-wtq/resolve/main/vocab.txtz]https://huggingface.co/google/tapas-small-finetuned-wikisql-supervised/resolve/main/vocab.txtzRhttps://huggingface.co/google/tapas-small-finetuned-tabfact/resolve/main/vocab.txtzMhttps://huggingface.co/google/tapas-tiny-finetuned-sqa/resolve/main/vocab.txtzMhttps://huggingface.co/google/tapas-tiny-finetuned-wtq/resolve/main/vocab.txtz\https://huggingface.co/google/tapas-tiny-finetuned-wikisql-supervised/resolve/main/vocab.txtzQhttps://huggingface.co/google/tapas-tiny-finetuned-tabfact/resolve/main/vocab.txtzMhttps://huggingface.co/google/tapas-mini-finetuned-sqa/resolve/main/vocab.txtzMhttps://huggingface.co/google/tapas-mini-finetuned-wtq/resolve/main/vocab.txtz\https://huggingface.co/google/tapas-mini-finetuned-wikisql-supervised/resolve/main/vocab.txtzQhttps://huggingface.co/google/tapas-mini-finetuned-tabfact/resolve/main/vocab.txt)z google/tapas-large-finetuned-sqaz google/tapas-large-finetuned-wtqz/google/tapas-large-finetuned-wikisql-supervisedz$google/tapas-large-finetuned-tabfactzgoogle/tapas-base-finetuned-sqazgoogle/tapas-base-finetuned-wtqz.google/tapas-base-finetuned-wikisql-supervisedz#google/tapas-base-finetuned-tabfactz!google/tapas-medium-finetuned-sqaz!google/tapas-medium-finetuned-wtqz0google/tapas-medium-finetuned-wikisql-supervisedz%google/tapas-medium-finetuned-tabfactz google/tapas-small-finetuned-sqaz google/tapas-small-finetuned-wtqz/google/tapas-small-finetuned-wikisql-supervisedz$google/tapas-small-finetuned-tabfactzgoogle/tapas-tiny-finetuned-sqazgoogle/tapas-tiny-finetuned-wtqz.google/tapas-tiny-finetuned-wikisql-supervisedz#google/tapas-tiny-finetuned-tabfactzgoogle/tapas-mini-finetuned-sqazgoogle/tapas-mini-finetuned-wtqz.google/tapas-mini-finetuned-wikisql-supervisedz#google/tapas-mini-finetuned-tabfactc                 C   s   i | ]
}|d qS )    .0namer   r   k/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/tapas/tokenization_tapas.py
<dictcomp>   s      r#   c                 C   s   i | ]}|d diqS )do_lower_caseTr   r   r   r   r"   r#      s      c                   @   s   e Zd ZdZdZdZdS )TapasTruncationStrategyz}
    Possible values for the `truncation` argument in [`~TapasTokenizer.__call__`]. Useful for tab-completion in an IDE.
    drop_rows_to_fitZdo_not_truncateN)__name__
__module____qualname____doc__DROP_ROWS_TO_FITDO_NOT_TRUNCATEr   r   r   r"   r%      s   r%   Z
TokenValuetoken	column_idrow_idT)frozenc                   @   s&   e Zd ZU eed< eed< eed< dS )TokenCoordinatescolumn_index	row_indextoken_indexN)r'   r(   r)   int__annotations__r   r   r   r"   r1      s   
r1   c                   @   s.   e Zd ZU eeee   ed< ee ed< dS )TokenizedTablerowsselected_tokensN)r'   r(   r)   r   r   r6   r1   r   r   r   r"   r7      s   
r7   c                   @   s>   e Zd ZU ee ed< ee ed< ee ed< ee ed< dS )SerializedExampletokens
column_idsrow_idssegment_idsN)r'   r(   r)   r   r   r6   r5   r   r   r   r"   r:      s   
r:   r-   c                 C   s
   |  dS )N##)
startswithr?   r   r   r"   _is_inner_wordpiece   s    rB   c              	   C   sR   t  }t| ddd}| }W 5 Q R X t|D ]\}}|d}|||< q2|S )z*Loads a vocabulary file into a dictionary.rutf-8encoding
)collectionsOrderedDictopen	readlines	enumeraterstrip)r   vocabreaderr;   indexr-   r   r   r"   
load_vocab   s    

rQ   c                 C   s   |   } | sg S |  }|S )z@Runs basic whitespace cleaning and splitting on a piece of text.)stripsplit)textr;   r   r   r"   whitespace_tokenize   s
    rU   a  
            add_special_tokens (`bool`, *optional*, defaults to `True`):
                Whether or not to encode the sequences with the special tokens relative to their model.
            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
                Activates and controls padding. Accepts the following values:

                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
                  sequence if provided).
                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
                  acceptable input length for the model if that argument is not provided.
                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
                  lengths).
            truncation (`bool`, `str` or [`TapasTruncationStrategy`], *optional*, defaults to `False`):
                Activates and controls truncation. Accepts the following values:

                - `True` or `'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument `max_length`
                  or to the maximum acceptable input length for the model if that argument is not provided. This will
                  truncate row by row, removing rows from the table.
                - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
                  greater than the model maximum admissible input size).
            max_length (`int`, *optional*):
                Controls the maximum length to use by one of the truncation/padding parameters.

                If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
                is required by one of the truncation/padding parameters. If the model has no specific maximum input
                length (like XLNet) truncation/padding to a maximum length will be deactivated.
            is_split_into_words (`bool`, *optional*, defaults to `False`):
                Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
                tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
                which it will tokenize. This is useful for NER or token classification.
            pad_to_multiple_of (`int`, *optional*):
                If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
                the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors instead of list of python integers. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return Numpy `np.ndarray` objects.
c                       sV  e Zd ZdZeZeZeZ	de
e
e
eee
eee  d fddZedd Zedd Zdd Zdd Zdd Zdd Zdd Zdeee ee ddd Zee
 ee ee
 d!d"d#Zee
 ee ee
 d!d$d%Zee
 ee ee
 d!d&d'Zee
 ee ee
 d!d(d)Zdee
 eee
  ee
 d*d+d,Zdee
 eee
  eee
 d- fd.d/Z e!e"dd0ee#e$e%e&ee$ ee% ee& f  ee#ee eee  f  ee#ee$ eee$  f  ee#eee'f e#eee(f ee
 ee
 ee#ee)f  ee ee eeeeee*d1d2d3Z+e!e,e"dd0ee#ee$ ee% ee& f  eeee   eeee$   ee#eee'f e#eee(f ee
 ee
 ee#ee)f  ee ee eeeeee*d1d4d5Z-d6d7 Z.de#ee$ ee% ee& f eeee   eeee$   ee#eee'f e#eee(f ee
 ee
 ee#ee)f  ee ee eeeeee*d8d9d:Z/dd0e#ee$ ee% ee& f ee0 eeee   eeee   eeee$   ee#eee'f e#eee(f ee
 ee
 ee#ee)f  ee ee eeeeee*d;d<d=Z1e!e,dd0ee#e$e%e&f  ee#eee'f e#eee(f ee
 ee#ee)f  ee
 d>d?d@Z2e!e,e"dd0ee#e$e%e&f  eee  eee$  ee#eee'f e#eee(f ee
 ee
 ee#ee)f  ee ee eeeee*dAdBdCZ3dd0e#e$e%e&f eee  eee$  ee#eee'f e#eee(f ee
 ee
 ee#ee)f  ee ee eeeedDdEdFZ4e!e,e"dd0e#e$e%e&f ee0 ee0 eee  eee$  ee#eee'f e#eee(f ee
 ee
 ee#ee)f  ee ee eeeeee*dGdHdIZ5ee e0e
e
e
e#ee(f ee
e
f dJdKdLZ6ddMdNZ7dOdP Z8ddQdRZ9e:eddf dSdTdUZ;dVdW Z<dXdY Z=dZd[ Z>d\d] Z?d^d_ Z@d`da ZAdbdc ZBddde ZCdfdg ZDdhdi ZEdjdk ZFdldm ZGdndo ZHdpdq ZIdrds ZJdtdu ZKdvdw ZLdxdy ZMdzd{ ZNd|d} ZOd~d ZPdd ZQde'jRddfe#eSee&f e*f ee
 e'ee
 ee eTdddZUdd ZVdd ZWdddZX  ZYS )TapasTokenizera  
    Construct a TAPAS tokenizer. Based on WordPiece. Flattens a table and one or more related sentences to be used by
    TAPAS models.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods. [`TapasTokenizer`] creates several token type ids to
    encode tabular structure. To be more precise, it adds 7 token type ids, in the following order: `segment_ids`,
    `column_ids`, `row_ids`, `prev_labels`, `column_ranks`, `inv_column_ranks` and `numeric_relations`:

    - segment_ids: indicate whether a token belongs to the question (0) or the table (1). 0 for special tokens and
      padding.
    - column_ids: indicate to which column of the table a token belongs (starting from 1). Is 0 for all question
      tokens, special tokens and padding.
    - row_ids: indicate to which row of the table a token belongs (starting from 1). Is 0 for all question tokens,
      special tokens and padding. Tokens of column headers are also 0.
    - prev_labels: indicate whether a token was (part of) an answer to the previous question (1) or not (0). Useful in
      a conversational setup (such as SQA).
    - column_ranks: indicate the rank of a table token relative to a column, if applicable. For example, if you have a
      column "number of movies" with values 87, 53 and 69, then the column ranks of these tokens are 3, 1 and 2
      respectively. 0 for all question tokens, special tokens and padding.
    - inv_column_ranks: indicate the inverse rank of a table token relative to a column, if applicable. For example, if
      you have a column "number of movies" with values 87, 53 and 69, then the inverse column ranks of these tokens are
      1, 3 and 2 respectively. 0 for all question tokens, special tokens and padding.
    - numeric_relations: indicate numeric relations between the question and the tokens of the table. 0 for all
      question tokens, special tokens and padding.

    [`TapasTokenizer`] runs end-to-end tokenization on a table and associated sentences: punctuation splitting and
    wordpiece.

    Args:
        vocab_file (`str`):
            File containing the vocabulary.
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        do_basic_tokenize (`bool`, *optional*, defaults to `True`):
            Whether or not to do basic tokenization before WordPiece.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        unk_token (`str`, *optional*, defaults to `"[UNK]"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        sep_token (`str`, *optional*, defaults to `"[SEP]"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        pad_token (`str`, *optional*, defaults to `"[PAD]"`):
            The token used for padding, for example when batching sequences of different lengths.
        cls_token (`str`, *optional*, defaults to `"[CLS]"`):
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        mask_token (`str`, *optional*, defaults to `"[MASK]"`):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
        empty_token (`str`, *optional*, defaults to `"[EMPTY]"`):
            The token used for empty cell values in a table. Empty cell values include "", "n/a", "nan" and "?".
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
        cell_trim_length (`int`, *optional*, defaults to -1):
            If > 0: Trim cells so that the length is <= this value. Also disables further cell trimming, should thus be
            used with `truncation` set to `True`.
        max_column_id (`int`, *optional*):
            Max column id to extract.
        max_row_id (`int`, *optional*):
            Max row id to extract.
        strip_column_names (`bool`, *optional*, defaults to `False`):
            Whether to add empty strings instead of column names.
        update_answer_coordinates (`bool`, *optional*, defaults to `False`):
            Whether to recompute the answer coordinates from the answer text.
        min_question_length (`int`, *optional*):
            Minimum length of each question in terms of tokens (will be skipped otherwise).
        max_question_length (`int`, *optional*):
            Maximum length of each question in terms of tokens (will be skipped otherwise).
    TN[UNK][SEP][PAD][CLS][MASK][EMPTY]Fr   )cell_trim_lengthmax_column_id
max_row_idstrip_column_namesupdate_answer_coordinatesmodel_max_lengthadditional_special_tokensc                    s<  t  std|d k	r*|
|kr0||
 n|
g}tj|sLtd| dt|| _t	
dd | j D | _|| _|rt||||d| _t| jt|d| _|| _|d k	r|n|d k	r|nt| _|d k	r|n|d k	r|nt| _|| _|| _|| _|| _t jf ||||||||	|
|||||||||||d| d S )	Nz+Pandas is required for the TAPAS tokenizer.z&Can't find a vocabulary file at path 'z'. To load the vocabulary from a Google pretrained model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`c                 S   s   g | ]\}}||fqS r   r   )r    tokidsr   r   r"   
<listcomp>i  s     z+TapasTokenizer.__init__.<locals>.<listcomp>)r$   never_splittokenize_chinese_charsstrip_accents)rN   	unk_token)r$   do_basic_tokenizerh   rk   	sep_token	pad_token	cls_token
mask_tokenempty_tokenri   rj   r^   r_   r`   ra   rb   min_question_lengthmax_question_lengthrc   rd   )r   ImportErrorappendospathisfile
ValueErrorrQ   rN   rH   rI   itemsids_to_tokensrl   BasicTokenizerbasic_tokenizerWordpieceTokenizerstrwordpiece_tokenizerr^   r   r_   r`   ra   rb   rr   rs   super__init__)selfr   r$   rl   rh   rk   rm   rn   ro   rp   rq   ri   rj   r^   r_   r`   ra   rb   rr   rs   rc   rd   kwargs	__class__r   r"   r   A  s|    

	zTapasTokenizer.__init__c                 C   s   | j jS N)r}   r$   r   r   r   r"   r$     s    zTapasTokenizer.do_lower_casec                 C   s
   t | jS r   )lenrN   r   r   r   r"   
vocab_size  s    zTapasTokenizer.vocab_sizec                 C   s   t | jf| jS r   )dictrN   Zadded_tokens_encoderr   r   r   r"   	get_vocab  s    zTapasTokenizer.get_vocabc                 C   st   t |tkr| jd gS g }| jrd| jj|| jdD ],}|| jjkrP|| q4|| j	|7 }q4n| j	|}|S )Nr   )rh   )
format_text
EMPTY_TEXTrd   rl   r}   tokenizeZall_special_tokensrh   ru   r   )r   rT   split_tokensr-   r   r   r"   	_tokenize  s    zTapasTokenizer._tokenizec                 C   s   | j || j | jS )z0Converts a token (str) in an id using the vocab.)rN   getrk   )r   r-   r   r   r"   _convert_token_to_id  s    z#TapasTokenizer._convert_token_to_idc                 C   s   | j || jS )z=Converts an index (integer) in a token (str) using the vocab.)r{   r   rk   )r   rP   r   r   r"   _convert_id_to_token  s    z#TapasTokenizer._convert_id_to_tokenc                 C   s   d |dd }|S )z:Converts a sequence of tokens (string) in a single string. z ## )joinreplacerR   )r   r;   Z
out_stringr   r   r"   convert_tokens_to_string  s    z'TapasTokenizer.convert_tokens_to_string)save_directoryfilename_prefixreturnc              	   C   s   d}t j|r4t j||r$|d ndtd  }n|r@|d nd| }t|dddZ}t| j dd	 d
D ]<\}}||krt	
d| d |}||d  |d7 }qnW 5 Q R X |fS )Nr   -r   r   wrD   rE   c                 S   s   | d S N   r   )kvr   r   r"   <lambda>      z0TapasTokenizer.save_vocabulary.<locals>.<lambda>keyzSaving vocabulary to z\: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!rG   r   )rv   rw   isdirr   VOCAB_FILES_NAMESrJ   sortedrN   rz   loggerwarningwrite)r   r   r   rP   r   writerr-   r4   r   r   r"   save_vocabulary  s"     
zTapasTokenizer.save_vocabulary)	query_idstable_valuesr   c                 C   s   dgdt | d t |  S )a  
        Creates the attention mask according to the query token IDs and a list of table values.

        Args:
            query_ids (`List[int]`): list of token IDs corresponding to the ID.
            table_values (`List[TableValue]`): lift of table values, which are named tuples containing the
                token value, the column ID and the row ID of said token.

        Returns:
            `List[int]`: List of ints containing the attention mask values.
        r   r   )r   r   r   r   r   r"   $create_attention_mask_from_sequences  s    z3TapasTokenizer.create_attention_mask_from_sequencesc                 C   s<   |rt t| d ng }dgdt| d  dgt|  S )a  
        Creates the segment token type IDs according to the query token IDs and a list of table values.

        Args:
            query_ids (`List[int]`): list of token IDs corresponding to the ID.
            table_values (`List[TableValue]`): lift of table values, which are named tuples containing the
                token value, the column ID and the row ID of said token.

        Returns:
            `List[int]`: List of ints containing the segment token type IDs values.
        r   r   listzipr   )r   r   r   	table_idsr   r   r"   ,create_segment_token_type_ids_from_sequences  s    z;TapasTokenizer.create_segment_token_type_ids_from_sequencesc                 C   s6   |rt t| d ng }dgdt| d  t | S )a  
        Creates the column token type IDs according to the query token IDs and a list of table values.

        Args:
            query_ids (`List[int]`): list of token IDs corresponding to the ID.
            table_values (`List[TableValue]`): lift of table values, which are named tuples containing the
                token value, the column ID and the row ID of said token.

        Returns:
            `List[int]`: List of ints containing the column token type IDs values.
        r   r   r   )r   r   r   Ztable_column_idsr   r   r"   +create_column_token_type_ids_from_sequences  s    z:TapasTokenizer.create_column_token_type_ids_from_sequencesc                 C   s6   |rt t| d ng }dgdt| d  t | S )a  
        Creates the row token type IDs according to the query token IDs and a list of table values.

        Args:
            query_ids (`List[int]`): list of token IDs corresponding to the ID.
            table_values (`List[TableValue]`): lift of table values, which are named tuples containing the
                token value, the column ID and the row ID of said token.

        Returns:
            `List[int]`: List of ints containing the row token type IDs values.
           r   r   r   )r   r   r   Ztable_row_idsr   r   r"   (create_row_token_type_ids_from_sequences  s    z7TapasTokenizer.create_row_token_type_ids_from_sequences)token_ids_0token_ids_1r   c                 C   s(   |dkrt d| jg| | jg | S )a  
        Build model inputs from a question and flattened table for question answering or sequence classification tasks
        by concatenating and adding special tokens.

        Args:
            token_ids_0 (`List[int]`): The ids of the question.
            token_ids_1 (`List[int]`, *optional*): The ids of the flattened table.

        Returns:
            `List[int]`: The model input with special tokens.
        Nz=With TAPAS, you must provide both question IDs and table IDs.)ry   Zcls_token_idZsep_token_id)r   r   r   r   r   r"    build_inputs_with_special_tokens  s    z/TapasTokenizer.build_inputs_with_special_tokens)r   r   already_has_special_tokensr   c                    s`   |rt  j||ddS |dk	rFdgdgt|  dg dgt|  S dgdgt|  dg S )a  
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of question IDs.
            token_ids_1 (`List[int]`, *optional*):
                List of flattened table IDs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        T)r   r   r   Nr   r   )r   get_special_tokens_maskr   )r   r   r   r   r   r   r"   r   0  s      (z&TapasTokenizer.get_special_tokens_maskzpd.DataFrame)tablequeriesanswer_coordinatesanswer_textadd_special_tokenspadding
truncation
max_lengthpad_to_multiple_ofreturn_tensorsreturn_token_type_idsreturn_attention_maskreturn_overflowing_tokensreturn_special_tokens_maskreturn_offsets_mappingreturn_lengthverboser   c                 K   s   t |tjstdd}|dks*t |tr0d}n,t |ttfr\t|dksXt |d tr\d}|shtdt |ttf}|r| j	f |||||||||	|
|||||||d|S | j
f |||||||||	|
|||||||d|S dS )	a  
        Main method to tokenize and prepare for the model one or several sequence(s) related to a table.

        Args:
            table (`pd.DataFrame`):
                Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas
                dataframe to convert it to string.
            queries (`str` or `List[str]`):
                Question or batch of questions related to a table to be encoded. Note that in case of a batch, all
                questions must refer to the **same** table.
            answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*):
                Answer coordinates of each table-question pair in the batch. In case only a single table-question pair
                is provided, then the answer_coordinates must be a single list of one or more tuples. Each tuple must
                be a (row_index, column_index) pair. The first data row (not the column header row) has index 0. The
                first column has index 0. In case a batch of table-question pairs is provided, then the
                answer_coordinates must be a list of lists of tuples (each list corresponding to a single
                table-question pair).
            answer_text (`List[str]` or `List[List[str]]`, *optional*):
                Answer text of each table-question pair in the batch. In case only a single table-question pair is
                provided, then the answer_text must be a single list of one or more strings. Each string must be the
                answer text of a corresponding answer coordinate. In case a batch of table-question pairs is provided,
                then the answer_coordinates must be a list of lists of strings (each list corresponding to a single
                table-question pair).
        z"Table must be of type pd.DataFrameFNTr   zgqueries input must of type `str` (single example), `List[str]` (batch or single pretokenized example). r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r   queryr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )
isinstancepdZ	DataFrameAssertionErrorr   r   tupler   ry   batch_encode_plusencode_plus)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zvalid_queryZ
is_batchedr   r   r"   __call__L  sn    7zTapasTokenizer.__call__c                 K   s   |dk	r|st d|r|r$|s.|r.t dn"|dkrP|dkrPdgt|  }}d|kr`td|rltd| jf |||||||||	|
|||||||d|S )a  
        Prepare a table and a list of strings for the model.

        <Tip warning={true}>

        This method is deprecated, `__call__` should be used instead.

        </Tip>

        Args:
            table (`pd.DataFrame`):
                Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas
                dataframe to convert it to string.
            queries (`List[str]`):
                Batch of questions related to a table to be encoded. Note that all questions must refer to the **same**
                table.
            answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*):
                Answer coordinates of each table-question pair in the batch. Each tuple must be a (row_index,
                column_index) pair. The first data row (not the column header row) has index 0. The first column has
                index 0. The answer_coordinates must be a list of lists of tuples (each list corresponding to a single
                table-question pair).
            answer_text (`List[str]` or `List[List[str]]`, *optional*):
                Answer text of each table-question pair in the batch. In case a batch of table-question pairs is
                provided, then the answer_coordinates must be a list of lists of strings (each list corresponding to a
                single table-question pair). Each string must be the answer text of a corresponding answer coordinate.
        NAsking to return token_type_ids while setting add_special_tokens to False results in an undefined behavior. Please set add_special_tokens to True or set return_token_type_ids to None.WIn case you provide answers, both answer_coordinates and answer_text should be providedis_split_into_words<Currently TapasTokenizer only supports questions as strings.return_offset_mapping is not available when using Python tokenizers. To use this feature, change your tokenizer to one deriving from transformers.PreTrainedTokenizerFast.r   )ry   r   NotImplementedError_batch_encode_plus)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r"   r     sF    6
z TapasTokenizer.batch_encode_plusc                 C   sf   |  |}| jdk	r4t|| jkr4td dg fS | jdk	r^t|| jk r^td dg fS ||fS )zITokenizes the query, taking into account the max and min question length.NzDSkipping query as its tokens are longer than the max question lengthr   zESkipping query as its tokens are shorter than the min question length)r   rs   r   r   r   rr   )r   r   query_tokensr   r   r"   _get_question_tokens"  s    


z#TapasTokenizer._get_question_tokens)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   c                 K   sx   |  |}g }t|D ](\}}| |\}}|||< || q| j|||||||||||	|
d||||||d}t|S )NT)tokenized_tablequeries_tokensr   r   r   r   r   r   r   r   prepend_batch_axisr   r   r   r   r   r   )_tokenize_tablerL   r   ru   _batch_prepare_for_modelr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   table_tokensr   idxr   r   batch_outputsr   r   r"   r   /  s8    
z!TapasTokenizer._batch_encode_plus)	raw_tableraw_queriesr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   c                 K   s   i }t t||||D ]\}}|\}}}}| j|||||||tjj|	|
d d|||d d||dkrj||d  nd |dkr||d  nd d}| D ]&\}}||krg ||< || | qq| j|||
||d}t	||d}|S )NFr   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   prev_answer_coordinatesprev_answer_text)r   r   r   r   )tensor_type)
rL   r   prepare_for_modelr   
DO_NOT_PADvaluerz   ru   padr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rP   Zexample	raw_queryr   Zanswer_coordsZ
answer_txtoutputsr   r   r   r   r"   r   h  sL    z'TapasTokenizer._batch_prepare_for_model)r   r   r   r   r   r   r   r   c           
   	   K   s(   | j |f||||||d|}	|	d S )af  
        Prepare a table and a string for the model. This method does not return token type IDs, attention masks, etc.
        which are necessary for the model to work correctly. Use that method if you want to build your processing on
        your own, otherwise refer to `__call__`.

        Args:
            table (`pd.DataFrame`):
                Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas
                dataframe to convert it to string.
            query (`str` or `List[str]`):
                Question related to a table to be encoded.
        )r   r   r   r   r   r   	input_ids)r   )
r   r   r   r   r   r   r   r   r   encoded_inputsr   r   r"   encode  s    zTapasTokenizer.encode)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   c                 K   sz   |dk	r|st d|r|r$|s,|r,t dd|kr<td|rHtd| jf |||||||||	|
||||||d|S )a7  
        Prepare a table and a string for the model.

        Args:
            table (`pd.DataFrame`):
                Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas
                dataframe to convert it to string.
            query (`str` or `List[str]`):
                Question related to a table to be encoded.
            answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*):
                Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single
                list of one or more tuples. Each tuple must be a (row_index, column_index) pair. The first data row
                (not the column header row) has index 0. The first column has index 0.
            answer_text (`List[str]` or `List[List[str]]`, *optional*):
                Answer text of each table-question pair in the batch. The answer_text must be a single list of one or
                more strings. Each string must be the answer text of a corresponding answer coordinate.
        Nr   r   r   r   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )ry   r   _encode_plus)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r"   r     s@    ,zTapasTokenizer.encode_plus)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   c                 K   s\   |d krd}t d | |}| |\}}| j|||||||||||	|
d|||||dS )Nr   zTAPAS is a question answering model but you have not passed a query. Please be aware that the model will probably not behave correctly.T)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r   r   r   r   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r"   r   .  s6    
zTapasTokenizer._encode_plus)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   c           .      K   s  t |tr.|r&|
dk	s|dk	r&tj}q@tj}nt |ts@t|}t |	tr^|	rVtj}	qptj}	nt |	tspt|	}	i }d}d\}}d|krd|krd}|d }|d }| ||	tjk}| 	|}| 
|\}}}|	tjkr| j|||||
|	d\}}t| ||||}| |}t|dkr0tt| d n
tt| }| t|}d	|krf|d	 rftd
|rz| ||} n|| } |
dk	rt| |
krtdt|  d|
 | |d< | ||}!| ||}"| ||}#|r|dkr|dkrdgt|# }$n| |"|#|||}$t|}t|}| |"|#|\}%}&| ||"|#|}'|dkr`d| jk}|dkrtd| jk}|r| ||}(|(|d< |dk	r|dk	r| |"|#|||})| ||"|#}*| ||"|#}+|)|d< |*|d< |+|d< |r|!|"|#|$|%|&|'g},dd tt|, D },|,|d< |rN|r<| |||d< ndgt|  |d< |
dkrt|d | j kr|r| j!"ddst#$dt|d  d| j  d d| j!d< |tjks|r| j%||
|j&||d}|rt|d |d< t'|||d}-|-S )a  
        Prepares a sequence of input id so that it can be used by the model. It adds special tokens, truncates
        sequences if overflowing while taking into account the special tokens.

        Args:
            raw_table (`pd.DataFrame`):
                The original table before any transformation (like tokenization) was applied to it.
            raw_query (`TextInput` or `PreTokenizedInput` or `EncodedInput`):
                The original query before any transformation (like tokenization) was applied to it.
            tokenized_table (`TokenizedTable`):
                The table after tokenization.
            query_tokens (`List[str]`):
                The query after tokenization.
            answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*):
                Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single
                list of one or more tuples. Each tuple must be a (row_index, column_index) pair. The first data row
                (not the column header row) has index 0. The first column has index 0.
            answer_text (`List[str]` or `List[List[str]]`, *optional*):
                Answer text of each table-question pair in the batch. The answer_text must be a single list of one or
                more strings. Each string must be the answer text of a corresponding answer coordinate.
        NF)NNr   r   T)truncation_strategyr   r   z?TAPAS does not return overflowing tokens as it works on tables.zCould not encode the query and table header given the maximum length. Encoding the query and table header results in a length of z( which is higher than the max_length of r   token_type_idsattention_masklabelsnumeric_valuesnumeric_values_scalec                 S   s   g | ]}t |qS r   )r   )r    rf   r   r   r"   rg      s     z4TapasTokenizer.prepare_for_model.<locals>.<listcomp>special_tokens_maskz4sequence-length-is-longer-than-the-specified-maximumzcToken indices sequence length is longer than the specified maximum sequence length for this model (z > zJ). Running this sequence through the model will result in indexing errors.)r   r   r   r   length)r   r   )(r   boolr   Z
MAX_LENGTHr   r%   r+   r,   _get_num_rows_get_num_columns_get_table_boundaries_get_truncated_table_rowsr   _get_table_valuesZconvert_tokens_to_idsr   r   ry   r   r   r   r   get_answer_idsadd_numeric_table_valuesadd_numeric_values_to_question_get_numeric_column_ranks_get_numeric_relationsmodel_input_namesr   _get_numeric_values_get_numeric_values_scaler   rc   Zdeprecation_warningsr   r   r   r   r   r   ).r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zis_part_of_batchr   r   num_rowsnum_columns_
num_tokensZ
table_datar   r   r   r>   r<   r=   prev_labelscolumn_ranksinv_column_ranksnumeric_relationsr   r   r   r   r   r   r   r   r"   r   e  s    1





     

*    




$
  z TapasTokenizer.prepare_for_model)r   r   r  r  r   r   r   c                 C   s   t |tst|}|dkr | j}|tjkr^| j|||||d}|dk	rHqx|d8 }|dk r*qxq*n|tjkrxtd| d||pdfS )a  
        Truncates a sequence pair in-place following the strategy.

        Args:
            query_tokens (`List[str]`):
                List of strings corresponding to the tokenized query.
            tokenized_table (`TokenizedTable`):
                Tokenized table
            num_rows (`int`):
                Total number of table rows
            num_columns (`int`):
                Total number of table columns
            max_length (`int`):
                Total maximum length.
            truncation_strategy (`str` or [`TapasTruncationStrategy`]):
                Truncation strategy to use. Seeing as this method should only be called when truncating, the only
                available strategy is the `"drop_rows_to_fit"` strategy.

        Returns:
            `Tuple(int, int)`: tuple containing the number of rows after truncation, and the number of tokens available
            for each table element.
        N)r  r  r   r   zUnknown truncation strategy .)r   r%   rc   r+   _get_max_num_tokensr,   ry   )r   r   r   r  r  r   r   r  r   r   r"   r  &  s(    

    
z(TapasTokenizer._get_truncated_table_rowsc              
   C   s   g }g }|D ],}| j r(|| d q|| | q|| | D ]0\}}g }|D ]}|| | q\|| qLg }t|D ]@\}	}t|D ].\}
}t|D ]\}}|t|	|
|d qqqt||dS )z
        Tokenizes column headers and cell texts of a table.

        Args:
            table (`pd.Dataframe`):
                Table. Returns: `TokenizedTable`: TokenizedTable object.
        r   r3   r2   r4   )r8   r9   )ra   ru   r   iterrowsrL   r1   r7   )r   r   Ztokenized_rowsZtokenized_rowcolumnr   rowcellZtoken_coordinatesr3   r2   r4   r  r   r   r"   r   _  s6    

zTapasTokenizer._tokenize_tablec                 C   s   t |d S )Nr   r   )r   question_tokensr   r   r"   _question_encoding_cost  s    z&TapasTokenizer._question_encoding_costc                 C   s   |dk	r|n| j | | S )al  
        Computes the number of tokens left for the table after tokenizing a question, taking into account the max
        sequence length of the model.

        Args:
            question_tokens (`List[String]`):
                List of question tokens. Returns: `int`: the number of tokens left for the table, given the model max
                length.
        N)rc   r  )r   r  r   r   r   r"   _get_token_budget  s    
z TapasTokenizer._get_token_budget)r   c           	      c   s   |j D ]}|j|d krq|j|kr&q|j|j |j }||j }|j}|dkrft|| rf|d8 }qH||krpqt||jd |jV  qdS )zFIterates over partial table and returns token, column and row indexes.r   r   N)r9   r3   r2   r8   r4   rB   
TableValue)	r   r   r  r  r  tcr  r-   Zword_begin_indexr   r   r"   r    s    



z TapasTokenizer._get_table_valuesc                 C   sj   d}d}d}|j D ]L}t||jd }t||jd }t||jd }t| j|}t| j|}q|||fS )z2Return maximal number of rows, columns and tokens.r   r   )r9   maxr2   r3   r4   minr_   r`   )r   r   max_num_tokensZmax_num_columnsZmax_num_rowsr"  r   r   r"   r    s    
z$TapasTokenizer._get_table_boundariesc                 C   s   t dd | ||||D S )Nc                 s   s   | ]
}d V  qdS )r   Nr   )r    r  r   r   r"   	<genexpr>  s     z1TapasTokenizer._get_table_cost.<locals>.<genexpr>)sumr  )r   r   r  r  r  r   r   r"   _get_table_cost  s    zTapasTokenizer._get_table_costc                 C   s   |  ||}| |\}}}| jdkr6|| jkr6| j}d}	t|d D ]$}	| ||||	d }
|
|krF qlqF|	|k r| jdkrdS |	dkrdS |	S )zCComputes max number of tokens that can be squeezed into the budget.r   r   N)r   r  r^   ranger(  )r   r  r   r  r  r   Ztoken_budgetr  r%  r  Zcostr   r   r"   r    s    
z"TapasTokenizer._get_max_num_tokensc                 C   s    |j d }|| jkrtd|S )Nr   zToo many columns)shaper_   ry   )r   r   r  r   r   r"   r    s    

zTapasTokenizer._get_num_columnsc                 C   s0   |j d }|| jkr,|r$| jd }ntd|S )Nr   r   zToo many rows)r*  r`   ry   )r   r   r&   r  r   r   r"   r    s    

zTapasTokenizer._get_num_rowsc                 C   sx   g }g }g }g }| | j | d | d | d |D ],}| | | d | d | d q>||||fS )z!Serializes texts in index arrays.r   )ru   ro   )r   r  r;   r>   r<   r=   r-   r   r   r"   _serialize_text  s    





zTapasTokenizer._serialize_textc                 C   s   |  |\}}}}	|| j |d |d |	d | ||||D ]2\}
}}||
 |d || |	| qLt||||	dS )zSerializes table and text.r   r   )r;   r>   r<   r=   )r+  ru   rm   r  r:   )r   r  r   r  r  r  r;   r>   r<   r=   r-   r.   r/   r   r   r"   
_serialize  s     	





zTapasTokenizer._serializec                 C   s6   i }|  D ]$\}}|| }|jd k	r|j||< q|S r   )r  numeric_value)r   r   	col_indextable_numeric_valuesr3   r  r  r   r   r"   _get_column_values  s    
z!TapasTokenizer._get_column_valuesc                 c   s<   t t|D ]*}|| d |kr|| d |kr|V  qd S r   )r)  r   )r   r<   r=   r.   r/   rP   r   r   r"   _get_cell_token_indexes  s     z&TapasTokenizer._get_cell_token_indexesc              	      s  dgt | }dgt | }|dk	rtt |jD ]}| ||}|sJq4zt|  W n tk
rr   Y q4Y nX  fdd| D }t	t
}| D ]\}	}
||
 |	 qt| }t|D ]H\}}
||
 D ]6}	| ||||	D ] }|d ||< t || ||< qqqq4||fS )z-Returns column ranks for all numeric columns.r   Nc                    s   i | ]\}}| |qS r   r   )r    r3   r   Zkey_fnr   r"   r#   6  s      z<TapasTokenizer._get_numeric_column_ranks.<locals>.<dictcomp>r   )r   r)  columnsr0  get_numeric_sort_key_fnvaluesry   rz   rH   defaultdictr   ru   r   keysrL   r1  )r   r<   r=   r   ZranksZ	inv_ranksr.  r/  Ztable_numeric_values_invr3   r   Zunique_valuesZrankrP   r   r2  r"   r
  #  s,    


z(TapasTokenizer._get_numeric_column_ranksc                 C   sD   |sdS t | }|| z
t|W S  tk
r>   Y dS X dS )a  
        Returns the sort key function for comparing value to table values. The function returned will be a suitable
        input for the key param of the sort(). See number_annotation_utils._get_numeric_sort_key_fn for details

        Args:
            table_numeric_values: Numeric values of a column
            value: Numeric value in the question

        Returns:
            A function key function to compare column and question values.
        N)r   r5  ru   r4  ry   )r   r/  r   Z
all_valuesr   r   r"   _get_numeric_sort_key_fnF  s    

z'TapasTokenizer._get_numeric_sort_key_fnc                 C   s  dgt | }tt}|dk	r|dk	r|jD ]~}|jD ]r}tt |jD ]^}	| ||	}
| 	|
|}|dkrpqJ|

 D ].\}}t|||}|dk	rx||	|f | qxqJq8q.|
 D ]b\\}	}}d}|D ],}|jtjjkst|d|jtjj  7 }q| |||	|D ]}|||< qq|S )a%  
        Returns numeric relations embeddings

        Args:
            question: Question object.
            column_ids: Maps word piece position to column id.
            row_ids: Maps word piece position to row id.
            table: The table containing the numeric cell values.
        r   Nr   )r   rH   r6  setnumeric_spansr5  r)  r3  r0  r8  rz   get_numeric_relationaddr   RelationEQr   r1  )r   questionr<   r=   r   r  Zcell_indices_to_relationsZnumeric_value_spanr   r2   r/  sort_key_fnr3   Z
cell_valueZrelationZ	relationsZrelation_set_indexZcell_token_indexr   r   r"   r  [  s,    


z%TapasTokenizer._get_numeric_relationsc                 C   s   t dgt| }|dk	r|jd }|jd }t|D ]h}t|D ]Z}|j||f j}	|	dk	rB|	jdkrjqB|	j}
|
t dkr~qB| ||||D ]}|
||< qqBq6|S )z6Returns numeric values for computation of answer loss.nanNr   r   inf)floatr   r*  r)  ilocr-  float_valuer1  )r   r   r<   r=   r   r  r  r.  r3   r-  rE  rP   r   r   r"   r    s     


z"TapasTokenizer._get_numeric_valuesc              	   C   s   dgt | }|dkr|S |jd }|jd }t|D ]L}t|D ]>}t| ||||}	t |	}
|
dkrB|	D ]}t|
||< qnqBq6|S )zDReturns a scale to each token to down weigh the value of long words.      ?Nr   r   )r   r*  r)  r   r1  rC  )r   r   r<   r=   r   r  r  r.  r3   indicesZnum_indicesrP   r   r   r"   r    s    

z(TapasTokenizer._get_numeric_values_scalec                 C   s6   t || jkr|  q t || jk r2|d qd S Nr   )r   rc   popru   )r   inputsr   r   r"   _pad_to_seq_length  s    
z!TapasTokenizer._pad_to_seq_lengthc                 C   s~   dgt | }t }t }|D ]F}|\}}	|||	f | ||||	D ]}
|||	f d||
< qHqt |t | }||fS )z2Maps lists of answer coordinates to token indexes.r   r   )r   r9  r<  r1  )r   r<   r=   answers_list
answer_idsZfound_answersZall_answersZanswersr2   r3   rP   missing_countr   r   r"   $_get_all_answer_ids_from_coordinates  s    z3TapasTokenizer._get_all_answer_ids_from_coordinatesc                 C   s   dd }| j ||||dS )a  
        Maps answer coordinates of a question to token indexes.

        In the SQA format (TSV), the coordinates are given as (row, column) tuples. Here, we first swap them to
        (column, row) format before calling _get_all_answer_ids_from_coordinates.
        c                 S   s   dd | D S )Nc                 S   s   g | ]}|d  |d fqS )r   r   r   )r    coordsr   r   r"   rg     s     zOTapasTokenizer._get_all_answer_ids.<locals>._to_coordinates.<locals>.<listcomp>r   )answer_coordinates_questionr   r   r"   _to_coordinates  s    z;TapasTokenizer._get_all_answer_ids.<locals>._to_coordinates)rL  )rO  )r   r<   r=   r   rR  r   r   r"   _get_all_answer_ids  s      z"TapasTokenizer._get_all_answer_idsc                 C   sj   t d| d|  tdt| t| D ]6}t|D ] \}}|||  j|jkr: q.q:|  S q.dS )z.Return start index of segment in text or None.ztext: r   r   N)r   infor)  r   rL   Zpiece)r   rT   segmentrP   Z	seg_indexZ	seg_tokenr   r   r"   _find_tokens  s    
zTapasTokenizer._find_tokensc                 c   sh   t d|  t|jD ]H\}}|dkr,qt|D ],\}}| ||}|dk	r4t|||dV  q4qdS )z4Returns all occurrences of answer_text in the table.zanswer text: r   Nr  )r   rT  rL   r8   rV  r1   )r   r   r   r3   r  r.  r  r4   r   r   r"   )_find_answer_coordinates_from_answer_text  s    z8TapasTokenizer._find_answer_coordinates_from_answer_textc              	   C   s   dgt | }|D ]}| ||D ]}t| j|||j|jd d}|  g }	|r|j|d  }
|
t | }|D ]}||
krr||k rr|	| qrt |	t |kr"|	D ]}d||< q qq"q|S )zDMaps question with answer texts to the first matching token indexes.r   r   )r.   r/   )	r   rW  r   r1  r2   r3   sortr4   ru   )r   r<   r=   r   answer_textsrM  r   ZcoordinatesZindexesZcoordinate_answer_idsbegin_index	end_indexrP   r   r   r"   "_find_answer_ids_from_answer_texts  s6    
z1TapasTokenizer._find_answer_ids_from_answer_textsc                 C   s"   |  |||\}}|rtd|S )z7Maps answer coordinates of a question to token indexes.zCouldn't find all answers)rS  ry   )r   r<   r=   r   rM  rN  r   r   r"   _get_answer_ids  s    zTapasTokenizer._get_answer_idsc                    s4    j r& j||| fdd|D dS  |||S )Nc                    s   g | ]}  |qS r   )r   )r    atr   r   r"   rg   ,  s     z1TapasTokenizer.get_answer_ids.<locals>.<listcomp>)rY  )rb   r\  r]  )r   r<   r=   r   Zanswer_texts_questionrQ  r   r   r"   r  &  s    zTapasTokenizer.get_answer_ids)r   r   padding_strategyr   r   r   c                 C   s  |dkrd| j k}|tjkr(t|d }|dk	rT|dk	rT|| dkrT|| d | }|tjkolt|d |k}|rd|krdgt|d  |d< |r|t|d  }| jdkr|r|d dg|  |d< d|kr|d | jgd g|  |d< d	|kr|d	 dg|  |d	< d
|kr6|d
 tdg|  |d
< d|krV|d dg|  |d< d|krv|d dg|  |d< |d | jg|  |d< n | jdkr|rdg| |d  |d< d|kr| jgd g| |d  |d< d	|krdg| |d	  |d	< d
|kr&tdg| |d
  |d
< d|krFdg| |d  |d< d|krfdg| |d  |d< | jg| |d  |d< nt	dt
| j |S )a?  
        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)

        Args:
            encoded_inputs:
                Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
            max_length: maximum length of the returned list and optionally padding length (see below).
                Will truncate by taking into account the special tokens.
            padding_strategy: PaddingStrategy to use for padding.

                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
                - PaddingStrategy.DO_NOT_PAD: Do not pad
                The tokenizer padding sides are defined in self.padding_side:

                    - 'left': pads on the left of the sequences
                    - 'right': pads on the right of the sequences
            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
                `>= 7.5` (Volta).
            return_attention_mask:
                (optional) Set to False to avoid returning attention mask (default: set to model specifics)
        Nr   r   r   r   rightr      r   r   rA  r   rF  r   leftzInvalid padding strategy:)r  r   ZLONGESTr   r   Zpadding_sideZpad_token_type_idrC  Zpad_token_idry   r   )r   r   r   r_  r   r   Zneeds_to_be_padded
differencer   r   r"   _pad0  s`     













zTapasTokenizer._padc           
      c   sX   t |D ]J\}}|| }|| d }|| d }	|dkr|	dkr|dkr||fV  qd S )Nr   r   )rL   )
r   probabilitiesr>   r=   r<   ipZ
segment_idcolr  r   r   r"   _get_cell_token_probs  s    z$TapasTokenizer._get_cell_token_probsc           
      C   s`   t t}| ||||D ]2\}}|| d }|| d }	|||	f | qdd | D S )z?Computes average probability per cell, aggregating over tokens.r   c                 S   s    i | ]\}}|t | qS r   )nparrayZmean)r    rP  Z
cell_probsr   r   r"   r#     s      z7TapasTokenizer._get_mean_cell_probs.<locals>.<dictcomp>)rH   r6  r   ri  ru   rz   )
r   re  r>   r=   r<   Zcoords_to_probsrf  Zprobrh  r  r   r   r"   _get_mean_cell_probs  s    
z#TapasTokenizer._get_mean_cell_probs      ?c                 C   s  |  }|dk	r|  }dd | D }d||dk < ddt|   |d  }ddd	d
dddg}|d }|d dddd|df }|d dddd|d	f }	|d dddd|df }
|jd }g }t|D ]}||  }|| }|	| }|
| }| }| }|dkr0|dkr0q| 	|| | | }g }t|D ]H}t|D ]8}|
||fd}|dk	rd||krd|||f qdqXt|}|| q|f}|dk	r|jdd}|| f}|S )a  
        Converts logits of [`TapasForQuestionAnswering`] to actual predicted answer coordinates and optional
        aggregation indices.

        The original implementation, on which this function is based, can be found
        [here](https://github.com/google-research/tapas/blob/4908213eb4df7aa988573350278b44c4dbe3f71b/tapas/experiments/prediction_utils.py#L288).

        Args:
            data (`dict`):
                Dictionary mapping features to actual values. Should be created using [`TapasTokenizer`].
            logits (`torch.Tensor` or `tf.Tensor` of shape `(batch_size, sequence_length)`):
                Tensor containing the logits at the token level.
            logits_agg (`torch.Tensor` or `tf.Tensor` of shape `(batch_size, num_aggregation_labels)`, *optional*):
                Tensor containing the aggregation logits.
            cell_classification_threshold (`float`, *optional*, defaults to 0.5):
                Threshold to be used for cell selection. All table cells for which their probability is larger than
                this threshold will be selected.

        Returns:
            `tuple` comprising various elements depending on the inputs:

            - predicted_answer_coordinates (`List[List[[tuple]]` of length `batch_size`): Predicted answer coordinates
              as a list of lists of tuples. Each element in the list contains the predicted answer coordinates of a
              single example in the batch, as a list of tuples. Each tuple is a cell, i.e. (row index, column index).
            - predicted_aggregation_indices (`List[int]`of length `batch_size`, *optional*, returned when
              `logits_aggregation` is provided): Predicted aggregation operator indices of the aggregation head.
        Nc                 S   s"   i | ]\}}|d kr||  qS )Ztraining)numpy)r    r   r   r   r   r"   r#     s       z@TapasTokenizer.convert_logits_to_predictions.<locals>.<dictcomp>g,Vr   r   r>   r<   r=   r  r  r  r  r   r   r   r]   )Zaxis)rn  rz   rj  exprP   r*  r)  tolistr#  rl  r   ru   r   Zargmax)r   dataZlogitsZ
logits_aggZcell_classification_thresholdre  Ztoken_typesr   r>   r=   r<   Z	num_batchZpredicted_answer_coordinatesrf  Zprobabilities_exampleZsegment_ids_exampleZrow_ids_exampleZcolumn_ids_example	max_width
max_heightZcell_coords_to_probr   rh  r  Z	cell_proboutputZpredicted_aggregation_indicesr   r   r"   convert_logits_to_predictions  sb       



z,TapasTokenizer.convert_logits_to_predictions)TTNrW   rX   rY   rZ   r[   r\   TNr]   NNFFNNr   N)N)N)NF)NNNTFFNNNNNFFFFT)NNNTFFNNNNNFFFFT)NNTFFNNNTNFFFFT)NNNNTFFNNNTTFFFTF)NTFFNN)NNNTFFNNNNNFFFT)NNTFFNNNTTFFFT)NNNNTFFNNNTTFFFTF)N)N)Nrm  )Zr'   r(   r)   r*   r   Zvocab_files_namesPRETRAINED_VOCAB_FILES_MAPZpretrained_vocab_files_map&PRETRAINED_POSITIONAL_EMBEDDINGS_SIZESZmax_model_input_sizesr5   r  r   r   r   r   propertyr$   r   r   r   r   r   r   r	   r   r!  r   r   r   r   r   r   r   -TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRINGr
   r   r   r   r   r%   r   r   r   r   r   r   r   r7   r   r   r   r   r   r  r   r  r   r   r  r  r(  r  r  r  r+  r,  r0  r1  r
  r8  r  r  r  rK  rO  rS  rV  rW  r\  r]  r  r   r   r   rd  ri  rl  ru  __classcell__r   r   r   r"   rV      sx  O                    
`

     
    
                 
t
                `               A                 G
      *
               

[              

7	                 

 C

; 
-
	#)&^	rV   c                   @   sN   e Zd ZdZdddZdddZdd	 Zdd
dZdd Zdd Z	dd Z
dS )r|   a  
    Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).

    Args:
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.

            This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
        do_split_on_punc (`bool`, *optional*, defaults to `True`):
            In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
            the full context of the words, such as contractions.
    TNc                 C   s2   |d krg }|| _ t|| _|| _|| _|| _d S r   )r$   r9  rh   ri   rj   do_split_on_punc)r   r$   rh   ri   rj   r{  r   r   r"   r     s    
zBasicTokenizer.__init__c                 C   s   |r| j t|n| j }| |}| jr4| |}td|}t|}g }|D ]R}||kr| j	r|
 }| jdk	r| |}n| jr| |}|| || qPtd|}|S )aj  
        Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.

        Args:
            never_split (`List[str]`, *optional*)
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of token not to split.
        NFCFr   )rh   unionr9  _clean_textri   _tokenize_chinese_charsunicodedata	normalizerU   r$   lowerrj   _run_strip_accentsextend_run_split_on_puncr   )r   rT   rh   Zunicode_normalized_textZorig_tokensr   r-   output_tokensr   r   r"   r   ,  s$    




zBasicTokenizer.tokenizec                 C   sB   t d|}g }|D ]"}t |}|dkr,q|| qd|S )z$Strips accents from a piece of text.ZNFDZMnr   )r  r  categoryru   r   )r   rT   rt  charcatr   r   r"   r  R  s    
z!BasicTokenizer._run_strip_accentsc                 C   s   | j r|dk	r||kr|gS t|}d}d}g }|t|k r|| }t|r^||g d}n |rl|g  d}|d | |d7 }q0dd |D S )	z&Splits punctuation on a piece of text.Nr   TFr]   r   c                 S   s   g | ]}d  |qS )r   )r   )r    xr   r   r"   rg   q  s     z5BasicTokenizer._run_split_on_punc.<locals>.<listcomp>)r{  r   r   r   ru   )r   rT   rh   charsrf  Zstart_new_wordrt  r  r   r   r"   r  ]  s"    

z!BasicTokenizer._run_split_on_puncc                 C   sT   g }|D ]@}t |}| |r>|d || |d q|| qd|S )z)Adds whitespace around any CJK character.r   r   )ord_is_chinese_charru   r   r   rT   rt  r  cpr   r   r"   r  s  s    


z&BasicTokenizer._tokenize_chinese_charsc                 C   s   |dkr|dks|dkr |dks|dkr0|dks|dkr@|dks|d	krP|d
ks|dkr`|dks|dkrp|dks|dkr|dkrdS dS )z6Checks whether CP is the codepoint of a CJK character.i N  i  i 4  iM  i   iߦ i  i? i@ i i  i i   i  i  i TFr   )r   r  r   r   r"   r    sD    
zBasicTokenizer._is_chinese_charc                 C   sX   g }|D ]D}t |}|dks|dkst|r.qt|rB|d q|| qd|S )zBPerforms invalid character removal and whitespace cleanup on text.r   i  r   r   )r  r   r   ru   r   r  r   r   r"   r~    s    zBasicTokenizer._clean_text)TNTNT)N)N)r'   r(   r)   r*   r   r   r  r  r  r  r~  r   r   r   r"   r|     s        

&
r|   c                   @   s"   e Zd ZdZdddZdd ZdS )	r~   zRuns WordPiece tokenization.d   c                 C   s   || _ || _|| _d S r   )rN   rk   max_input_chars_per_word)r   rN   rk   r  r   r   r"   r     s    zWordpieceTokenizer.__init__c                 C   s   g }t |D ]}t|}t|| jkr4|| j qd}d}g }|t|k rt|}d}	||k rd||| }
|dkrd|
 }
|
| jkr|
}	q|d8 }qX|	dkrd}q||	 |}q@|r|| j q|| q|S )a  
        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
        tokenization using the given vocabulary.

        For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through *BasicTokenizer*.

        Returns:
            A list of wordpiece tokens.
        Fr   Nr   r@   r   T)	rU   r   r   r  ru   rk   r   rN   r  )r   rT   r  r-   r  Zis_badstartZ
sub_tokensendZ
cur_substrsubstrr   r   r"   r     s:    


zWordpieceTokenizer.tokenizeN)r  )r'   r(   r)   r*   r   r   r   r   r   r"   r~     s   
r~   c                   @   s0   e Zd ZdZdZdZdZdZdZdZ	dZ
d	Zd
S )r=  r   r   r            ra     	   N)r'   r(   r)   ZHEADER_TO_CELLZCELL_TO_HEADERZQUERY_TO_HEADERZQUERY_TO_CELLZROW_TO_CELLZCELL_TO_ROWr>  LTGTr   r   r   r"   r=    s   r=  c                   @   s>   e Zd ZU dZee ed< dZee ed< dZee ed< dS )DateNyearmonthday)	r'   r(   r)   r  r   r5   r6   r  r  r   r   r   r"   r    s   
r  c                   @   s.   e Zd ZU dZee ed< dZee ed< dS )NumericValueNrE  date)	r'   r(   r)   rE  r   rC  r6   r  r  r   r   r   r"   r    s   
r  c                   @   s6   e Zd ZU dZeed< dZeed< dZee	 ed< dS )NumericValueSpanNrZ  r[  r5  )
r'   r(   r)   rZ  r5   r6   r[  r5  r   r  r   r   r   r"   r  	  s   
r  c                   @   s&   e Zd ZU eed< dZee ed< dS )CellrT   Nr-  )r'   r(   r)   r   r6   r-  r   r  r   r   r   r"   r  		  s   
r  c                   @   s2   e Zd ZU eed< eed< dZeee  ed< dS )Questionoriginal_textrT   Nr:  )	r'   r(   r)   r   r6   r:  r   r   r  r   r   r   r"   r  	  s   
r  	_DateMaskr  r  r  F%B%Yz%Ysz%b %Yz%B %Yz%B %dz%b %dz%d %bz%d %Bz	%B %d, %Yz%d %B %Yz%m-%d-%Yz%Y-%m-%dz%Y-%mz%d %b %Yz	%b %d, %Yz%d.%m.%Yz	%A, %b %dz	%A, %B %d))z%A\w+)r  r  )r  z\d{4})z%bz\w{3})z%d\d{1,2})z%mr  c                 C   s~   | \}}|}| dtd}| dtd}| dd}tD ]\}}| ||}q@d|ksft|||td| d fS )z<Compute a regex for each date pattern to use as a prefilter.r  r   r   z\s+%^$)r   reescape_FIELD_TO_REGEXr   compile)dppatternmaskregexfieldZfield_regexr   r   r"   _process_date_patternK	  s    r  c                   C   s   t dd tD S )Nc                 s   s   | ]}t |V  qd S r   )r  )r    r  r   r   r"   r&  Z	  s     z)_process_date_patterns.<locals>.<genexpr>)r   _DATE_PATTERNSr   r   r   r"   _process_date_patternsY	  s    r  r  zeroonetwothreefourfivesixseveneightninetenZelevenZtwelveZzerothfirstsecondthirdZfourthZfithZsixthZseventhZeighthZninthZtenthZeleventhZtwelfthstndrdthz.((^|\s)[+-])?((\.\d+)|(\d+(,\d\d\d)*(\.\d*)?))i  i  ZINFc                 C   s^   | j tk s| j tkr$td| j  t }|j r8| j |_ |jrF| j|_|jrT| j|_t|dS )zYConverts date (datetime Python object) to a NumericValue object with a Date object value.zInvalid year: )r  )r  	_MIN_YEAR	_MAX_YEARry   r  r  r  r  )r  r  Znew_dater   r   r"   _get_numeric_value_from_date	  s    r  c                 C   s   | d | d  | d  fS )zHSorts span by decreasing length first and increasing first index second.r   r   r   )spanr   r   r"   _get_span_length_key	  s    r  c                 C   s
   t | dS )zDConverts float (Python) to a NumericValue object with a float value.)rE  )r  r   r   r   r"   _get_numeric_value_from_float	  s    r  c              	   C   s   t dd| } tD ]p\}}}|| s(qztj| | }W n tk
rV   Y qY nX zt||W   S  tk
r   Y qY qX qdS )zAAttempts to format a text as a standard date string (yyyy-mm-dd).zSept\bSepN)	r  sub_PROCESSED_DATE_PATTERNSmatchdatetimestrptimer  ry   r  )rT   Z
in_patternr  r  r  r   r   r"   _parse_date	  s    

r  c                 C   sx   t D ]$}| |r| dt|  }  q*q| dd} zt| }W n tk
rX   Y dS X t|rhdS |tkrtdS |S )z,Parses simple cardinal and ordinals numbers.N,r   )	_ORDINAL_SUFFIXESendswithr   r   rC  ry   mathisnan_INF)rT   suffixr   r   r   r"   _parse_number	  s    

r  c                 c   s   g }t | D ]v\}}| sq|dks6| |d   s@|| |d t| ks`| |d   s|| d D ]}||d fV  qnqdS )z
    Split a text into all possible ngrams up to 'max_ngram_length'. Split points are white space and punctuation.

    Args:
      text: Text to split.
      max_ngram_length: maximal ngram length.
    Yields:
      Spans, tuples of begin-end index.
    r   r   N)rL   isalnumru   r   )rT   max_ngram_lengthZstart_indexesrP   r  start_indexr   r   r"   get_all_spans	  s    

 r  c                 C   s   d |   S )Nr   )r   r  rS   rT   r   r   r"   normalize_for_match	  s    r  c                 C   s`   |    } | dks$| dks$| dkr(t} tdd| dd} d|  } |  } | r\| S tS )z"Lowercases and strips punctuation.zn/a?rA  z[^\w\d]+r   r  )r  rR   r   r  r  r   r   rS   r  r   r   r"   r   	  s    r   c                 C   s  t t}t| D ]>}| | |  }t|}|dk	r||  	t
| qt| ddD ]\}}||f|krvq`| || }t|}|dk	r|||f 	t
| ttD ].\}}||kr|||f 	t
t|  qqttD ].\}}||kr|||f 	t
t|  q`qq`t| tdD ]:\}}| || }t|}|dk	r&|||f 	| q&t| dd dd}	g }
|	D ]R\}}|
D ]4\}}|d |d kr|d |d kr qq|
	||f q|
jd	d d
 g }|
D ]&\}}|	t|d |d |d q|S )z
    Extracts longest number and date spans.

    Args:
      text: text to annotate

    Returns:
      List of longest numeric value spans.
    Nr   )r  c                 S   s   t | d S rH  )r  Z
span_valuer   r   r"   r   
  r   zparse_text.<locals>.<lambda>T)r   reverser   c                 S   s   | d d S rH  r   r  r   r   r"   r   $
  r   r   )rZ  r[  r5  )rH   r6  r   _NUMBER_PATTERNfinditerr  r  r  r  ru   r  r  rL   _NUMBER_WORDSrC  _ORDINAL_WORDS_MAX_DATE_NGRAM_SIZEr  r   rz   rX  r  )rT   Z	span_dictr  Z	span_textnumberrZ  r[  wordr  spansZselected_spansr  r   Zselected_spanr  Znumeric_value_spansr5  r   r   r"   
parse_text	  sL    


$
 r  ZEMPTYr  r  c                 C   s.   | j d k	rtS | jd k	rtS td|  d S )NUnknown type: )rE  NUMBER_TYPEr  	DATE_TYPEry   r-  r   r   r"   _get_value_type=
  s
    

r  c                 C   s   | j dk	r| j S | jdk	rz| j}dddg}|jdk	rBt|j|d< |jdk	rZt|j|d< |jdk	rrt|j|d< t|S td|  dS )z7Maps a NumericValue proto to a float or tuple of float.Nr   r   r   r  )rE  r  r  rC  r  r  r   ry   )r-  r  Zvalue_tupler   r   r"   _get_value_as_primitive_valueE
  s    





r  c                 C   s   dd | D S )Nc                 S   s   h | ]}t |qS r   )r  )r    r   r   r   r"   	<setcomp>X
  s     z!_get_all_types.<locals>.<setcomp>r   )r   r   r   r"   _get_all_typesW
  s    r  c                    s   t | }t|dkr"td|  tt|}|tkr:tS ttt	 | D ]>}t|}t
|tsdtt|D ]\}}|dkrl | qlqJ std|   fdd}|S )a$  
    Creates a function that can be used as a sort key or to compare the values. Maps to primitive types and finds the
    biggest common subset. Consider the values "05/05/2010" and "August 2007". With the corresponding primitive values
    (2010.,5.,5.) and (2007.,8., None). These values can be compared by year and date so we map to the sequence (2010.,
    5.), (2007., 8.). If we added a third value "2006" with primitive value (2006., None, None), we could only compare
    by the year so we would map to (2010.,), (2007.,) and (2006.,).

    Args:
     numeric_values: Values to compare

    Returns:
     A function that can be used as a sort key function (mapping numeric values to a comparable tuple)

    Raises:
      ValueError if values don't have a common type or are not comparable.
    r   zNo common value type in NzNo common value in c                    s   t |  t fddD S )Nc                 3   s   | ]} | V  qd S r   r   )r    rP   r  r   r"   r&  
  s     z@get_numeric_sort_key_fn.<locals>._sort_key_fn.<locals>.<genexpr>)r  r   r  Zvalid_indexesr  r"   _sort_key_fn
  s    z-get_numeric_sort_key_fn.<locals>._sort_key_fn)r  r   ry   nextiterr  r  r9  r)  _DATE_TUPLE_SIZEr   r   r   rL   discard)r   Zvalue_types
value_typer-  r   Ztuple_indexZinner_valuer  r   r  r"   r4  [
  s"    r4  c                 C   s   t  }|  D ]}|t| q|s,i S t| }|t| | k rLi S t }| D ]\}}||krZ|	| qZt|dkrt
|kstt
}	ntt|}	i }
|  D ]*\}}|D ]}t||	kr||
|<  qqq|
S )a(  
    Finds the most common numeric values in a column and returns them

    Args:
        row_index_to_values:
            For each row index all the values in that cell.
        min_consolidation_fraction:
            Fraction of cells that need to have consolidated value.
        debug_info:
            Additional information only used for logging

    Returns:
        For each row index the first value that matches the most common value. Rows that don't have a matching value
        are dropped. Empty list if values can't be consolidated.
    r   )rH   Counterr5  updater  r#  r   r9  rz   r<  r  r   r  r  r  )Zrow_index_to_valuesmin_consolidation_fraction
debug_infoZtype_countsr   	max_countZvalid_typesr  countZmax_typeZnew_row_index_to_valuerP   r5  r   r   r   r"   _consolidate_numeric_values
  s.    r	  c                 C   s   t | }tjdd |D  S )z'Parses text and returns numeric values.c                 s   s   | ]}|j V  qd S r   )r5  )r    r  r   r   r"   r&  
  s     z&_get_numeric_values.<locals>.<genexpr>)r  	itertoolschain)rT   r:  r   r   r"   r  
  s    r  c                 C   s8   i }|   D ]&\}}t|| j}tt|||< q|S )a=  
    Parses text in column and returns a dict mapping row_index to values. This is the _get_column_values function from
    number_annotation_utils.py of the original implementation

    Args:
      table: Pandas dataframe
      col_index: integer, indicating the index of the column to get the numeric values of
    )r  r  rT   r   r  )r   r.  Zindex_to_valuesr3   r  rT   r   r   r"   r0  
  s
    	r0  c                 C   s>   || } ||}| |krt jS | |k r,t jS | |kr:t jS dS )z7Compares two values and returns their relation or None.N)r=  r>  r  r  )r   Zother_valuer@  r   r   r"   r;  
  s    r;  c                 C   s"   | }t | } t| }t|| |dS )z'Adds numeric value spans to a question.)r  rT   r:  )r  r  r  )r?  r  r:  r   r   r"   r	  
  s    r	  c                 C   s   t | trdS | dfS )z@Return an empty string and True if 'text' is in invalid unicode.)r   TF)r   bytesr  r   r   r"   filter_invalid_unicode
  s    r  c              
   C   s   t | dsd| _|  D ]H\}}t|D ]6\}}t|\}}|r(td| j d| d|  q(qt| jD ]0\}}t|\}}|rltd| j d|  qldS )z
    Removes invalid unicode from table. Checks whether a table cell text contains an invalid unicode encoding. If yes,
    reset the table cell text to an empty str and log a warning for each invalid cell

    Args:
        table: table to clean.
    table_idr   z(Scrub an invalid table body @ table_id: z, row_index: z, col_index: z*Scrub an invalid table header @ table_id: N)hasattrr  r  rL   r  r   r   r3  )r   r3   r  r.  r  Z
is_invalidr  r   r   r"   !filter_invalid_unicode_from_table
  s    	
r  ffffff?c           
      C   s   |   } t|  |  D ].\}}t|D ]\}}t|d| j||f< q(qt| jD ]B\}}tt| ||||fd}|	 D ]\}}	|	| j||f _
qzqR| S )a  
    Parses text in table column-wise and adds the consolidated values. Consolidation refers to finding values with a
    common types (date or number)

    Args:
        table:
            Table to annotate.
        min_consolidation_fraction:
            Fraction of cells in a column that need to have consolidated value.
        debug_info:
            Additional information used for logging.
    r  )r  r  )copyr  r  rL   r  rD  r3  r	  r0  rz   r-  )
r   r  r  r3   r  r.  r  r  Zcolumn_valuesr-  r   r   r"   r    s    r  )r  N)yr*   rH   r  enumr
  r  rv   r  r  dataclassesr   typingr   r   r   r   r   r   r	   r
   rn  rj  Ztokenization_utilsr   r   r   r   Ztokenization_utils_baser   r   r   r   r   r   utilsr   r   r   r   r   r   Zpandasr   Z
get_loggerr'   r   r   rv  r7  rw  ZPRETRAINED_INIT_CONFIGURATIONr%   
namedtupler!  r1   r7   r:   rB   rQ   rU   ry  rV   objectr|   r~   Enumr=  r  r  r  r  r  r  Z_YEARZ_YEAR_MONTHZ_YEAR_MONTH_DAYZ_MONTHZ
_MONTH_DAYr  r  r  r  r  r  r  r  r  r  r  r  r  rC  r  r  r  r  r  r  r  r  r   r  Z_PrimitiveNumericValueEllipsisZ
_SortKeyFnr   r   r  r  r  r  r  r4  r	  r  r0  r;  r	  r  r  r  r   r   r   r"   <module>   sj  (  
S		*              & #B

?"//