U
    9%e^J                    @   s  d Z ddlZddlZddlmZ ddlmZmZmZ ddl	Z	ddl	m
Z
 ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZmZmZ ddlmZmZmZmZ ddlm Z  e!e"Z#dZ$dZ%dZ&dZ'ddddddddgZ(dd Z)G dd de
j*Z+G dd de
j*Z,G dd de
j*Z-G d d! d!e
j*Z.G d"d# d#e
j*Z/G d$d% d%e
j*Z0G d&d' d'e
j*Z1G d(d) d)e
j*Z2G d*d+ d+e
j*Z3eG d,d- d-eZ4eG d.d/ d/eZ5eG d0d1 d1eZ6eG d2d3 d3eZ7G d4d5 d5e
j*Z8G d6d7 d7e
j*Z9G d8d9 d9e
j*Z:G d:d; d;e
j*Z;G d<d= d=e
j*Z<d>Z=d?Z>G d@dA dAeZ?G dBdC dCe?Z@edDe=G dEdF dFe?ZAedGe=G dHdI dIe?ZBedJe=G dKdL dLe?ZCedMe=G dNdO dOe?ZDdPZEedQe=G dRdS dSe?ZFdS )Tz PyTorch REALM model.    N)	dataclass)OptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN))BaseModelOutputWithPastAndCrossAttentions,BaseModelOutputWithPoolingAndCrossAttentionsMaskedLMOutputModelOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)add_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )RealmConfigz(google/realm-cc-news-pretrained-embedderz'google/realm-cc-news-pretrained-encoderz&google/realm-cc-news-pretrained-scorerr   z&google/realm-cc-news-pretrained-openqazgoogle/realm-orqa-nq-openqazgoogle/realm-orqa-nq-readerzgoogle/realm-orqa-wq-openqazgoogle/realm-orqa-wq-readerc                 C   s  zddl }ddl}ddl}W n  tk
r<   td  Y nX tj|}t	d|  |j
|}g }g }	|D ]@\}
}t	d|
 d|  |j
||
}||
 |	| qrt||	D ]\}
}t| trd|
krt	d|
 d	| jj d
 q|
ds|
dr4t| tr4|
dd}
|
dd}
|
dsL|
drdt| trd|
dd}
|
drt| trdnd}|
d| d}
|
d| d}
|
d| d}
|
d| d}
|
d| d}
|
drlt| trdnd}|
d| d}
|
d| d }
|
d!| d"}
|
d#| d$}
|
d%| d}
|
d&| d$}
nD|
d'rt| trdnd}|
d(| d }
|
d)| d"}
|
d*}
td+d, |
D rt	dd*|
  q| }|
D ]}|d-|r|d.|}n|g}|d d/ks0|d d0kr<t|d1}nn|d d2ksX|d d3krdt|d4}nFzt||d }W n2 tk
r   t	dd*|
  Y qY nX t|d5krt|d6 }|| }q|d7d d8krt|d1}n|d/kr| |}z,|j!|j!ks,t"d9|j! d:|j! d;W n< t"k
rj } z| j#|j!|j!f7  _# W 5 d}~X Y nX t	d<|
  t$%||_&q| S )=z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape readerz	Skipping z as it is not z's parameterZbertclszbert/zreader/realm/zcls/zreader/cls/zrealm/ zreader/zreader/module/bert/zreader/module/cls/zreader/dense/zqa_outputs/dense_intermediate/zreader/dense_1/zqa_outputs/dense_output/zreader/layer_normalizationzqa_outputs/layer_normalizationzmodule/module/module/z	embedder/z!module/module/module/module/bert/zmodule/module/module/LayerNorm/zcls/LayerNorm/zmodule/module/module/dense/z
cls/dense/z,module/module/module/module/cls/predictions/zcls/predictions/zmodule/module/module/bert/z%module/module/module/cls/predictions/zmodule/module/zmodule/module/LayerNorm/zmodule/module/dense//c                 s   s   | ]}|d kV  qdS ))Zadam_vZadam_mZAdamWeightDecayOptimizerZAdamWeightDecayOptimizer_1Zglobal_stepN ).0nr   r   g/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/transformers/models/realm/modeling_realm.py	<genexpr>|   s   z+load_tf_weights_in_realm.<locals>.<genexpr>z[A-Za-z]+_\d+z_(\d+)ZkernelgammaweightZoutput_biasbetabias   r   iZ_embeddingszPointer shape z and array shape z mismatchedzInitialize PyTorch weight )'renumpyZ
tensorflowImportErrorloggererrorospathabspathinfotrainZlist_variablesZload_variableappendzip
isinstanceRealmReader	__class____name__
startswithRealmForOpenQAreplaceRealmKnowledgeAugEncoderRealmEmbeddersplitanyjoin	fullmatchgetattrAttributeErrorlenint	transposeshapeAssertionErrorargstorchZ
from_numpydata)modelconfigZtf_checkpoint_pathr&   nptfZtf_pathZ	init_varsnamesZarraysnamerD   arrayZreader_prefixZembedder_prefixZpointerZm_nameZscope_namesnumer   r   r   load_tf_weights_in_realm:   s    
$$




rR   c                       sT   e Zd ZdZ fddZd	eej eej eej eej e	ej
dddZ  ZS )
RealmEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd d S )N)padding_idxZepsposition_embedding_typeabsoluteposition_ids)r   F)
persistenttoken_type_idsdtype)super__init__r   	Embedding
vocab_sizehidden_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutr?   rV   register_bufferrG   arangeexpandzerosrX   sizelongselfrJ   r4   r   r   r_      s"    
    zRealmEmbeddings.__init__Nr   )	input_idsr[   rX   inputs_embedspast_key_values_lengthreturnc                 C   s   |d k	r|  }n|  d d }|d }|d krL| jd d ||| f }|d krt| dr| jd d d |f }||d |}	|	}ntj|tj| jjd}|d kr| 	|}| 
|}
||
 }| jdkr| |}||7 }| |}| |}|S )NrY   r   r[   r   r]   devicerW   )rp   rX   hasattrr[   rn   rG   ro   rq   rz   rc   rf   rV   re   rg   rk   )rs   ru   r[   rX   rv   rw   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedrf   
embeddingsre   r   r   r   forward   s,    







zRealmEmbeddings.forward)NNNNr   )r5   
__module____qualname____doc__r_   r   rG   
LongTensorFloatTensorrB   Tensorr   __classcell__r   r   rt   r   rS      s        rS   c                
       s   e Zd Zd fdd	ZejejdddZdejeej eej eej eej ee	e	ej   ee
 e	ej dd	d
Z  ZS )RealmSelfAttentionNc                    s   t    |j|j dkr>t|ds>td|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|pt|dd| _| jdks| jd	kr|j| _t	d
|j d | j| _|j| _d S )Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()rV   rW   relative_keyrelative_key_queryr%   r   )r^   r_   rb   num_attention_headsr{   
ValueErrorrB   attention_head_sizeall_head_sizer   Linearquerykeyvalueri   Zattention_probs_dropout_probrk   r?   rV   rd   r`   distance_embedding
is_decoderrs   rJ   rV   rt   r   r   r_      s*    
  zRealmSelfAttention.__init__)xrx   c                 C   s6   |  d d | j| jf }||}|ddddS )NrY   r   r%   r   r   )rp   r   r   viewpermute)rs   r   Znew_x_shaper   r   r   transpose_for_scores  s    
z'RealmSelfAttention.transpose_for_scoresFhidden_statesattention_mask	head_maskencoder_hidden_statesencoder_attention_maskpast_key_valueoutput_attentionsrx   c                 C   s  |  |}|d k	}	|	r4|d k	r4|d }
|d }|}n|	r^| | |}
| | |}|}nv|d k	r| | |}
| | |}tj|d |
gdd}
tj|d |gdd}n | | |}
| | |}| |}|d k	}| jr|
|f}t||
dd}| j	dks | j	dkr|j
d |
j
d  }}|r^tj|d tj|jd	dd}ntj|tj|jd	dd}tj|tj|jd	dd}|| }| || j d }|j|jd
}| j	dkrtd||}|| }n4| j	dkrtd||}td|
|}|| | }|t| j }|d k	r:|| }tjj|dd}| |}|d k	rf|| }t||}|dddd }| d d | jf }||}|r||fn|f}| jr||f }|S )Nr   r   r%   dimrY   r   r   ry   r\   zbhld,lrd->bhlrzbhrd,lrd->bhlrr   ) r   r   r   r   rG   catr   matmulrC   rV   rD   tensorrq   rz   r   rm   r   rd   tor]   einsummathsqrtr   r   Z
functionalZsoftmaxrk   r   
contiguousrp   r   )rs   r   r   r   r   r   r   r   Zmixed_query_layerZis_cross_attentionZ	key_layerZvalue_layerZquery_layer	use_cacheZattention_scoresZquery_lengthZ
key_lengthZposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keyZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr   r   r   r     sp    


 





zRealmSelfAttention.forward)N)NNNNNF)r5   r   r   r_   rG   r   r   r   r   r   boolr   r   r   r   rt   r   r      s$         r   c                       s4   e Zd Z fddZejejejdddZ  ZS )RealmSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S NrU   )r^   r_   r   r   rb   denserg   rh   ri   rj   rk   rr   rt   r   r   r_   o  s    
zRealmSelfOutput.__init__r   input_tensorrx   c                 C   s&   |  |}| |}| || }|S Nr   rk   rg   rs   r   r   r   r   r   r   u  s    

zRealmSelfOutput.forwardr5   r   r   r_   rG   r   r   r   r   r   rt   r   r   n  s   r   c                
       sv   e Zd Zd
 fdd	Zdd Zdejeej eej eej eej ee	e	ej   ee
 e	ej ddd	Z  ZS )RealmAttentionNc                    s.   t    t||d| _t|| _t | _d S )NrV   )r^   r_   r   rs   r   outputsetpruned_headsr   rt   r   r   r_   ~  s    

zRealmAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )rA   r   rs   r   r   r   r   r   r   r   r   r   r   union)rs   headsindexr   r   r   prune_heads  s       zRealmAttention.prune_headsFr   c              	   C   s<   |  |||||||}| |d |}	|	f|dd   }
|
S )Nr   r   )rs   r   )rs   r   r   r   r   r   r   r   Zself_outputsattention_outputr   r   r   r   r     s    
	zRealmAttention.forward)N)NNNNNF)r5   r   r   r_   r   rG   r   r   r   r   r   r   r   r   r   rt   r   r   }  s$         r   c                       s0   e Zd Z fddZejejdddZ  ZS )RealmIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r^   r_   r   r   rb   intermediate_sizer   r2   
hidden_actstrr	   intermediate_act_fnrr   rt   r   r   r_     s
    
zRealmIntermediate.__init__r   rx   c                 C   s   |  |}| |}|S r   )r   r   rs   r   r   r   r   r     s    

zRealmIntermediate.forwardr   r   r   rt   r   r     s   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )RealmOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r^   r_   r   r   r   rb   r   rg   rh   ri   rj   rk   rr   rt   r   r   r_     s    
zRealmOutput.__init__r   c                 C   s&   |  |}| |}| || }|S r   r   r   r   r   r   r     s    

zRealmOutput.forwardr   r   r   rt   r   r     s   r   c                
       st   e Zd Z fddZd
ejeej eej eej eej eeeej   ee	 eej dddZ
dd	 Z  ZS )
RealmLayerc                    sr   t    |j| _d| _t|| _|j| _|j| _| jrZ| jsLt|  dt|dd| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is addedrW   r   )r^   r_   chunk_size_feed_forwardseq_len_dimr   	attentionr   add_cross_attentionr   crossattentionr   intermediater   r   rr   rt   r   r   r_     s    


zRealmLayer.__init__NFr   c              	   C   s  |d k	r|d d nd }| j |||||d}	|	d }
| jrP|	dd }|	d }n|	dd  }d }| jr|d k	rt| dstd|  d|d k	r|d	d  nd }| |
||||||}|d }
||dd  }|d }|| }t| j| j| j|
}|f| }| jr||f }|S )
Nr%   r   r   r   r   rY   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r   )	r   r   r{   r   r   r   feed_forward_chunkr   r   )rs   r   r   r   r   r   r   r   Zself_attn_past_key_valueZself_attention_outputsr   r   Zpresent_key_valueZcross_attn_present_key_valueZcross_attn_past_key_valueZcross_attention_outputslayer_outputr   r   r   r     sV    


	   

zRealmLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )rs   r   Zintermediate_outputr   r   r   r   r     s    
zRealmLayer.feed_forward_chunk)NNNNNF)r5   r   r   r_   rG   r   r   r   r   r   r   r   r   r   r   rt   r   r     s$         Ar   c                       s   e Zd Z fddZd	ejeej eej eej eej eeeej   ee	 ee	 ee	 ee	 e
eej ef dddZ  ZS )
RealmEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r   )r   )r   _rJ   r   r   
<listcomp>)  s     z)RealmEncoder.__init__.<locals>.<listcomp>F)	r^   r_   rJ   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingrr   rt   r   r   r_   &  s    
 zRealmEncoder.__init__NFT)r   r   r   r   r   past_key_valuesr   r   output_hidden_statesreturn_dictrx   c              	      st  |	rdnd } rdnd } r(| j jr(dnd }| jrJ| jrJ|rJtd d}|rRdnd }t| jD ]\}}|	rv||f }|d k	r|| nd }|d k	r|| nd | jr| jrև fdd}tj	j

|||||||}n|||||| }|d }|r||d f7 } r`||d f }| j jr`||d	 f }q`|	r@||f }|
sbtd
d |||||fD S t|||||dS )Nr   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fc                    s    fdd}|S )Nc                     s    | f S r   r   )inputs)moduler   r   r   r   custom_forwardO  s    zKRealmEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr   )r   r   r   )r   r   create_custom_forwardN  s    z3RealmEncoder.forward.<locals>.create_custom_forwardr   rY   r   r%   c                 s   s   | ]}|d k	r|V  qd S r   r   )r   vr   r   r   r    s  s   z'RealmEncoder.forward.<locals>.<genexpr>)last_hidden_stater   r   
attentionscross_attentions)rJ   r   r   trainingr)   Zwarning_once	enumerater   rG   utils
checkpointtupler
   )rs   r   r   r   r   r   r   r   r   r   r   Zall_hidden_statesZall_self_attentionsZall_cross_attentionsZnext_decoder_cacheiZlayer_moduleZlayer_head_maskr   Zlayer_outputsr   r   r   r   ,  sv    
	

zRealmEncoder.forward)	NNNNNNFFT)r5   r   r   r_   rG   r   r   r   r   r   r   r
   r   r   r   r   rt   r   r   %  s.   	         r   c                       s0   e Zd Z fddZejejdddZ  ZS )RealmPoolerc                    s*   t    t|j|j| _t | _d S r   )r^   r_   r   r   rb   r   ZTanh
activationrr   rt   r   r   r_     s    
zRealmPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )rs   r   Zfirst_token_tensorpooled_outputr   r   r   r     s    

zRealmPooler.forwardr   r   r   rt   r   r     s   r   c                   @   sL   e Zd ZU dZdZejed< dZe	e
ej  ed< dZe	e
ej  ed< dS )RealmEmbedderOutputa*  
    Outputs of [`RealmEmbedder`] models.

    Args:
        projected_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):

            Projected score.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nprojected_scorer   r   )r5   r   r   r   r   rG   r   __annotations__r   r   r   r   r   r   r   r   r     s   
r   c                   @   s<   e Zd ZU dZdZejed< dZejed< dZ	ejed< dS )RealmScorerOutputa'  
    Outputs of [`RealmScorer`] models.

    Args:
        relevance_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates)`):
            The relevance score of document candidates (before softmax).
        query_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):
            Query score derived from the query embedder.
        candidate_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates, config.retriever_proj_size)`):
            Candidate score derived from the embedder.
    Nrelevance_scorequery_scorecandidate_score)
r5   r   r   r   r   rG   r   r   r   r   r   r   r   r   r     s   
r   c                   @   s   e Zd ZU dZdZejed< dZejed< dZ	ejed< dZ
ejed< dZejed< dZejed< dZejed	< dZejed
< dZejed< dZeeej  ed< dZeeej  ed< dS )RealmReaderOutputa+	  
    Outputs of [`RealmReader`] models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            Total loss.
        retriever_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            Retriever loss.
        reader_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            Reader loss.
        retriever_correct (`torch.BoolTensor` of shape `(config.searcher_beam_size,)`, *optional*):
            Whether or not an evidence block contains answer.
        reader_correct (`torch.BoolTensor` of shape `(config.reader_beam_size, num_candidates)`, *optional*):
            Whether or not a span candidate contains answer.
        block_idx (`torch.LongTensor` of shape `()`):
            The index of the retrieved evidence block in which the predicted answer is most likely.
        candidate (`torch.LongTensor` of shape `()`):
            The index of the retrieved span candidates in which the predicted answer is most likely.
        start_pos (`torch.IntTensor` of shape `()`):
            Predicted answer starting position in *RealmReader*'s inputs.
        end_pos (`torch.IntTensor` of shape `()`):
            Predicted answer ending position in *RealmReader*'s inputs.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlossretriever_lossreader_lossretriever_correctreader_correct	block_idx	candidate	start_posend_posr   r   )r5   r   r   r   r   rG   r   r   r   r   r   
BoolTensorr   r   r   r  r  Zint32r  r   r   r   r   r   r   r   r   r     s   
#r   c                   @   s,   e Zd ZU dZdZeed< dZej	ed< dS )RealmForOpenQAOutputz

    Outputs of [`RealmForOpenQA`] models.

    Args:
        reader_output (`dict`):
            Reader output.
        predicted_answer_ids (`torch.LongTensor` of shape `(answer_sequence_length)`):
            Predicted answer ids.
    Nreader_outputpredicted_answer_ids)
r5   r   r   r   r  dictr   r  rG   r   r   r   r   r   r    s   
r  c                       s$   e Zd Z fddZdd Z  ZS )RealmPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S r   )r^   r_   r   r   rb   r   r2   r   r   r	   transform_act_fnrg   rh   rr   rt   r   r   r_   	  s    
z%RealmPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r
  rg   r   r   r   r   r     s    


z$RealmPredictionHeadTransform.forwardr5   r   r   r_   r   r   r   r   rt   r   r	    s   	r	  c                       s$   e Zd Z fddZdd Z  ZS )RealmLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)r$   )r^   r_   r	  	transformr   r   rb   ra   decoder	ParameterrG   ro   r$   rr   rt   r   r   r_     s
    

zRealmLMPredictionHead.__init__c                 C   s   |  |}| |}|S r   )r  r  r   r   r   r   r   '  s    

zRealmLMPredictionHead.forwardr  r   r   rt   r   r    s   r  c                       s$   e Zd Z fddZdd Z  ZS )RealmOnlyMLMHeadc                    s   t    t|| _d S r   )r^   r_   r  predictionsrr   rt   r   r   r_   .  s    
zRealmOnlyMLMHead.__init__c                 C   s   |  |}|S r   )r  )rs   sequence_outputprediction_scoresr   r   r   r   2  s    
zRealmOnlyMLMHead.forwardr  r   r   rt   r   r  -  s   r  c                       s$   e Zd Z fddZdd Z  ZS )RealmScorerProjectionc                    s>   t    t|| _t|j|j| _tj	|j|j
d| _	d S r   )r^   r_   r  r  r   r   rb   retriever_proj_sizer   rg   rh   rr   rt   r   r   r_   8  s    

zRealmScorerProjection.__init__c                 C   s   |  |}| |}|S r   )r   rg   r   r   r   r   r   >  s    

zRealmScorerProjection.forwardr  r   r   rt   r   r  7  s   r  c                       s$   e Zd Z fddZdd Z  ZS )RealmReaderProjectionc                    sX   t    || _t|j|jd | _t|jd| _tj	|j|j
d| _t | _d S )Nr%   r   rU   )r^   r_   rJ   r   r   rb   Zspan_hidden_sizedense_intermediatedense_outputrg   Zreader_layer_norm_epslayer_normalizationZReLUrelurr   rt   r   r   r_   E  s    
zRealmReaderProjection.__init__c                    s    fdd}t jfdd} |}|jddd\}}||\}}}	t j|d|d	}
t j|d|d	}|
| } |} |} |d}|||	|j	d
7 }|||fS )Nc                    s   j \}fdd t fddtjjD  \}}t|d}t|d}tjd|d}tjd|d}|| }|||fS )aK  
            Generate span candidates.

            Args:
                masks: <bool> [num_retrievals, max_sequence_len]

            Returns:
                starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
                whether spans locate in evidence block.
            c                    s6   t j|  d  jd}t j| d  jd}||fS )Nr   rz   )rG   rm   rz   )widthZcurrent_startsZcurrent_ends)masksmax_sequence_lenr   r   _spans_given_width[  s    zRRealmReaderProjection.forward.<locals>.span_candidates.<locals>._spans_given_widthc                 3   s   | ]} |d  V  qdS )r   Nr   )r   w)r  r   r   r    `  s     zIRealmReaderProjection.forward.<locals>.span_candidates.<locals>.<genexpr>r   rY   r   r   )rD   r1   r   rJ   max_span_widthrG   r   index_select)r  r   ZstartsZendsZstart_masksZ	end_masksZ
span_masksrs   )r  r  r  r   span_candidatesN  s    
"z6RealmReaderProjection.forward.<locals>.span_candidatesc                 S   s   d|  | t|j S N      ?typerG   Zfinfominmaskr]   r   r   r   mask_to_scorem  s    z4RealmReaderProjection.forward.<locals>.mask_to_scorer%   rY   r   r   r!  r\   )
rG   float32r  chunkr#  r  r  r  squeezer]   )rs   r   
block_maskr%  r-  Zstart_projectionZend_projectioncandidate_startscandidate_endsZcandidate_maskZcandidate_start_projectionsZcandidate_end_projectionsZcandidate_hiddenreader_logitsr   r$  r   r   M  s    


zRealmReaderProjection.forwardr  r   r   rt   r   r  D  s   r  aH  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`RealmConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a5
  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                   @   s,   e Zd ZdZeZeZdZdd Z	dd Z
dS )RealmPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    realmc                 C   s   t |tjr:|jjjd| jjd |jdk	r|jj	  nft |tj
rz|jjjd| jjd |jdk	r|jj|j 	  n&t |tjr|jj	  |jjd dS )zInitialize the weightsg        )meanZstdNr'  )r2   r   r   r"   rH   Znormal_rJ   Zinitializer_ranger$   Zzero_r`   rT   rg   Zfill_)rs   r   r   r   r   _init_weights  s    

z"RealmPreTrainedModel._init_weightsc                 G   sT   g }|D ]F}|dkr | d q|j}t|dkrD|d|d f}| | q|S )z.Flatten inputs' shape to (-1, input_shape[-1])Nr%   rY   )r0   rD   rA   r   )rs   r   Zflattened_inputsr   r|   r   r   r   _flatten_inputs  s    z$RealmPreTrainedModel._flatten_inputsN)r5   r   r   r   r   config_classrR   Zload_tf_weightsZbase_model_prefixr8  r9  r   r   r   r   r5    s   r5  c                       sD   e Zd ZdZd fdd	Zdd Zdd Zd	d
 ZdddZ  Z	S )RealmBertModelz?
    Same as the original BertModel but remove docstrings.
    Tc                    sD   t  | || _t|| _t|| _|r2t|nd | _| 	  d S r   )
r^   r_   rJ   rS   r   r   encoderr   pooler	post_init)rs   rJ   Zadd_pooling_layerrt   r   r   r_     s    

zRealmBertModel.__init__c                 C   s   | j jS r   r   rc   r$  r   r   r   get_input_embeddings  s    z#RealmBertModel.get_input_embeddingsc                 C   s   || j _d S r   r?  rs   r   r   r   r   set_input_embeddings  s    z#RealmBertModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr<  r   r   r   )rs   Zheads_to_pruner   r   r   r   r   _prune_heads  s    zRealmBertModel._prune_headsNc                 C   s^  |d k	r|n| j j}|d k	r |n| j j}|d k	r4|n| j j}| j jrZ|
d k	rP|
n| j j}
nd}
|d k	rx|d k	rxtdn@|d k	r| || | }n"|d k	r| d d }ntd|\}}|d k	r|j	n|j	}|	d k	r|	d d j
d nd}|d krtj||| f|d}|d krft| jdrT| jjd d d |f }|||}|}ntj|tj|d	}| ||}| j jr|d k	r| \}}}||f}|d krtj||d}| |}nd }| || j j}| j|||||d
}| j||||||	|
|||d
}|d }| jd k	r$| |nd }|sB||f|dd   S t|||j|j|j|jdS )NFzDYou cannot specify both input_ids and inputs_embeds at the same timerY   z5You have to specify either input_ids or inputs_embedsr   r%   r  r[   ry   )ru   rX   r[   rv   rw   )	r   r   r   r   r   r   r   r   r   r   )r   pooler_outputr   r   r   r   )rJ   r   r   use_return_dictr   r   r   Z%warn_if_padding_and_no_attention_maskrp   rz   rD   rG   Zonesr{   r   r[   rn   ro   rq   Zget_extended_attention_maskZinvert_attention_maskZget_head_maskr   r<  r=  r   r   r   r   r   )rs   ru   r   r[   rX   r   rv   r   r   r   r   r   r   r   r|   
batch_sizer}   rz   rw   r~   r   Zextended_attention_maskZencoder_batch_sizeZencoder_sequence_lengthr   Zencoder_hidden_shapeZencoder_extended_attention_maskZembedding_outputZencoder_outputsr  r   r   r   r   r     s    




zRealmBertModel.forward)T)NNNNNNNNNNNNN)
r5   r   r   r   r_   r@  rB  rD  r   r   r   r   rt   r   r;    s&   
             r;  z`The embedder of REALM outputting projected score that will be used to calculate relevance score.c                       s   e Zd ZdgZ fddZdd Zdd Zee	de
eed	deej eej eej eej eej eej ee ee ee eeef d
ddZ  ZS )r:   zcls.predictions.decoder.biasc                    s0   t  | t| j| _t| j| _|   d S r   )r^   r_   r;  rJ   r6  r  r   r>  rr   rt   r   r   r_     s    zRealmEmbedder.__init__c                 C   s
   | j jjS r   r6  r   rc   r$  r   r   r   r@    s    z"RealmEmbedder.get_input_embeddingsc                 C   s   || j j_d S r   rH  rA  r   r   r   rB    s    z"RealmEmbedder.set_input_embeddingsbatch_size, sequence_lengthoutput_typer:  N)
ru   r   r[   rX   r   rv   r   r   r   rx   c
                 C   sn   |	dk	r|	n| j j}	| j|||||||||	d	}
|
d }| |}|	sX|f|
dd  S t||
j|
jdS dS )a  
        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RealmEmbedder
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder")
        >>> model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> projected_score = outputs.projected_score
        ```
        Nr   r[   rX   r   rv   r   r   r   r   r%      )r   r   r   )rJ   rF  r6  r   r   r   r   )rs   ru   r   r[   rX   r   rv   r   r   r   Zrealm_outputsrE  r   r   r   r   r     s*    !
zRealmEmbedder.forward)	NNNNNNNNN)r5   r   r   _tied_weights_keysr_   r@  rB  r   REALM_INPUTS_DOCSTRINGformatr   r   _CONFIG_FOR_DOCr   rG   r   r   r   r   r   r   r   r   r   rt   r   r:   {  s6   
         
r:   zoThe scorer of REALM outputting relevance scores representing the score of document candidates (before softmax).c                       s   e Zd ZdZd
 fdd	Zeedee	e
ddeej eej eej eej eej eej eej eej eej eej ee ee ee eee	f ddd	Z  ZS )RealmScorerz
    Args:
        query_embedder ([`RealmEmbedder`]):
            Embedder for input sequences. If not specified, it will use the same embedder as candidate sequences.
    Nc                    s8   t  | t| j| _|d k	r$|n| j| _|   d S r   )r^   r_   r:   rJ   embedderquery_embedderr>  )rs   rJ   rT  rt   r   r   r_     s    zRealmScorer.__init__rI  rJ  )ru   r   r[   rX   candidate_input_idscandidate_attention_maskcandidate_token_type_idscandidate_inputs_embedsr   rv   r   r   r   rx   c                 C   s   |dk	r|n| j j}|dkr,|
dkr,td|dkrD|dkrDtd| j|||||	|
|||d	}| |||\}}}| j|||||	||||d	}|d }|d }|d| j j| j j}t	
d||}|s|||fS t|||dS )	a
  
        candidate_input_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`):
            Indices of candidate input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        candidate_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_candidates, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        candidate_token_type_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        candidate_inputs_embeds (`torch.FloatTensor` of shape `(batch_size * num_candidates, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `candidate_input_ids` you can choose to directly pass an embedded
            representation. This is useful if you want more control over how to convert *candidate_input_ids* indices
            into associated vectors than the model's internal embedding lookup matrix.

        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, RealmScorer

        >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-scorer")
        >>> model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=2)

        >>> # batch_size = 2, num_candidates = 2
        >>> input_texts = ["How are you?", "What is the item in the picture?"]
        >>> candidates_texts = [["Hello world!", "Nice to meet you!"], ["A cute cat.", "An adorable dog."]]

        >>> inputs = tokenizer(input_texts, return_tensors="pt")
        >>> candidates_inputs = tokenizer.batch_encode_candidates(candidates_texts, max_length=10, return_tensors="pt")

        >>> outputs = model(
        ...     **inputs,
        ...     candidate_input_ids=candidates_inputs.input_ids,
        ...     candidate_attention_mask=candidates_inputs.attention_mask,
        ...     candidate_token_type_ids=candidates_inputs.token_type_ids,
        ... )
        >>> relevance_score = outputs.relevance_score
        ```Nz5You have to specify either input_ids or input_embeds.zJYou have to specify either candidate_input_ids or candidate_inputs_embeds.rL  r   rY   z
bd,bnd->bn)r   r   r   )rJ   rF  r   rT  r9  rS  r   num_candidatesr  rG   r   r   )rs   ru   r   r[   rX   rU  rV  rW  rX  r   rv   r   r   r   Zquery_outputsflattened_input_idsflattened_attention_maskflattened_token_type_idsZcandidate_outputsr   r   r   r   r   r   r     sV    I  

  zRealmScorer.forward)N)NNNNNNNNNNNNN)r5   r   r   r   r_   r   rO  rP  r   r   rQ  r   rG   r   r   r   r   r   r   r   r   r   rt   r   rR    sB   	
             
rR  zrThe knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood loss.c                       s   e Zd ZdgZ fddZdd Zdd Zdd	 Zd
d Ze	e
deeeddeej eej eej eej eej eej eej eej eej ee ee ee eeef dddZ  ZS )r9   zcls.predictions.decoderc                    s0   t  | t| j| _t| j| _|   d S r   )r^   r_   r;  rJ   r6  r  r   r>  rr   rt   r   r   r_   h  s    z!RealmKnowledgeAugEncoder.__init__c                 C   s
   | j jjS r   rH  r$  r   r   r   r@  n  s    z-RealmKnowledgeAugEncoder.get_input_embeddingsc                 C   s   || j j_d S r   rH  rA  r   r   r   rB  q  s    z-RealmKnowledgeAugEncoder.set_input_embeddingsc                 C   s
   | j jjS r   r   r  r  r$  r   r   r   get_output_embeddingst  s    z.RealmKnowledgeAugEncoder.get_output_embeddingsc                 C   s   || j j_d S r   r]  )rs   Znew_embeddingsr   r   r   set_output_embeddingsw  s    z.RealmKnowledgeAugEncoder.set_output_embeddingsz+batch_size, num_candidates, sequence_lengthrJ  N)ru   r   r[   rX   r   rv   r   labelsmlm_maskr   r   r   rx   c                 C   sz  |dk	r|n| j j}| |||\}}}| j|||||||
||d	}|d }| |}|}d}|dk	r6|dkrxtd| \}}|	dkrtj|tj	d}	n|	
tj	}	tdd}|d| j j}|d	| j jd}||||| j j| }|dd}|| }|d	}tt||	 t|	  }|sf|f|d
d  }|dk	rb|f| S |S t|||j|jdS )a  
        relevance_score (`torch.FloatTensor` of shape `(batch_size, num_candidates)`, *optional*):
            Relevance score derived from RealmScorer, must be specified if you want to compute the masked language
            modeling loss.

        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

        mlm_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid calculating joint loss on certain positions. If not specified, the loss will not be masked.
            Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, RealmKnowledgeAugEncoder

        >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
        >>> model = RealmKnowledgeAugEncoder.from_pretrained(
        ...     "google/realm-cc-news-pretrained-encoder", num_candidates=2
        ... )

        >>> # batch_size = 2, num_candidates = 2
        >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]

        >>> inputs = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        ```NrL  r   zZYou have to specify `relevance_score` when `labels` is specified in order to compute loss.r\   none)Z	reductionrY   r   r%   rM  )r   logitsr   r   )rJ   rF  r9  r6  r   r   rp   rG   Z	ones_liker.  r)  r   r   ra   ZtilerY  Zlog_softmax	unsqueeze	logsumexpZnansumsumr   r   r   )rs   ru   r   r[   rX   r   rv   r   r`  ra  r   r   r   rZ  r[  r\  Zjoint_outputsZjoint_outputr  r   Zmasked_lm_lossrG  r}   Zloss_fctZ
mlm_logitsZmlm_targetsZmasked_lm_log_probZcandidate_log_probZjoint_gold_log_probZmarginal_gold_log_probsr   r   r   r   r   z  sf    9  




  
 z RealmKnowledgeAugEncoder.forward)NNNNNNNNNNNN)r5   r   r   rN  r_   r@  rB  r^  r_  r   rO  rP  r   r   rQ  r   rG   r   r   r   r   r   r   r   r   r   rt   r   r9   `  sJ   
            
r9   zThe reader of REALM.c                       s   e Zd Z fddZeedeee	dd	e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
e e
e e
e eeef dddZ  ZS )
r3   c                    s>   t  | |j| _t|| _t|| _t|| _| 	  d S r   )
r^   r_   Z
num_labelsr;  r6  r  r   r  
qa_outputsr>  rr   rt   r   r   r_     s    


zRealmReader.__init__z!reader_beam_size, sequence_lengthrJ  N)ru   r   r[   rX   r   rv   r   r1  start_positionsend_positionshas_answersr   r   r   rx   c           $      C   sL  |dk	r|n| j j}|dkr$td|dkr4td|d| j jk rNtd| j|||||||||d	}|d }| ||d| j j \}}}t	|d| j j d}||7 }t
tj|dd	j}t
tj|dd	j}tj|d|d
}tj|d|d
}d}d}d}d}d}|	dk	r|
dk	r|dk	rdd }dd }|d} |	d| }	|
d| }
|}t|}!||||	d| j j |
d| j j d}t|}"|||}||d|d}||!tj9 }||"tj9 }||  }|s*||||f|dd  }#|dk	r&|||||f|# S |#S t||||||||||j|jdS )ar  
        relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
            Relevance score, which must be specified if you want to compute the logits and marginal log loss.
        block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*):
            The mask of the evidence block, which must be specified if you want to compute the logits and marginal log
            loss.
        start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        has_answers (`torch.BoolTensor` of shape `(searcher_beam_size,)`, *optional*):
            Whether or not the evidence block has answer(s).

        Returns:
        NzCYou have to specify `relevance_score` to calculate logits and loss.zOYou have to specify `block_mask` to separate question block and evidence block.r   zQThe input sequence length must be greater than or equal to config.max_span_width.rL  r   rY   r   r!  c                 S   s\   t t t | ddt |d}t t t |ddt |d}t t ||dS )zCompute correct span.r   rY   r   )rG   eqrd  r<   logical_and)r2  r3  gold_starts	gold_endsZis_gold_startZis_gold_endr   r   r   compute_correct_candidates[  s     
 
z7RealmReader.forward.<locals>.compute_correct_candidatesc                 S   s@   t jfdd}t j| ||| jd dd}t j| dd}|| S )z3Loss based on the negative marginal log-likelihood.c                 S   s   d|  | t|j S r&  r(  r+  r   r   r   r-  k  s    zERealmReader.forward.<locals>.marginal_log_loss.<locals>.mask_to_scorer\   rY   r   )rG   r.  re  r]   )rc  Z
is_correctr-  Zlog_numeratorZlog_denominatorr   r   r   marginal_log_lossh  s    z.RealmReader.forward.<locals>.marginal_log_loss)r2  r3  rm  rn  r%   )r   r   r   r   r   r   r  r  r  r   r   )rJ   rF  r   rp   r"  r6  rg  reader_beam_sizerG   rd  Zargmaxmaxvaluesr#  clampr<   r   r)  r.  r7  r   r   r   )$rs   ru   r   r[   rX   r   rv   r   r1  rh  ri  rj  r   r   r   r   r  r4  r2  r3  Zretriever_logitsZpredicted_block_indexZpredicted_candidateZpredicted_startZpredicted_endZ
total_lossr   r   r   r   ro  rp  Zignored_indexZany_retriever_correctZany_reader_correctr   r   r   r   r     s    & 




zRealmReader.forward)NNNNNNNNNNNNNN)r5   r   r   r_   r   rO  rP  r   r   rQ  r   rG   r   r   r  r   r   r   r   r   r   r   rt   r   r3     sD   

              
r3   ay  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token (should not be used in this model by design).

            [What are token type IDs?](../glossary#token-type-ids)
        answer_ids (`list` of shape `(num_answers, answer_length)`, *optional*):
            Answer ids for computing the marginal log-likelihood loss. Indices should be in `[-1, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-1` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z?`RealmForOpenQA` for end-to-end open domain question answering.c                       s   e Zd Zd fdd	Zedd Zdd Zee	de
eed	deej eej eej eej ee eeef d
ddZ  ZS )r7   Nc              	      s`   t  | t|| _t|| _| dtdj	|j
|jftjtdd || _|   d S )N	block_embr   cpu)rp   r]   rz   )r^   r_   r:   rS  r3   r   rl   rG   ro   Z	new_emptyZnum_block_recordsr  r.  rz   	retrieverr>  )rs   rJ   rw  rt   r   r   r_     s    



zRealmForOpenQA.__init__c                 C   s   | j r| jjS | jjS r   )r   rJ   searcher_beam_sizerq  r$  r   r   r   rx    s    z!RealmForOpenQA.searcher_beam_sizec                 C   s   | j || _ dS )zSend `self.block_emb` to a specific device.

        Args:
            device (`str` or `torch.device`):
                The device to which `self.block_emb` will be sent.
        N)ru  r   )rs   rz   r   r   r   block_embedding_to  s    z!RealmForOpenQA.block_embedding_toz1, sequence_lengthrJ  )ru   r   r[   
answer_idsr   rx   c                 C   s  |dk	r|n| j j}|dk	r2|jd dkr2td| j|||dd}|d }td| j|| jj	}tj
|| jdd	\}	}
|
 }
tj| jd|
d
}| j|
 ||| j jd\}}}}|| jj	}|jtjj| jj	d}| |jtj |dk	rDtj|tj| jj	d}tj|tj| jj	d}tj|tj| jj	d}td| || jj	}| j|jd| j j |jd| j j |jd| j j |||||dd	}|j|j }||j|jd  }|s||fS t ||dS )a  
        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import RealmForOpenQA, RealmRetriever, AutoTokenizer

        >>> retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
        >>> model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa", retriever=retriever)

        >>> question = "Who is the pioneer in modern computer science?"
        >>> question_ids = tokenizer([question], return_tensors="pt")
        >>> answer_ids = tokenizer(
        ...     ["alan mathison turing"],
        ...     add_special_tokens=False,
        ...     return_token_type_ids=False,
        ...     return_attention_mask=False,
        ... ).input_ids

        >>> reader_output, predicted_answer_ids = model(**question_ids, answer_ids=answer_ids, return_dict=False)
        >>> predicted_answer = tokenizer.decode(predicted_answer_ids)
        >>> loss = reader_output.loss
        ```Nr   r   z'The batch_size of the inputs must be 1.T)ru   r[   r   r   z	BD,QD->QBrY   )kr   r!  )
max_lengthr  ry   zD,BD->B)	ru   r   r[   r   r1  rj  rh  ri  r   )r  r  )!rJ   rF  rD   r   rS  rG   r   ru  r   rz   Ztopkrx  r0  r#  rw  rv  Zreader_seq_lenr   Zspecial_tokens_maskr)  r   Zlogical_not_Zlogical_and_r[   r   rq   ru   rq  r   r   r  r  r  )rs   ru   r   r[   rz  r   Zquestion_outputsZquestion_projectionZbatch_scoresr   Zretrieved_block_idsZretrieved_block_embrj  r  r  Zconcat_inputsr1  Zretrieved_logitsr  Zpredicted_blockr  r   r   r   r     sf    %      
  zRealmForOpenQA.forward)N)NNNN)r5   r   r   r_   propertyrx  ry  r   REALM_FOR_OPEN_QA_DOCSTRINGrP  r   r  rQ  r   rG   r   r   r   r   r   r   r   r   r   rt   r   r7     s$   


    
r7   )Gr   r   r+   dataclassesr   typingr   r   r   rG   r   Ztorch.nnr   Zactivationsr	   Zmodeling_outputsr
   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   r   r   r   r   Zconfiguration_realmr   Z
get_loggerr5   r)   Z_EMBEDDER_CHECKPOINT_FOR_DOCZ_ENCODER_CHECKPOINT_FOR_DOCZ_SCORER_CHECKPOINT_FOR_DOCrQ  Z#REALM_PRETRAINED_MODEL_ARCHIVE_LISTrR   ModulerS   r   r   r   r   r   r   r   r   r   r   r   r  r	  r  r  r  r  ZREALM_START_DOCSTRINGrO  r5  r;  r:   rR  r9   r3   r~  r7   r   r   r   r   <module>   s   
lA 2Wc1
C2( N   (!