U
    ,-eʴ                    @   s  d Z ddlZddlZddlZddlmZ ddlmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZmZmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZmZmZmZ ddlm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z& e#'e(Z)ese)*de
j+ d dZ,dZ-ddddddddddddddd d!d"d#d$d%d&d'd(d)d*gZ.d+Z/d,Z0eG d-d. d.e Z1d/d0 Z2G d1d2 d2ej3Z4G d3d4 d4ej3Z5G d5d6 d6ej3Z6G d7d8 d8ej3Z7G d9d: d:ej3Z8G d;d< d<ej3Z9G d=d> d>ej3Z:G d?d@ d@ej3Z;G dAdB dBej3Z<G dCdD dDej3Z=G dEdF dFej3Z>G dGdH dHej3Z?G dIdJ dJeZ@dKZAdLZBe!dMeAG dNdO dOe@ZCe!dPeAG dQdR dRe@ZDe!dSeAG dTdU dUe@ZEe!dVeAG dWdX dXe@ZFG dYdZ dZeGejHZIG d[d\ d\eJZKG d]d^ d^eKZLdd`daZMddcddZNddfdeZOdgdh ZPddjdkZQddmdnZRddpdqZSddsdtZTdudv ZUdwdx ZVdydz ZWd{d| ZXd}d~ ZYdd ZZdd Z[dd Z\de]dddZ^dd Z_dS )zPyTorch TAPAS model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputSequenceClassifierOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indices#is_torch_greater_or_equal_than_1_12prune_linear_layer)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )TapasConfigzYou are using torch==zH, but torch>=1.12.0 is required to use TapasModel. Please upgrade torch.r   zgoogle/tapas-basezgoogle/tapas-largez google/tapas-large-finetuned-sqaz google/tapas-large-finetuned-wtqz/google/tapas-large-finetuned-wikisql-supervisedz$google/tapas-large-finetuned-tabfactzgoogle/tapas-base-finetuned-sqazgoogle/tapas-base-finetuned-wtqz.google/tapas-base-finetuned-wikisql-supervisedz#google/tapas-base-finetuned-tabfactzgoogle/tapas-smallz google/tapas-small-finetuned-sqaz google/tapas-small-finetuned-wtqz/google/tapas-small-finetuned-wikisql-supervisedz$google/tapas-small-finetuned-tabfactzgoogle/tapas-minizgoogle/tapas-mini-finetuned-sqazgoogle/tapas-mini-finetuned-wtqz.google/tapas-mini-finetuned-wikisql-supervisedz#google/tapas-mini-finetuned-tabfactzgoogle/tapas-tinyzgoogle/tapas-tiny-finetuned-sqazgoogle/tapas-tiny-finetuned-wtqz.google/tapas-tiny-finetuned-wikisql-supervisedz#google/tapas-tiny-finetuned-tabfactg|=g     c                   @   sl   e Zd ZU dZdZeej ed< dZ	ejed< dZ
ejed< dZeeej  ed< dZeeej  ed< dS )TableQuestionAnsweringOutputa  
    Output type of [`TapasForQuestionAnswering`].

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` (and possibly `answer`, `aggregation_labels`, `numeric_values` and `numeric_values_scale` are provided)):
            Total loss as the sum of the hierarchical cell selection log-likelihood loss and (optionally) the
            semi-supervised regression loss and (optionally) supervised loss for aggregations.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Prediction scores of the cell selection head, for every token.
        logits_aggregation (`torch.FloatTensor`, *optional*, of shape `(batch_size, num_aggregation_labels)`):
            Prediction scores of the aggregation head, for every aggregation operator.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlosslogitslogits_aggregationhidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r    r   r!    r)   r)   i/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/tapas/modeling_tapas.pyr   a   s   
r   c                 C   s  zddl }ddl}ddl}W n  tk
r<   td  Y nX tj|}t	d|  |j
|}g }g }	|D ]@\}
}t	d|
 d|  |j
||
}||
 |	| qrt||	D ]\}
}|
d}
tdd	 |
D rt	d
d|
  qt| tr4tdd	 |
D r4t	d
d|
  qt| trltdd	 |
D rlt	d
d|
  qt| trtdd	 |
D rt	d
d|
  q|
d dkrd|
d< | }|
D ]}|d|r|d|}n|g}|d dks|d dkrt|d}nv|d dkr0t|d}nZ|d dkrdt| tsVt|d}n
t|d}n&|d dkrt|d}n
|d dkrt|d}n|d dkrt|d}n|d dkrt|d}t|d}n|d dkrt|d}t|d}n|d dkr t|d}t|d}nj|d d krDt|d}t|d}nFzt||d }W n2 tk
r   t	d
d|
  Y qY nX t|d!krt|d" }|| }q|d#d d$krt|d}n@|d%d d&d' td(D krt|d}n|dkr||}z,|j|jkr8td)|j d*|j d+W n< tk
rv } z| j |j|jf7  _  W 5 d}~X Y nX t	d,|
  |!|r|"|}t#$||_%q| S )-z
    Load tf checkpoints in a PyTorch model. This is an adaptation from load_tf_weights_in_bert

    - add cell selection and aggregation heads
    - take into account additional token type embedding layers
    r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape /c                 s   s   | ]}|d kV  qdS ))Zadam_vZadam_mZAdamWeightDecayOptimizerZAdamWeightDecayOptimizer_1Zglobal_stepZseq_relationshipNr)   .0nr)   r)   r*   	<genexpr>   s   
z+load_tf_weights_in_tapas.<locals>.<genexpr>z	Skipping c                 s   s   | ]}|d kV  qdS ))output_biasoutput_weightsNr)   r,   r)   r)   r*   r/      s     c                 s   s   | ]}|d kV  qdS ))r0   r1   output_bias_clsoutput_weights_clsNr)   r,   r)   r)   r*   r/      s     c                 s   s   | ]}|d kV  qdS ))poolerNr)   r,   r)   r)   r*   r/      s     Zberttapasz[A-Za-z]+_\d+z_(\d+)Zkernelgammaweightbetabiasr0   r1   column_output_biascolumn_output_weightsZoutput_bias_aggaggregation_classifierZoutput_weights_aggr2   
classifierr3      r   iZ_embeddingsic                 S   s   g | ]}d | qS )Z_embeddings_r)   )r-   ir)   r)   r*   
<listcomp>   s     z,load_tf_weights_in_tapas.<locals>.<listcomp>   zPointer shape z and array shape z mismatchedzInitialize PyTorch weight )&renumpyZ
tensorflowImportErrorloggererrorospathabspathinfotrainZlist_variablesZload_variableappendzipsplitanyjoin
isinstanceTapasForSequenceClassification
TapasModelTapasForMaskedLM	fullmatchgetattrAttributeErrorlenintrange	transposeshape
ValueErrorAssertionErrorargsZisscalararrayr&   Z
from_numpydata)modelconfigZtf_checkpoint_pathrB   nptfZtf_pathZ	init_varsnamesZarraysnamer\   r`   ZpointerZm_nameZscope_namesnumer)   r)   r*   load_tf_weights_in_tapas   s    







 


rj   c                       s*   e Zd ZdZ fddZdddZ  ZS )TapasEmbeddingsz
    Construct the embeddings from word, position and token_type embeddings. Same as BertEmbeddings but with a number of
    additional token type embeddings to encode tabular structure.
    c                    s   t    tj|j|j|jd| _t|j|j| _	t
|jD ](\}}d| }t| |t||j q>t|j| _tj|j|jd| _t|j| _|| _d S )N)padding_idxtoken_type_embeddings_Zeps)super__init__r   	Embedding
vocab_sizehidden_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddings	enumeratetype_vocab_sizessetattrrX   number_of_token_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutrc   )selfrc   r?   rx   rg   	__class__r)   r*   rp     s    

zTapasEmbeddings.__init__Nc                 C   s  |d k	r|  }n|  d d }|d }|d k	r8|jn|j}|d kr tj|tj|d}|d|}| jjr t	|d d d d df | jj
d dd}t	|d d d d df | jj
d dd}	t||	}
t||
d }t||
}tj|tj|dd}ttj| jjd |d|| }|d krBtj|| j tj|d}|d krV| |}| |}|| }t| jD ]4}d| }|t| ||d d d d |f 7 }qr| |}| |}|S )	Nr   dtypedevicer   )
batch_dimsr>   r   rm   )sizer   r&   arangelong	unsqueezeexpandrc   Zreset_position_index_per_cellIndexMaprx   ProductIndexMap
reduce_mingathermin	as_tensorru   zerosrz   rt   rv   rZ   rV   r{   r   )r   	input_idstoken_type_idsposition_idsinputs_embedsinput_shapeZ
seq_lengthr   	col_index	row_indexZ
full_indexZfirst_position_per_segmentZfirst_positionpositionrv   
embeddingsr?   rg   r)   r)   r*   forward   sF    


((

 
  



(

zTapasEmbeddings.forward)NNNN)r"   r#   r$   r%   rp   r   __classcell__r)   r)   r   r*   rk     s   rk   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
TapasSelfAttentionc                    s   t    |j|j dkr<t|ds<td|j d|j |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads )ro   rp   rs   num_attention_headshasattrr]   rY   attention_head_sizeall_head_sizer   Linearquerykeyvaluer}   Zattention_probs_dropout_probr   
is_decoderr   rc   r   r)   r*   rp   U  s    
zTapasSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )Nr   r   r>   r   r
   )r   r   r   viewpermute)r   xZnew_x_shaper)   r)   r*   transpose_for_scoresh  s    
z'TapasSelfAttention.transpose_for_scoresNFc                 C   s  |  |}|d k	}	|	r4|d k	r4|d }
|d }|}n|	r^| | |}
| | |}|}nv|d k	r| | |}
| | |}tj|d |
gdd}
tj|d |gdd}n | | |}
| | |}| |}| jr|
|f}t||
dd}|t	
| j }|d k	r"|| }tjj|dd}| |}|d k	rN|| }t||}|dddd }| d d | jf }|j| }|r||fn|f}| jr||f }|S )Nr   r   r>   dimr   r
   )r   r   r   r   r&   catr   matmulr[   mathsqrtr   r   
functionalsoftmaxr   r   
contiguousr   r   r   )r   r    attention_mask	head_maskencoder_hidden_statesencoder_attention_maskpast_key_valueoutput_attentionsZmixed_query_layerZis_cross_attentionZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr)   r)   r*   r   m  sH    







zTapasSelfAttention.forward)NNNNNF)r"   r#   r$   rp   r   r   r   r)   r)   r   r*   r   T  s         r   c                       s4   e Zd Z fddZejejejdddZ  ZS )TapasSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nrn   )ro   rp   r   r   rs   denser{   r|   r}   r~   r   r   r   r)   r*   rp     s    
zTapasSelfOutput.__init__r    input_tensorreturnc                 C   s&   |  |}| |}| || }|S Nr   r   r{   r   r    r   r)   r)   r*   r     s    

zTapasSelfOutput.forwardr"   r#   r$   rp   r&   Tensorr   r   r)   r)   r   r*   r     s   r   c                
       st   e Zd Z fddZdd Zd
ejeej eej eej eej ee	e	ej   ee
 e	ej ddd	Z  ZS )TapasAttentionc                    s*   t    t|| _t|| _t | _d S r   )ro   rp   r   r   r   outputsetpruned_headsr   r   r)   r*   rp     s    


zTapasAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )rX   r   r   r   r   r   r   r   r   r   r   r   r   union)r   headsindexr)   r)   r*   prune_heads  s       zTapasAttention.prune_headsNFr    r   r   r   r   r   r   r   c              	   C   s<   |  |||||||}| |d |}	|	f|dd   }
|
S )Nr   r   )r   r   )r   r    r   r   r   r   r   r   Zself_outputsattention_outputr   r)   r)   r*   r     s    
	zTapasAttention.forward)NNNNNF)r"   r#   r$   rp   r   r&   r   r   r'   r   boolr   r   r)   r)   r   r*   r     s$         r   c                       s0   e Zd Z fddZejejdddZ  ZS )TapasIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )ro   rp   r   r   rs   intermediate_sizer   rQ   
hidden_actstrr   intermediate_act_fnr   r   r)   r*   rp     s
    
zTapasIntermediate.__init__r    r   c                 C   s   |  |}| |}|S r   )r   r   r   r    r)   r)   r*   r     s    

zTapasIntermediate.forwardr   r)   r)   r   r*   r     s   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )TapasOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )ro   rp   r   r   r   rs   r   r{   r|   r}   r~   r   r   r   r)   r*   rp     s    
zTapasOutput.__init__r   c                 C   s&   |  |}| |}| || }|S r   r   r   r)   r)   r*   r     s    

zTapasOutput.forwardr   r)   r)   r   r*   r     s   r   c                
       st   e Zd Z fddZd
ejeej eej eej eej eeeej   ee	 eej dddZ
dd	 Z  ZS )
TapasLayerc                    sn   t    |j| _d| _t|| _|j| _|j| _| jrV| jsLt|  dt|| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is added)ro   rp   chunk_size_feed_forwardseq_len_dimr   	attentionr   Zadd_cross_attentionr]   crossattentionr   intermediater   r   r   r   r)   r*   rp     s    



zTapasLayer.__init__NFr   c              	   C   s  |d k	r|d d nd }| j |||||d}	|	d }
| jrP|	dd }|	d }n|	dd  }d }| jr|d k	rt| dstd|  d|d k	r|d	d  nd }| |
||||||}|d }
||dd  }|d }|| }t| j| j| j|
}|f| }| jr||f }|S )
Nr>   )r   r   r   r   r   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r   )	r   r   r   r]   r   r   feed_forward_chunkr   r   )r   r    r   r   r   r   r   r   Zself_attn_past_key_valueZself_attention_outputsr   r   Zpresent_key_valueZcross_attn_present_key_valueZcross_attn_past_key_valueZcross_attention_outputslayer_outputr)   r)   r*   r   $  sV    


	   

zTapasLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )r   r   Zintermediate_outputr   r)   r)   r*   r   f  s    
zTapasLayer.feed_forward_chunk)NNNNNF)r"   r#   r$   rp   r&   r   r   r'   r   r   r   r   r   r)   r)   r   r*   r     s$         Br   c                	       s&   e Zd Z fddZdddZ  ZS )	TapasEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r)   )r   )r-   _rc   r)   r*   r@   p  s     z)TapasEncoder.__init__.<locals>.<listcomp>F)	ro   rp   rc   r   Z
ModuleListrZ   num_hidden_layerslayergradient_checkpointingr   r   r   r*   rp   m  s    
 zTapasEncoder.__init__NFTc              	      s   |	rdnd } rdnd }t | jD ]\}}|	r8||f }|d k	rH|| nd }| jr| jr fdd}tjj|||||||}n|||||| }|d } r"||d f }q"|	r||f }|
stdd |||fD S t|||dS )	Nr)   c                    s    fdd}|S )Nc                     s    | f S r   r)   )inputs)moduler   past_key_valuesr)   r*   custom_forward  s    zKTapasEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr)   )r   r   r   r   )r   r*   create_custom_forward  s    z3TapasEncoder.forward.<locals>.create_custom_forwardr   r   c                 s   s   | ]}|d k	r|V  qd S r   r)   )r-   vr)   r)   r*   r/     s      z'TapasEncoder.forward.<locals>.<genexpr>)last_hidden_stater    r!   )	rw   r   r   Ztrainingr&   utils
checkpointtupler   )r   r    r   r   r   r   r   Z	use_cacher   output_hidden_statesreturn_dictZall_hidden_statesZall_attentionsr?   Zlayer_moduleZlayer_head_maskr   Zlayer_outputsr)   r   r*   r   s  sJ    
		
  zTapasEncoder.forward)	NNNNNNFFTr"   r#   r$   rp   r   r   r)   r)   r   r*   r   l  s   	         r   c                       s0   e Zd Z fddZejejdddZ  ZS )TapasPoolerc                    s*   t    t|j|j| _t | _d S r   )ro   rp   r   r   rs   r   ZTanh
activationr   r   r)   r*   rp     s    
zTapasPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r   r    Zfirst_token_tensorpooled_outputr)   r)   r*   r     s    

zTapasPooler.forwardr   r)   r)   r   r*   r     s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )TapasPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S r   )ro   rp   r   r   rs   r   rQ   r   r   r   transform_act_fnr{   r|   r   r   r)   r*   rp     s    
z%TapasPredictionHeadTransform.__init__r   c                 C   s"   |  |}| |}| |}|S r   )r   r   r{   r   r)   r)   r*   r     s    


z$TapasPredictionHeadTransform.forwardr   r)   r)   r   r*   r     s   	r   c                       s$   e Zd Z fddZdd Z  ZS )TapasLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)r9   )ro   rp   r   	transformr   r   rs   rr   decoder	Parameterr&   r   r9   r   r   r)   r*   rp     s
    

zTapasLMPredictionHead.__init__c                 C   s   |  |}| |}|S r   )r   r  r   r)   r)   r*   r     s    

zTapasLMPredictionHead.forwardr   r)   r)   r   r*   r     s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )TapasOnlyMLMHeadc                    s   t    t|| _d S r   )ro   rp   r   predictionsr   r   r)   r*   rp     s    
zTapasOnlyMLMHead.__init__)sequence_outputr   c                 C   s   |  |}|S r   )r  )r   r  prediction_scoresr)   r)   r*   r     s    
zTapasOnlyMLMHead.forwardr   r)   r)   r   r*   r    s   r  c                   @   s.   e Zd ZdZeZdZdZdd Zd
ddZ	d	S )TapasPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    r5   Tc                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  nft |tj
rz|jjjd| jjd |jdk	r|jj|j 	  n&t |tjr|jj	  |jjd dS )zInitialize the weights        )meanstdN      ?)rQ   r   r   r7   ra   normal_rc   initializer_ranger9   Zzero_rq   rl   r{   Zfill_)r   r   r)   r)   r*   _init_weights  s    

z"TapasPreTrainedModel._init_weightsFc                 C   s   t |tr||_d S r   )rQ   r   r   )r   r   r   r)   r)   r*   _set_gradient_checkpointing  s    
z0TapasPreTrainedModel._set_gradient_checkpointingN)F)
r"   r#   r$   r%   r   config_classbase_model_prefixZsupports_gradient_checkpointingr  r  r)   r)   r)   r*   r    s   r  a?  
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`TapasConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a@
  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0}, 7)`, *optional*):
            Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
            class for more info.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. If
            `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
            used. Selected in the range `[0, config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - 1
            indicates the head is **not masked**, - 0 indicates the head is **masked**.
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z_The bare Tapas Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd ZdZd fdd	Zdd Zdd Zd	d
 Zee	
deeeddeej eej eej eej eej eej eej eej ee ee ee eeef dddZ  ZS )rS   a  
    This class is a small change compared to [`BertModel`], taking into account the additional token type ids.

    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

    Tc                    sD   t  | || _t|| _t|| _|r2t|nd | _| 	  d S r   )
ro   rp   rc   rk   r   r   encoderr   r4   	post_init)r   rc   add_pooling_layerr   r)   r*   rp   [  s    

zTapasModel.__init__c                 C   s   | j jS r   r   rt   r   r)   r)   r*   get_input_embeddingsg  s    zTapasModel.get_input_embeddingsc                 C   s   || j _d S r   r  )r   r   r)   r)   r*   set_input_embeddingsj  s    zTapasModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )r   Zheads_to_pruner   r   r)   r)   r*   _prune_headsm  s    zTapasModel._prune_headsbatch_size, sequence_lengthoutput_typer  N)r   r   r   r   r   r   r   r   r   r   r   r   c              
   C   s  |	dk	r|	n| j j}	|
dk	r |
n| j j}
|dk	r4|n| j j}|dk	rV|dk	rVtdn@|dk	rt| || | }n"|dk	r| dd }ntd|dk	r|jn|j}|dkrtj	||d}|dkrtj
|t| j jftj|d}| ||}| j jrB|dk	rB| \}}}||f}|dkr6tj	||d}| |}nd}| || j j}| j||||d}| j||||||	|
|d}|d	 }| jdk	r| |nd}|s||f|d
d  S t|||j|jdS )ag  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, TapasModel
        >>> import pandas as pd

        >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
        >>> model = TapasModel.from_pretrained("google/tapas-base")

        >>> data = {
        ...     "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
        ...     "Age": ["56", "45", "59"],
        ...     "Number of movies": ["87", "53", "69"],
        ... }
        >>> table = pd.DataFrame.from_dict(data)
        >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]

        >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embedsr   r   )r   r   r   r   )r   r   r   r   r   r   r   r   r   )r   Zpooler_outputr    r!   )rc   r   r   use_return_dictr]   Z%warn_if_padding_and_no_attention_maskr   r   r&   onesr   rX   rx   r   Zget_extended_attention_maskr   Zinvert_attention_maskZget_head_maskr   r   r  r4   r   r    r!   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zextended_attention_maskZencoder_batch_sizeZencoder_sequence_lengthr   Zencoder_hidden_shapeZencoder_extended_attention_maskZembedding_outputZencoder_outputsr  r   r)   r)   r*   r   u  sp    )

  
   
zTapasModel.forward)T)NNNNNNNNNNN)r"   r#   r$   r%   rp   r  r  r  r   TAPAS_INPUTS_DOCSTRINGformatr   r   _CONFIG_FOR_DOCr   r&   
LongTensorr'   r   r   r   r   r   r)   r)   r   r*   rS   L  s@   

           
rS   z3Tapas Model with a `language modeling` head on top.c                       s   e Zd ZddgZeZdZ fddZdd Zdd	 Z	e
ed
eeeddeej eej eej eej eej eej eej eej eej ee ee ee eeef dddZ  ZS )rT   zcls.predictions.decoder.weightzcls.predictions.decoder.biasr5   c                    s0   t  | t|dd| _t|| _|   d S )NF)r  )ro   rp   rS   r5   r  clsr  r   r   r)   r*   rp     s    
zTapasForMaskedLM.__init__c                 C   s
   | j jjS r   r$  r  r  r  r)   r)   r*   get_output_embeddings  s    z&TapasForMaskedLM.get_output_embeddingsc                 C   s   || j j_d S r   r%  )r   Znew_embeddingsr)   r)   r*   set_output_embeddings  s    z&TapasForMaskedLM.set_output_embeddingsr  r  N)r   r   r   r   r   r   r   r   labelsr   r   r   r   c                 K   s   |dk	r|n| j j}| j|||||||||
||d}|d }| |}d}|	dk	rtt }||d| j j|	d}|s|f|dd  }|dk	r|f| S |S t|||j|j	dS )ax  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, TapasForMaskedLM
        >>> import pandas as pd

        >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
        >>> model = TapasForMaskedLM.from_pretrained("google/tapas-base")

        >>> data = {
        ...     "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
        ...     "Age": ["56", "45", "59"],
        ...     "Number of movies": ["87", "53", "69"],
        ... }
        >>> table = pd.DataFrame.from_dict(data)

        >>> inputs = tokenizer(
        ...     table=table, queries="How many [MASK] has George [MASK] played in?", return_tensors="pt"
        ... )
        >>> labels = tokenizer(
        ...     table=table, queries="How many movies has George Clooney played in?", return_tensors="pt"
        ... )["input_ids"]

        >>> outputs = model(**inputs, labels=labels)
        >>> logits = outputs.logits
        ```N)
r   r   r   r   r   r   r   r   r   r   r   r   r>   r   r   r    r!   )
rc   r  r5   r$  r   r   rr   r   r    r!   )r   r   r   r   r   r   r   r   r   r(  r   r   r   kwargsr   r  r  Zmasked_lm_lossloss_fctr   r)   r)   r*   r     s:    4
zTapasForMaskedLM.forward)NNNNNNNNNNNN)r"   r#   r$   Z_tied_weights_keysr   r  r  rp   r&  r'  r   r   r!  r   r   r"  r   r&   r#  r'   r   r   r   r   r   r)   r)   r   r*   rT     sF   	
            
rT   a  
    Tapas Model with a cell selection head and optional aggregation head on top for question-answering tasks on tables
    (linear layers on top of the hidden-states output to compute `logits` and optional `logits_aggregation`), e.g. for
    SQA, WTQ or WikiSQL-supervised tasks.
    c                       s   e Zd Zed fddZeedee	e
dd
eej eej eej eej eej eej eej eej eej eej eej eej ee ee ee eee	f ddd	Z  ZS )TapasForQuestionAnsweringr   c                    s   t  | t|| _t|j| _|jrTt	t
|j| _t	t
|j| _nPt	t
|j| _tjj| j|jd t	t
|j| _tjj| j|jd t	t
g | _t	t
g | _|jdkrt|j|j| _|   d S )N)r
  r   )ro   rp   rS   r5   r   r}   r~   r   Z#init_cell_selection_weights_to_zeror  r&   r   rs   r1   r;   emptyinitr  r  r0   r:   num_aggregation_labelsr   r<   r  r   r   r)   r*   rp   ^  s*    
  
z"TapasForQuestionAnswering.__init__r  r  N)r   r   r   r   r   r   
table_maskr(  aggregation_labelsfloat_answernumeric_valuesnumeric_values_scaler   r   r   r   c           /      C   s\  |dk	r|n| j j}| j|||||||||d	}|d }|d }| |}|dk	r\| }n| dd }|dk	rz|jn|j}|dkrtj|t| j j	ftj
|d}ddd	d
dddg}|dddd|d	f }|dddd|df }tt|tj| j jd |jd| j jdd}tt|tj| j jd |jd| j jdd}t||}|dk	rj| n| dd }|dk	r|jn|j}|dkrtj||d}|dkrt|dkt|t|}| |}| |}t||\}}t|| j j| j| j}d} | j jr8t|| j| j ||| j j!} d}!| j j"dkrT| #|}!d}"d}#|dk	rd}#| j j"dk p| j j$ }$|$rd}%nH|
dk	r|j%d |
j%d kst&dt'|
|| j j(|| j#}%nt)d| j j*rt||\}&}t+|&|}tj,j-|d}'d}(| j js~t|dktj|tj.d| j j/tj|tj.d })|'0| |) }*tj1|*| ddtj1|ddt2  }(n$t3|| ||||\}(}tj,j-|d}'| j j4rn,|$r|"t5|(7 }"n|"t5|(d|%  7 }"| j j"dkr|$r@|	dk	r6|j%d |	j%d kst&dt6|!|%|	| j j$| j j"| j j7}+nt)dn8tj|j%d tj
|jd}	t6|!|%|	| j j$| j j"| j j7}+| j j$r|dk	r|dk	r|j%|j%kst&t8|
|%|'||||!| j \},}-|+|,7 }+|+|-9 }+nt)d|"t5|+7 }"n t|}t3|| ||||\}}|s<||!f|dd  }.|#r8|"f|. S |.S t9|#rH|"nd||!|j:|j;dS )as  
        table_mask (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
            Mask for the table. Indicates which tokens belong to the table (1). Question tokens, table headers and
            padding are 0.
        labels (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
            Labels per token for computing the hierarchical cell selection loss. This encodes the positions of the
            answer appearing in the table. Can be obtained using [`AutoTokenizer`].

            - 1 for tokens that are **part of the answer**,
            - 0 for tokens that are **not part of the answer**.

        aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
            Aggregation function index for every example in the batch for computing the aggregation loss. Indices
            should be in `[0, ..., config.num_aggregation_labels - 1]`. Only required in case of strong supervision for
            aggregation (WikiSQL-supervised).
        float_answer (`torch.FloatTensor` of shape `(batch_size, )`, *optional*):
            Float answer for every example in the batch. Set to *float('nan')* for cell selection questions. Only
            required in case of weak supervision (WTQ) to calculate the aggregate mask and regression loss.
        numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*):
            Numeric values of every token, NaN for tokens which are not numeric values. Can be obtained using
            [`AutoTokenizer`]. Only required in case of weak supervision for aggregation (WTQ) to calculate the
            regression loss.
        numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*):
            Scale of the numeric values of every token. Can be obtained using [`AutoTokenizer`]. Only required in case
            of weak supervision for aggregation (WTQ) to calculate the regression loss.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, TapasForQuestionAnswering
        >>> import pandas as pd

        >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-wtq")
        >>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")

        >>> data = {
        ...     "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
        ...     "Age": ["56", "45", "59"],
        ...     "Number of movies": ["87", "53", "69"],
        ... }
        >>> table = pd.DataFrame.from_dict(data)
        >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]

        >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> logits = outputs.logits
        >>> logits_aggregation = outputs.logits_aggregation
        ```Nr   r   r   r   r   r   r   r   r   r   r   r   Zsegment_ids
column_idsrow_idsZprev_labelsZcolumn_ranksZinv_column_ranksZnumeric_relationsr   indicesnum_segmentsr   r  FTz>Make sure the answers are a FloatTensor of shape (batch_size,)zJYou have to specify float answers in order to calculate the aggregate maskr   r   r   r  zHMake sure the aggregation labels are a LongTensor of shape (batch_size,)zQYou have to specify aggregation labels in order to calculate the aggregation losszeYou have to specify numeric values and numeric values scale in order to calculate the regression lossr>   )r   r   r   r    r!   )<rc   r  r5   r   r   r   r&   r   rX   rx   r   r   r   r   r   Zmax_num_rowsZmax_num_columnsr   r  where	ones_like
zeros_likefloattoreduce_meancompute_token_logitstemperaturer1   r0   Zselect_one_columncompute_column_logitsr;   r:   allow_empty_column_selectionr/  r<   use_answer_as_supervisionr\   r^   _calculate_aggregate_maskcell_selection_preferencer]   Zaverage_logits_per_cellr   distributions	Bernoullifloat32Zpositive_label_weightlog_probsumEPSILON_ZERO_DIVISION"_single_column_cell_selection_lossZdisable_per_token_lossr	  _calculate_aggregation_lossaggregation_loss_weight_calculate_regression_lossr   r    r!   )/r   r   r   r   r   r   r   r0  r(  r1  r2  r3  r4  r   r   r   r   r  r   r   r   Ztoken_typesr7  r6  r   r   
cell_indexinput_mask_floatZtable_mask_float	cell_maskr   r   column_logitsr   Z
total_lossZcalculate_lossZis_supervisedaggregate_masklogits_per_cellZdist_per_tokenselection_loss_per_exampler7   Zselection_loss_per_tokenZper_example_additional_lossZanswer_losslarge_answer_loss_maskr   r)   r)   r*   r     sn   G

  

"





	



     

		



     z!TapasForQuestionAnswering.forward)NNNNNNNNNNNNNNN)r"   r#   r$   r   rp   r   r   r!  r   r   r"  r   r&   r#  r'   r   r   r   r   r   r)   r)   r   r*   r,  U  sH   	"
               
r,  z
    Tapas Model with a sequence classification head on top (a linear layer on top of the pooled output), e.g. for table
    entailment tasks, such as TabFact (Chen et al., 2020).
    c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeej ef dddZ  ZS )
rR   c                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r   )ro   rp   
num_labelsrS   r5   r   r}   r~   r   r   rs   r=   r  r   r   r)   r*   rp     s    
z'TapasForSequenceClassification.__init__r  r  N)r   r   r   r   r   r   r(  r   r   r   r   c                 C   s|  |
dk	r|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dk	r8| j jdkr| jdkrzd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }| jdkr|| | }n
|||}nN| j jdkrt }||d| j|d}n| j jdkr8t }|||}|
sh|f|dd  }|dk	rd|f| S |S t|||j|jd	S )
ad  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Note: this is called
            "classification_class_index" in the original implementation.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, TapasForSequenceClassification
        >>> import torch
        >>> import pandas as pd

        >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-tabfact")
        >>> model = TapasForSequenceClassification.from_pretrained("google/tapas-base-finetuned-tabfact")

        >>> data = {
        ...     "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
        ...     "Age": ["56", "45", "59"],
        ...     "Number of movies": ["87", "53", "69"],
        ... }
        >>> table = pd.DataFrame.from_dict(data)
        >>> queries = [
        ...     "There is only one actor who is 45 years old",
        ...     "There are 3 actors which played in more than 60 movies",
        ... ]

        >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
        >>> labels = torch.tensor([1, 0])  # 1 means entailed, 0 means refuted

        >>> outputs = model(**inputs, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ```Nr5  r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   r>   r)  )rc   r  r5   r   r=   Zproblem_typer\  r   r&   r   rY   r	   Zsqueezer   r   r   r   r    r!   )r   r   r   r   r   r   r   r(  r   r   r   r   r   r   r   r+  r   r)   r)   r*   r     sV    4




"


z&TapasForSequenceClassification.forward)
NNNNNNNNNN)r"   r#   r$   rp   r   r   r!  r   r   r"  r   r&   r#  r'   r   r   r   r   r   r   r)   r)   r   r*   rR     s4   
          rR   c                   @   s   e Zd ZdZdZdZdS )AverageApproximationFunctionratioZfirst_orderZsecond_orderN)r"   r#   r$   RATIOFIRST_ORDERSECOND_ORDERr)   r)   r)   r*   r]  .  s   r]  c                   @   s"   e Zd ZdZdddZdd ZdS )	r   z'Index grouping entries within a tensor.r   c                 C   s(   t || _t j||jd| _|| _dS )a  
        Creates an index

        Args:
            indices (`torch.LongTensor`, same shape as a *values* Tensor to which the indices refer):
                Tensor containing the indices.
            num_segments (`torch.LongTensor`):
                Scalar tensor, the number of segments. All elements in a batched segmented tensor must have the same
                number of segments (although many segments can be empty).
            batch_dims (`int`, *optional*, defaults to 0):
                The number of batch dimensions. The first *batch_dims* dimensions of a SegmentedTensor are treated as
                batch dimensions. Segments in different batch elements are always distinct even if they have the same
                index.
        r   N)r&   r   r9  r   r:  r   )r   r9  r:  r   r)   r)   r*   rp   :  s    zIndexMap.__init__c                 C   s   | j  d | j S r   )r9  r   r   r  r)   r)   r*   batch_shapeM  s    zIndexMap.batch_shapeN)r   )r"   r#   r$   r%   rp   rb  r)   r)   r)   r*   r   7  s   
r   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )r   zThe product of two indices.c                    sN   |j |j krtdt j|j|j|j  |j|j |j d || _|| _dS )a  
        Combines indices i and j into pairs (i, j). The result is an index where each segment (i, j) is the
        intersection of segments i and j. For example if the inputs represent table cells indexed by respectively rows
        and columns the output will be a table indexed by (row, column) pairs, i.e. by cell. The implementation
        combines indices {0, .., n - 1} and {0, .., m - 1} into {0, .., nm - 1}. The output has *num_segments* equal to
        *outer_index.num_segments* * *inner_index.num_segments*

        Args:
            outer_index (`IndexMap`):
                IndexMap.
            inner_index (`IndexMap`):
                IndexMap, must have the same shape as *outer_index*.
        zCouter_index.batch_dims and inner_index.batch_dims must be the same.r8  N)r   r]   ro   rp   r9  r:  outer_indexinner_index)r   rc  rd  r   r)   r*   rp   T  s    
zProductIndexMap.__init__c                 C   s2   t j|j| jjddt j}t|| jj|j	dS )zDProjects an index with the same index set onto the outer components.floor)Zrounding_moder8  )
r&   divr9  rd  r:  typer   r   rc  r   )r   r   r9  r)   r)   r*   project_outerm  s    zProductIndexMap.project_outerc                 C   s6   t t|j| jjtj tj	| jj|j
dS )zDProjects an index with the same index set onto the inner components.r8  )r   r&   fmodr9  rd  r:  rg  r@  re  r   r   )r   r   r)   r)   r*   project_innerr  s    zProductIndexMap.project_inner)r"   r#   r$   r%   rp   rh  rj  r   r)   r)   r   r*   r   Q  s   r   segmented_gatherc                 C   sn   |j }t| j|jd dk rHt| |j||  d d| S |d	| j}t| |j|S dS )a  
    Gathers from *values* using the index map. For each element in the domain of the index map this operation looks up
    a value for that index in *values*. Two elements from the same segment always get assigned the same value.

    Args:
        values (`torch.Tensor` of shape (B1, ..., Bn, num_segments, V1, ...)):
            Tensor with segment values.
        index (`IndexMap` of shape (B1, ..., Bn, I1, ..., Ik)):
            IndexMap.
        name (`str`, *optional*, defaults to 'segmented_gather'):
            Name for the operation. Currently not used

    Returns:
        `tuple(torch.Tensor)`: Tensor of shape (B1, ..., Bn, I1, ..., Ik, V1, ...) with the gathered values.
    Nr>   r   r   )
r9  rX   r\   r   r&   r   r   r   r   r   )valuesr   rg   r9  r)   r)   r*   r   ~  s    
 
r   segmented_flattenc                 C   s   t t t|  }t jd|| jjd| j }||  }t	| j
t| j D ]}|d}qV|| j }t|d| j| ddS )aj  
    Flattens a batched index map (which is typically of shape batch_size, seq_length) to a 1d index map. This operation
    relabels the segments to keep batch elements distinct. The k-th batch element will have indices shifted by
    *num_segments* * (k - 1). The result is a tensor with *num_segments* multiplied by the number of elements in the
    batch.

    Args:
        index (`IndexMap`):
            IndexMap to flatten.
        name (`str`, *optional*, defaults to 'segmented_flatten'):
            Name for the operation. Currently not used

    Returns:
        (`IndexMap`): The flattened IndexMap.
    r   startendr   r   r8  )r&   prodZtensorlistrb  r   r:  r   r   rZ   r   rX   r9  r   r   r   )r   rg   Z
batch_sizeoffsetr   r9  r)   r)   r*   flatten  s    
rt  range_index_mapc                 C   s   t j| t jd} t|  dks$tt |}t| dksBtt jd||jd}t jt j	| t j|jd|j
ddgdd}dd | D }||}t j| t dggdd}|| }t||t|  d d	S )
a  
    Constructs an index map equal to range(num_segments).

    Args:
        batch_shape (`torch.Size`):
            Batch shape
        num_segments (`int`):
            Number of segments
        name (`str`, *optional*, defaults to 'range_index_map'):
            Name for the operation. Currently not used

    Returns:
        (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
    r<  r   r   rn  r   r   c                 S   s   g | ]}t |qS r)   )rY   )r-   r   r)   r)   r*   r@     s     z#range_index_map.<locals>.<listcomp>r8  )r&   r   r   rX   r   r^   r   r   r   r>  r   tolistr   repeatr   rr  )rb  r:  rg   r9  Z
new_tensor	new_shapeZ	multiplesr)   r)   r*   ru    s*     
  
c                 C   s  t |}|  t|j d }tjtjdgtjdtj|tjdgdd}| |	 }tj
t|jtj|jd}|jd|j | |dd}	tjtj| tjdtj|jgtjdtj|tjdgdd}
|	 |
	 | j}t| |j}||fS )	a  
    Applies a segment reduction segment-wise.

    Args:
        values (`torch.Tensor`):
            Tensor with segment values.
        index (`IndexMap`):
            IndexMap.
        segment_reduce_fn (`str`):
            Name for the reduce operation. One of "sum", "mean", "max" or "min".
        name (`str`):
            Name for the operation. Currently not used

    Returns:
        (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
    Nr   r<  r   r   r   F)r   r   srcreduceZinclude_self)rt  r   rX   r9  r&   r   r   r   Zreshaperv  r   rY   r:  r@  r   Zscatter_reducerb  cloner   rA  r   ru  )rl  r   Zsegment_reduce_fnrg   Z
flat_indexZvector_shapeZflattened_shapeZflat_valuesoutZsegment_meansrx  Zoutput_valuesZoutput_indexr)   r)   r*   _segment_reduce  s2          	r}  segmented_reduce_sumc                 C   s   t | |d|S )a~  
    Sums a tensor over its segments.

    Outputs 0 for empty segments.

    This operations computes the sum over segments, with support for:

        - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
        - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a sum of
          vectors rather than scalars. Only the middle dimensions [I1, ..., Ik] are reduced by the operation.

    Args:
        values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
            Tensor containing the values of which the sum must be taken segment-wise.
        index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
            Index defining the segments.
        name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
            Name for the operation. Currently not used

    Returns:
        output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
        output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. .
    rN  r}  rl  r   rg   r)   r)   r*   
reduce_sum  s    r  segmented_reduce_meanc                 C   s   t | |d|S )a  
    Averages a tensor over its segments.

    Outputs 0 for empty segments.

    This operations computes the mean over segments, with support for:

        - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
        - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a mean of
          vectors rather than scalars.

    Only the middle dimensions [I1, ..., Ik] are reduced by the operation.

    Args:
        values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
            Tensor containing the values of which the mean must be taken segment-wise.
        index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
            Index defining the segments.
        name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
            Name for the operation. Currently not used

    Returns:
        output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
        output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
    r	  r  r  r)   r)   r*   rB  0  s    rB  segmented_reduce_maxc                 C   s   t | |d|S )av  
    Computes the maximum over segments.

    This operation computes the maximum over segments, with support for:

        - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
        - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise
          maximum of vectors rather than scalars.

    Only the middle dimensions [I1, ..., Ik] are reduced by the operation.

    Args:
        values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
            Tensor containing the values of which the max must be taken segment-wise.
        index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
            Index defining the segments.
        name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
            Name for the operation. Currently not used

    Returns:
        output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
        output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
    Zamaxr  r  r)   r)   r*   
reduce_maxM  s    r  segmented_reduce_minc                 C   s   t | |d|S )aw  
    Computes the minimum over segments.

    This operations computes the minimum over segments, with support for:

        - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
        - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise
          minimum of vectors rather than scalars.

    Only the middle dimensions [I1, ..., Ik] are reduced by the operation.

    Args:
        values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
            Tensor containing the values of which the min must be taken segment-wise.
        index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
            Index defining the segments.
        name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
            Name for the operation. Currently not used

    Returns:
        output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
        output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
    Zaminr  r  r)   r)   r*   r   h  s    r   c                 C   s   t d| || }t||\}}||}	t|| |	\}
}t||	\}}|
|t  }
t |dk t |jd }|
t	t j
|t j|jd 7 }
|s|
t	t j
t |jdt j|jjd 7 }
|
S )a)  
    Computes the column logits.

    Args:
        sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model.
        column_output_weights (`torch.FloatTensor` of shape `(hidden_size)`):
            Weights of the linear layer for column selection.
        column_output_bias (`torch.FloatTensor` of shape `()`):
            Bias of the linear layer for column selection.
        cell_index (`ProductIndexMap`):
            Index that groups tokens into cells.
        cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
            Mask for cells that exist in the table (i.e. that are not padding).
        allow_empty_column_selection (`bool`):
            Whether to allow not to select any column

    Returns:
        column_logits (`torch.FloatTensor`of shape `(batch_size, max_num_cols)`): Tensor containing the column logits
        for every example in the batch.
    	bsj,j->bs      ?r   r   )r&   einsumrB  rj  r  rO  logical_andeqr9  CLOSE_ENOUGH_TO_LOG_ZEROr   rL  r   )r  r;   r:   rT  rV  rF  token_logitsZcell_logitsZcell_logits_indexcolumn_indexrW  Z	out_indexZ
cell_countr   Z
is_paddingr)   r)   r*   rE    s&    
  
  
rE  c                 C   s  t tj|tj|jd|\}}tj|dd}ttj|ddd d}	t|		|
 t||}tjj|d}
|
| }t| |\}}ttj|tj|jd|\}}||j}tjt|tj|ddtj|jd}tjj|d}||tj}tj|| | dd }|tj|| ddt  }|}|t|		|
 t||7 }tjtj|ddtj|jd}tjt|tj|ddtj|jd}tt|d	|
 t||}|td||    }t||}||fS )a  
    Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The
    model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside
    the selected column are never selected.

    Args:
        token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Tensor containing the logits per token.
        column_logits (`torch.FloatTensor` of shape `(batch_size, max_num_cols)`):
            Tensor containing the logits per column.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Labels per token.
        cell_index (`ProductIndexMap`):
            Index that groups tokens into cells.
        col_index (`IndexMap`):
            Index that groups tokens into columns.
        cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
            Mask for cells that exist in the table (i.e. that are not padding).

    Returns:
        selection_loss_per_example (`torch.FloatTensor` of shape `(batch_size,)`): Loss for each example. logits
        (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): New logits which are only allowed to select
        cells in a single column. Logits outside of the most likely column according to *column_logits* will be set to
        a very low value (such that the probabilities are 0).
    r   r   r   r   r;  r   r  )r  r&   r   rL  r   Zargmaxr  maxr=  r   r   r?  rJ  CategoricalrM  rB  r  r   rj  r9  r   rK  rg  rN  rO  r  r   )r  rW  r(  rT  r   rV  Zlabels_per_columnr   Zcolumn_labelZno_cell_selectedZcolumn_distZcolumn_loss_per_examplerY  Zlabels_per_cellZlabels_indexZcolumn_id_for_cellsZcolumn_maskZ	cell_distZcell_log_probZ	cell_lossrZ  Zselected_column_idZselected_column_maskZnew_logits_per_cellr   r)   r)   r*   rP    sh        	  
rP  c                 C   s   t d| || | }|S )a  
    Computes logits per token

    Args:
        sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model.
        temperature (`float`):
            Temperature for the Bernoulli distribution.
        output_weights (`torch.FloatTensor` of shape `(hidden_size,)`):
            Weights of the linear layer for cell selection.
        output_bias (`torch.FloatTensor` of shape `()`):
            Bias of the linear layer for cell selection

    Returns:
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Logits per token.
    r  )r&   r  )r  rD  r1   r0   r   r)   r)   r*   rC  %  s    rC  c                 C   s   t t | t j| j}||}t jjj	|d}t j
|jddddf dd}||k}	t j
|dddk}
t t |	|
| t j|t jd|}| }|S )a  
    Finds examples where the model should select cells with no aggregation.

    Returns a mask that determines for which examples should the model select answers directly from the table, without
    any aggregation function. If the answer is a piece of text the case is unambiguous as aggregation functions only
    apply to numbers. If the answer is a number but does not appear in the table then we must use some aggregation
    case. The ambiguous case is when the answer is a number that also appears in the table. In this case we use the
    aggregation function probabilities predicted by the model to decide whether to select or aggregate. The threshold
    for this is a hyperparameter *cell_selection_preference*

    Args:
        answer (`torch.FloatTensor` of shape `(batch_size, )`):
            Answer for every example in the batch. Nan if there is no scalar answer.
        pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Output of the pooler (BertPooler) on top of the encoder layer.
        cell_selection_preference (`float`):
            Preference for cell selection in ambiguous cases.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Labels per token. aggregation_classifier (`torch.nn.Linear`): Aggregation head

    Returns:
        aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask set to 1 for examples that should use
        aggregation functions.
    r;  Nr   r   r   r<  )r&   Zlogical_notisnanrg  r'   rA  r   rJ  categoricalr  rN  probsr=  r  r   r   r?  rL  detach)answerr   rI  r(  r<   Zaggregate_mask_initr   dist_aggregationaggregation_ops_total_massZis_pred_cell_selectionZis_cell_supervision_availablerX  r)   r)   r*   rH  ;  s      rH  c           	      C   sn   |rt j|t jd}n|}tjj||dt j}tjj| dd}t j	|| dd }|rf|d|  S |S dS )a  
    Calculates aggregation loss when its type is known during training.

    In the weakly supervised setting, the only known information is that for cell selection examples, "no aggregation"
    should be predicted. For other examples (those that require aggregation), no loss is accumulated. In the setting
    where aggregation type is always known, standard cross entropy loss is accumulated for all examples

    Args:
        logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
            Logits per aggregation operation.
        aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
            A mask set to 1 for examples that should use aggregation functions.
        aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`):
            Aggregation function id for every example in the batch.
        use_answer_as_supervision (`bool`, *optional*):
            Whether to use the answer as the only supervision for aggregation examples.
        num_aggregation_labels (`int`, *optional*, defaults to 0):
            The number of aggregation operators to predict.

    Returns:
        aggregation_loss_known (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (when its type is known
        during training) per example.
    r<  )Znum_classesr   r   r   N)
r&   r?  r   r   r   Zone_hotrg  rL  Zlog_softmaxrN  )	r   rX  r1  rG  r/  Ztarget_aggregationZone_hot_labelsZ	log_probsZ$per_example_aggregation_intermediater)   r)   r*   !_calculate_aggregation_loss_knownn  s    r  c                 C   s@   t jjj| d}t j|jddddf dd}t | | S )a  
    Calculates aggregation loss in the case of answer supervision.

    Args:
        logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
            Logits per aggregation operation.
        aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
            A mask set to 1 for examples that should use aggregation functions

    Returns:
        aggregation_loss_unknown (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (in case of answer
        supervision) per example.
    r;  Nr   r   )r&   rJ  r  r  rN  r  log)r   rX  r  r  r)   r)   r*   #_calculate_aggregation_loss_unknown  s     r  c                 C   s*   t | ||||}|r"|t| |7 }|| S )a  
    Calculates the aggregation loss per example.

    Args:
        logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
            Logits per aggregation operation.
        aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
            A mask set to 1 for examples that should use aggregation functions.
        aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`):
            Aggregation function id for every example in the batch.
        use_answer_as_supervision (`bool`, *optional*):
            Whether to use the answer as the only supervision for aggregation examples.
        num_aggregation_labels (`int`, *optional*, defaults to 0):
            The number of aggregation operators to predict.
        aggregation_loss_weight (`float`, *optional*, defaults to 1.0):
            Importance weight for the aggregation loss.

    Returns:
        aggregation_loss (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss per example.
    )r  r  )r   rX  r1  rG  r/  rR  Zper_example_aggregation_lossr)   r)   r*   rQ    s        rQ  c                 C   s  |j r*tjj|j| j|j d}| }n| j}|| | }tj|dd}t	t
|t||}	tj||	 dd}
|j}|tjkr|
|t  }n|tjkrtj|ddd| d }tj|	| | dd}n|tjkr@tj|ddd| d }|d|  }tj|ddd| }|t| d | }tj|	| | dd}ntd|j |jrtjj|j|ddddf d}| }n&tjj|ddddf |j d	d}tjtj|
ddtj|ddtj|ddgdd}tj|| dd}|S )
a  
    Calculates the expected result given cell and aggregation probabilities.

    Args:
        dist_per_cell (`torch.distributions.Bernoulli`):
            Cell selection distribution for each cell.
        numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
            Numeric values of every token. Nan for tokens which are not numeric values.
        numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
            Scale of the numeric values of every token.
        input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
            Mask for the table, without question tokens and table headers.
        logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
            Logits per aggregation operation.
        config ([`TapasConfig`]):
            Model configuration class with all the hyperparameters of the model

    Returns:
        expected_result (`torch.FloatTensor` of shape `(batch_size,)`): The expected result per example.
    )rD  r   r   r   T)r   Zkeepdimz(Invalid average_approximation_function: Nr;  r   )Zuse_gumbel_for_cellsr&   rJ  ZRelaxedBernoullirD  r   sampler  rN  r=  r  r?  Zaverage_approximation_functionr]  r_  rO  r`  ra  Zsquarer]   Zuse_gumbel_for_aggregationZRelaxedOneHotCategoricalZaggregation_temperaturer   r   r   r   r   )dist_per_cellr3  r4  rU  r   rc   Zgumbel_distZscaled_probability_per_cellZcount_resultZnumeric_values_maskedZ
sum_resultZavg_approximationZaverage_resultexZpointwise_varvar
multiplierZaggregation_op_only_probsZall_resultsexpected_resultr)   r)   r*   _calculate_expected_result  s^    

  

 
 	r  r  deltac                 C   s8   t | | }t ||k d|d  || d|d   S )Nr  r>   )r&   absr=  )inputtargetr  errorsr)   r)   r*   
huber_loss2	  s    r  c                 C   s   t ||||||}tt| t| | }	|jrvtt|t|	t 	 }
|	|
 }||
 }t
|| || }nt
|| |	| |jd}|jdkrtj|tjd}n,t||jktj|tjdtj|tjd}|j||  }||fS )a  
    Calculates the regression loss per example.

    Args:
        answer (`torch.FloatTensor` of shape `(batch_size,)`):
            Answer for every example in the batch. Nan if there is no scalar answer.
        aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`):
            A mask set to 1 for examples that should use aggregation functions.
        dist_per_cell (`torch.distributions.Bernoulli`):
            Cell selection distribution for each cell.
        numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
            Numeric values of every token. Nan for tokens which are not numeric values.
        numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
            Scale of the numeric values of every token.
        input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
            Mask for the table, without question tokens and table headers.
        logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
            Logits per aggregation operation.
        config ([`TapasConfig`]):
            Model configuration class with all the parameters of the model

    Returns:
        per_example_answer_loss_scaled (`torch.FloatTensor` of shape `(batch_size,)`): Scales answer loss for each
        example in the batch. large_answer_loss_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask which is 1
        for examples for which their answer loss is larger than the answer_loss_cutoff.
    r  Nr<  )r  r&   r=  r  r?  Zuse_normalized_answer_lossr  r  rO  r  r  Zhuber_loss_deltaZanswer_loss_cutoffr>  rL  Zanswer_loss_importance)r  rX  r  r3  r4  rU  r   rc   r  Zanswer_maskedZ
normalizerZnormalized_answer_maskedZnormalized_expected_resultZper_example_answer_lossr[  Zper_example_answer_loss_scaledr)   r)   r*   rS  7	  s>    %         
rS  )rk  )rm  )ru  )r~  )r  )r  )r  )r  )`r%   enumr   rG   dataclassesr   typingr   r   r   r&   Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zactivationsr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   r   r   r   r   r   Zconfiguration_tapasr   Z
get_loggerr"   rE   warning__version__r"  Z_CHECKPOINT_FOR_DOCZ#TAPAS_PRETRAINED_MODEL_ARCHIVE_LISTrO  r  r   rj   Modulerk   r   r   r   r   r   r   r   r   r   r   r  r  ZTAPAS_START_DOCSTRINGr   rS   rT   r,  rR   r   Enumr]  objectr   r   r   rt  ru  r}  r  rB  r  r   rE  rP  rC  rH  r  r  rQ  r  r@  r  rS  r)   r)   r)   r*   <module>   s   
" O`4XE
 + l  Tx	-
!

)0



4k3.&X