U
    ,-e_                    @   s^  d Z ddlZddlZddlZddlmZ ddlmZmZm	Z	 ddl
Z
ddlZ
ddl
mZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZ ddlmZ eeZ dZ!dZ"dgZ#dZ$dZ%dZ&dCddZ'dd Z(dDddZ)dd Z*eG dd deZ+eG dd deZ,eG d d! d!eZ-eG d"d# d#eZ.G d$d% d%eZ/G d&d' d'ej0Z1G d(d) d)ej2Z3G d*d+ d+ej2Z4G d,d- d-ej2Z5G d.d/ d/ej2Z6G d0d1 d1ej2Z7ed2e$G d3d4 d4e/Z8ed5e$G d6d7 d7e/Z9ed8e$G d9d: d:e/Z:ed;e$G d<d= d=e/Z;ed>e$G d?d@ d@e/Z<G dAdB dBe/Z=dS )EzS PyTorch ProphetNet model, ported from ProphetNet repo(fairsequery_states version).    N)	dataclass)OptionalTupleUnion)Tensornn)	LayerNorm   )ACT2FN)BaseModelOutput)PreTrainedModel)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )ProphetNetConfigZProphenetConfigz"microsoft/prophetnet-large-uncasedat  
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    Original ProphetNet code can be found [here](https://github.com/microsoft/ProphetNet). Checkpoints were converted
    from original Fairseq checkpoints. For more information on the checkpoint conversion, please take a look at the
    file `convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py`.

    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and
    behavior.

    Parameters:
        config ([`ProphetNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Fc                 C   s0   |rt jj|  |dS t jj| |tjdS d S )Ndimr   dtype)r   
functionalsoftmaxfloattorchfloat32)Zhidden_stater   
onnx_trace r   s/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/transformers/models/prophetnet/modeling_prophetnet.pyr      s    r   c                 C   s   t j|| | f||dt |j }|  }t|D ]*}|| jddd || | d  q6d|dddddf< t j	||gddS )	z@
    This function computes the bias for the predict stream
    )devicer   r   F)wrapr   N   r   )
r   onesfinfomindetachclonerangeZfill_diagonal_Ztriu_cat)sequence_lengthngramr    r   Z
left_blockZright_blockZ
stream_idxr   r   r   ngram_attention_bias   s     r,   c           	      C   s   | }d}|r@| d } |t |t | |   }t |}nt |t |}| d }t ||}|t | | t||  | |   }t 	|t 
|| d   }|t || | }|S )zo
    This function computes individual parts of the relative position buckets. For more detail, see paper.
    r   r"   r   )r   lt
zeros_likeintabsmaxlogr   mathr%   Z	ones_likewhere)	num_bucketsmax_distancerelative_positionsis_bidirectionalZinv_relative_positionsZrel_positions_bucketZ	max_exactZis_smallZval_if_larger   r   r   compute_relative_buckets   s(    r9   c                 C   s   | dd|dd}|| d }tj|d |fdd d}|d|dd}|| d }t| ||dd}t| ||dd}||fS )zm
    This function computes both main and predict relative position buckets. For more detail, see paper.
    r   r   F)r8   )	unsqueezerepeatsizer   r)   r9   )r5   r6   position_idsZmain_stream_relative_positionsZ$predicting_stream_relative_positionsmain_relative_position_buckets!predict_relative_position_bucketsr   r   r   #compute_all_stream_relative_buckets   s$          rA   c                   @   s  e Zd ZU dZdZeej ed< dZ	ejed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dZeeej  ed< dZeej ed< dZeeej  ed< dZeeej  ed< edd ZdS )ProphetNetSeq2SeqLMOutputa  
    Base class for sequence-to-sequence language models outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss.
        logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
            Prediction scores of the main stream language modeling head (scores for each vocabulary token before
            SoftMax).
        logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
            Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
            SoftMax).
        past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
            num_attn_heads, decoder_sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
            used (see `past_key_values` input) to speed up sequential decoding.
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, decoder_sequence_length, hidden_size)`.

            Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
        decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

            Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
            outputs.
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            decoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            decoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            encoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
            compute the weighted average in the
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, encoder_sequence_length, hidden_size)`.

            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            encoder_sequence_length, encoder_sequence_length)`. Attentions weights of the encoder, after the attention
            softmax, used to compute the weighted average in the self-attention heads.
    Nlosslogitslogits_ngrampast_key_valuesdecoder_hidden_statesdecoder_ngram_hidden_statesdecoder_attentionsdecoder_ngram_attentionscross_attentionsencoder_last_hidden_stateencoder_hidden_statesencoder_attentionsc                 C   s   t dt | jS Nzi`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.warningswarnFutureWarningrK   selfr   r   r   decoder_cross_attentionsG  s
    z2ProphetNetSeq2SeqLMOutput.decoder_cross_attentions)__name__
__module____qualname____doc__rC   r   r   FloatTensor__annotations__rD   rE   rF   r   rG   rH   rI   rJ   rK   rL   rM   rN   propertyrV   r   r   r   r   rB      s   
<rB   c                   @   s   e Zd ZU dZejed< dZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed	< dZee
ej  ed
< dZeej ed< dZee
ej  ed< dZee
ej  ed< edd ZdS )ProphetNetSeq2SeqModelOutputa2  
    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
    decoding.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
            Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
            Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
        past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
            num_attn_heads, decoder_sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
            used (see `past_key_values` input) to speed up sequential decoding.
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, decoder_sequence_length, hidden_size)`.

            Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
        decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

            Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
            outputs.
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            decoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            decoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
            weighted average in the
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            encoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
            compute the weighted average in the
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, encoder_sequence_length, hidden_size)`.

            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            encoder_sequence_length, encoder_sequence_length)`.

            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
    last_hidden_stateNlast_hidden_state_ngramrF   rG   rH   rI   rJ   rK   rL   rM   rN   c                 C   s   t dt | jS rO   rP   rT   r   r   r   rV     s
    z5ProphetNetSeq2SeqModelOutput.decoder_cross_attentions)rW   rX   rY   rZ   r   r[   r\   r`   r   rF   r   rG   rH   rI   rJ   rK   rL   rM   rN   r]   rV   r   r   r   r   r^   Q  s   
>
r^   c                   @   s   e Zd ZU dZejed< dZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed	< dZee
ej  ed
< dS )ProphetNetDecoderModelOutputaZ  
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
            Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
            Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
        past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
            num_attn_heads, decoder_sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
            used (see `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, decoder_sequence_length, hidden_size)`.

            Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
        ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

            Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
            outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            decoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            decoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
            weighted average in the
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            encoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
            compute the weighted average in the
    r_   Nr`   rF   hidden_stateshidden_states_ngram
attentionsngram_attentionsrK   )rW   rX   rY   rZ   r   r[   r\   r`   r   rF   r   rb   rc   rd   re   rK   r   r   r   r   ra     s   
0
ra   c                   @   s   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dZeeej  ed< dS )ProphetNetDecoderLMOutputam  
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss.
        logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
            Prediction scores of the main stream language modeling head (scores for each vocabulary token before
            SoftMax).
        logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
            Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
            SoftMax).
        past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
            num_attn_heads, decoder_sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
            used (see `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, decoder_sequence_length, hidden_size)`.

            Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
        ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

            Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
            outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            decoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            decoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
            weighted average in the
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
            encoder_sequence_length, decoder_sequence_length)`.

            Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
            compute the weighted average in the
    NrC   rD   rE   rF   rb   rc   rd   re   rK   )rW   rX   rY   rZ   rC   r   r   r[   r\   rD   rE   rF   r   rb   rc   rd   re   rK   r   r   r   r   rf     s   
1rf   c                   @   s2   e Zd ZeZdZdZdd ZdddZdd	 Z	d
S )ProphetNetPreTrainedModel
prophetnetTc                 C   s|   t |tjr:|jjjd| jjd |jd k	rx|jj	  n>t |tj
rx|jjjd| jjd |jd k	rx|jj|j 	  d S )N        )meanZstd)
isinstancer   LinearweightdataZnormal_configZinit_stdbiasZzero_	Embeddingpadding_idx)rU   moduler   r   r   _init_weights&  s    

z'ProphetNetPreTrainedModel._init_weightsFc                 C   s   t |ttfr||_d S N)rk   ProphetNetDecoderProphetNetEncodergradient_checkpointing)rU   rs   valuer   r   r   _set_gradient_checkpointing0  s    z5ProphetNetPreTrainedModel._set_gradient_checkpointingc                 C   s   | j j}| j j}|d k	s td||j}|dd df  |ddd f< ||d< |d k	sdtd||dk| t	|dk
 std	|S )
Nzself.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the pad_token_id. See ProphetNet docs for more information.r:   r   ).r   z1self.model.config.pad_token_id has to be defined.r   z8Verify that `shifted_input_ids` has only positive values)ro   decoder_start_token_idpad_token_idAssertionError	new_zerosshaper'   Zmasked_fill_r   allitem)rU   	input_idsr|   r}   Zshifted_input_idsr   r   r   _shift_right4  s    
 z&ProphetNetPreTrainedModel._shift_rightN)F)
rW   rX   rY   r   config_classZbase_model_prefixZsupports_gradient_checkpointingrt   rz   r   r   r   r   r   rg   !  s   

rg   c                       sB   e Zd ZdZedd fddZd
 fdd	Z fdd	Z  ZS )ProphetNetPositionalEmbeddingsa  
    This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
    based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
    the forward function.
    N)ro   returnc                    s"   |j | _t |j |j|j d S ru   )max_position_embeddings
max_lengthsuper__init__hidden_sizer}   rU   ro   	__class__r   r   r   R  s    z'ProphetNetPositionalEmbeddings.__init__c                    s   |d ks| j d kstd|d kr|d k	rj|d d jd }|d | }tjdtj|dt| j |  }nN|d krtj|tj|d}tj|dd||  | j  }|	d| j
d }t ||fS )NzCIf position_ids is pre-computed then padding_idx should not be set.r   r"   r   )r   r   r   r    r   )rr   r~   r   r   r#   longr/   Zcumsumtype_asclampr   r   forward)rU   Zinputs_shaper    attention_maskrF   r>   Zprev_num_input_idsZnum_input_idsr   r   r   r   V  s(    z&ProphetNetPositionalEmbeddings.forwardc                    s   t  |S ru   )r   r   )rU   r>   r   r   r   _forwardr  s    z'ProphetNetPositionalEmbeddings._forward)NNN)	rW   rX   rY   rZ   r   r   r   r   __classcell__r   r   r   r   r   K  s   r   c                
       sv   e Zd ZdZeed fddZejeedddZ	de
e e
e e
e e
ee  eeee
e f d
ddZ  ZS )ProphetNetAttentionz=Multi-headed attention from 'Attention Is All You Need' paper)ro   num_attn_headsc                    s   t    |j}|j| _|j| _|| _|| | _| j| |ksFtdt	||| _
t	||| _t	||| _t	||| _d S )Nzw`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and `config.num_decoder_attention_heads`)r   r   r   attention_dropoutdropoutr   head_dimr~   r   rl   key_proj
value_proj
query_projout_proj)rU   ro   r   r   r   r   r   r   y  s    

zProphetNetAttention.__init__)tensorseq_lenbszc                 C   s    | ||| j| jdd S Nr   r"   viewr   r   	transpose
contiguous)rU   r   r   r   r   r   r   _shape  s    zProphetNetAttention._shapeNF)key_value_statesr   layer_head_maskpast_key_valueoutput_attentionsr   c                 C   s  |  \}}}	|d k	}
t|  |||	gksJtd|||	f d|   | || jd  }|
r||d k	r||d }|d }nV|
r| | |d|}| | |d|}n(| | |d|}| | |d|}|
r||f}|| jd| jf}| |||j	| }|j	| }|j	| }| d}t
d||dd	}|| j||f}|  |krjtd
| d|   |d k	r| dkrd }|| jd|f}|d k	r|  |krtd| d|   |d k	r|| }|r|}nd }tjj|dd}|d k	rd|  | jfks.td| jf d|   |	dddd|	|| j|| }|	dddd| }tjj|| j| jd}t
d||}|| j|| jf}|  |krtd| d|   |dd|||	}| |}tjj|| j| jd}|||fS )Nz Size of hidden states should be 	, but is       ?r   r   r:   r"   zbsij,bsjk->bsikr	   z#Attention weights should have size z Attention mask should have size r   /Head mask for a single layer should be of size ptrainingz `attn_output` should have shape , but is of shape )r=   listr~   r   r   r   r   r   r   r   r   einsumr   
ValueErrorr   r   r   r   r   r   r   reshaper   )rU   rb   r   r   r   r   r   
batch_sizetgt_lenr   Zis_cross_attentionquery_states
key_statesvalue_states
proj_shapesrc_lenattn_weightsZexpected_shapeZattn_weights_reshapedZ
attn_probsattn_outputr   r   r   r     s    	






   
zProphetNetAttention.forward)NNNNF)rW   rX   rY   rZ   r   r/   r   r   r   r   r   r   boolr   r   r   r   r   r   r   v  s$        
r   c                       s0   e Zd ZdZeed fddZdd Z  ZS )ProphetNetFeedForwardzm
    This is the residual two feed-forward layer block based on the original Transformer implementation.
    )ro   ffn_dimc                    sJ   t    t|j | _t|j|| _t||j| _	|j
| _
|j| _d S ru   )r   r   r
   Zactivation_functionactivation_fnr   rl   r   intermediateoutputactivation_dropoutr   )rU   ro   r   r   r   r   r     s    
zProphetNetFeedForward.__init__c                 C   sN   |  |}| |}tjj|| j| jd}| |}tjj|| j| jd}|S )Nr   )r   r   r   r   r   r   r   r   )rU   rb   r   r   r   r     s    


zProphetNetFeedForward.forward)	rW   rX   rY   rZ   r   r/   r   r   r   r   r   r   r   r     s   r   c                       sZ   e Zd Zed fddZdd Zdd Zdeee	  d	d
dZ
dd Zdd Z  ZS )ProphetNetNgramSelfAttentionro   c                    s   t    |j| _|j| _|j| _|j| _|j| _|j| _|j| j | _	|j
| _
| j	| j |jksjtdt|j|j| _t|j|j| _t|j|j| _t|j|j| _t|j| j| j | _d| _d S )Nz6config.hidden_size must be divisible by num_attn_headsF)r   r   r   r5   relative_max_distancenum_decoder_attention_headsr   r   r   r   r+   r~   r   rl   r   r   r   r   relative_pos_embeddingsr   r   r   r   r   r     s&    
z%ProphetNetNgramSelfAttention.__init__c                 C   s    | ||| j| jdd S r   r   )rU   r   r   r   r   r   r   r   ,  s    z#ProphetNetNgramSelfAttention._shapec                 C   s
   d| _ d S )NT)r   rT   r   r   r   prepare_for_onnx_export_/  s    z5ProphetNetNgramSelfAttention.prepare_for_onnx_export_N)r   c	           *         sx  |  \}	}
}t|  |	|
|gks@td|	|
|f d|j | |}| |}| |}|| jd  }| ||
|	}| |d|	}| |d|	}|	| j	d| jf}|j
| }|j
| }|j
| }|jd| j dd}|jd| j dd}|jd| j dd}|jd| j dd}|d |dd   }}|d |dd   }}|d |dd    }|d |dd   }|d k	r|d }tj| fdd |d }tj|fdd f}|
d| j  }td	| dd
}| ||||}|| }|d k	r || }t|d| jd|}|d k	rp|  | j	fksNtd| j	f d|   |
dddd|
|	| j	d| }tjj|| j| jd}td	|} | dd|	d||} | | } t|d
|	| j| j	|| j}!t fdd|D d}"tj|dd}#tfdd|D d}$td|!|"f}%| |#|%||}&|%|& }%|d k	rp|dddd
d}||%j }|%| }%t|%d| jd|%}'|d k	r|  | j	fkstd| j	f d|   |
ddddd|' }'tjj|'| j| jd}'td|'|$ddf}(|(dd
}(|(|	| j||}(| |(}(t| |(gd
|	d|})|
|	| j	|d}tjj|)| j| jd})|)||'|fS )Nz#`hidden_states` should be of shape r   r   r:   r   r   r"   r   zbntc,bncs->bntsr	   )r   r   r   r   r   c                    s   g | ]}t  |gd qS r"   )r   r)   ).0key)main_key_statesr   r   
<listcomp>  s     z8ProphetNetNgramSelfAttention.forward.<locals>.<listcomp>c                    s"   g | ]}t  |gd d qS r   )r   r)   r;   )r   Zv_p)main_value_statesr   r   r     s     zbnhtc,bnhsc->bnhts   zbnhts,bnhsc->bnhtc)!r=   r   r~   r   r   r   r   r   r   r   r   chunkr+   r   r)   r   r    get_main_relative_pos_embeddingsr   r   r   r   r   r   r   r   r   r   stack#get_predict_relative_pos_embeddingspermutetor   )*rU   rb   r   r   r   extended_predict_attention_maskr?   r@   r>   r   Zngram_sequence_lengthr   r   r   r   r   Zhidden_states_listZquery_states_listZkey_states_listZvalue_states_listZmain_hidden_statesZhidden_states_predict_listZmain_query_statesZpredict_query_states_listZpredict_key_states_listZpredict_value_states_listZprev_main_key_statesZprev_main_value_statesr*   Zmain_attn_weightsmain_relative_pos_embeddingsZmain_attn_probsZmain_attn_outputZpredict_query_statesZpredict_key_statesZpredict_hidden_statesZpredict_value_statesZpredict_attn_weightspredict_relative_pos_embeddingsZpredict_attn_probsZpredict_attn_outputr   r   )r   r   r   r   2  s    






   

   
        

   
z$ProphetNetNgramSelfAttention.forwardc                 C   sH  |j \}}}}|||||}|d kr|j d d \}}	td|j d d dd||	d|j}
|
|d||	d }
t| j	| j
|
d}| |}||j d d | j	| jf }|dddd}||j d d d }|d| jd}|d|j d }| }|d|d}tj|d|d}||||d}|S )	Nr"   r   r:   r   Fr	   )r:   r   index)r   r   r   aranger;   r<   r   r    r9   r5   r   r   r   r   r   r   r=   gather)rU   rb   r   r>   r?   r   r   r   r   r*   r7   rel_pos_embeddingsr   r   r   r   r     sN         
 z=ProphetNetNgramSelfAttention.get_main_relative_pos_embeddingsc                 C   sH  |j dd \}}|d kr|j d }|d d |d ks@tdtd|dd||d|j}||d||d }t| j	| j
|d}|dd}| |}	|	|j d d | j	| jf }	|	ddddd}	|	d| j	}	|d}|| jd| jd}|d|d }tj|	d|d	}
|
|| j| j|d}
|
S )
Nr   r"   r:   r   zb`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)Fr   r	   r   )r   r~   r   r   r;   r<   r   r    r9   r5   r   r   r   r   r   r   r   r+   r=   r   r   )rU   rb   r   r>   r@   r   r*   Zkey_sequence_lengthr7   r   r   r   r   r   r     sr    
     

          z@ProphetNetNgramSelfAttention.get_predict_relative_pos_embeddings)NNNNNNN)rW   rX   rY   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r     s           
 3-r   c                       s6   e Zd ZdZed fddZd	edddZ  ZS )
ProphetNetEncoderLayerz&
    Encoder block for Prophetnet
    r   c                    sB   t    t||j| _t|j| _t||j	| _
t|j| _d S ru   )r   r   r   num_encoder_attention_heads	self_attnr   r   self_attn_layer_normr   Zencoder_ffn_dimfeed_forwardfeed_forward_layer_normr   r   r   r   r   R  s
    
zProphetNetEncoderLayer.__init__Fr   c           
      C   sV   | j ||||d\}}}| || }| |}| || }|f}	|rR|	|f7 }	|	S )N)rb   r   r   r   )r   r   r   r   )
rU   rb   r   r   r   attention_outputr   _feed_forward_outputoutputsr   r   r   r   \  s    

zProphetNetEncoderLayer.forward)F	rW   rX   rY   rZ   r   r   r   r   r   r   r   r   r   r   M  s    r   c                       s8   e Zd ZdZed fddZdeedd	d
Z  ZS )ProphetNetDecoderLayerz&
    Decoder block for Prophetnet
    r   c                    s^   t    t|| _t|j| _|jr@t||j	| _
t|j| _t||j| _t|j| _d S ru   )r   r   r   r   r   r   r   add_cross_attentionr   r   
cross_attncross_attn_layer_normr   Zdecoder_ffn_dimr   r   r   r   r   r   r   }  s    

zProphetNetDecoderLayer.__init__NTF)	use_cacher   c              
   C   s   |d k	r|d d nd }| j |||||||	|
d\}}}}| || }|d k	r\|dd  nd }d }|d k	r| j||||||d\}}}| || }|| }| |}| || }|f}|r||||f7 }|r||f7 }|S )Nr"   )rb   r   r   r   r   r?   r@   r>   )rb   r   r   r   r   r   )r   r   r   r   r   r   )rU   rb   r   rM   encoder_attn_maskr   cross_attn_layer_head_maskr   r?   r@   r>   r   r   r   Zself_attn_past_key_valueZngram_attention_outputZself_attn_weightsZself_attn_weights_ngramZpresent_key_valueZcross_attn_past_key_valueZcross_attn_weightsr   Zcross_attn_present_key_valuer   r   r   r   r   r     sB    


zProphetNetDecoderLayer.forward)NNNNNNNNNNTFr   r   r   r   r   r   x  s$               r   z3The standalone encoder part of the ProphetNetModel.c                       s   e Zd ZdZdeejd fddZdd Zdd	 Z	e
eeeed
deej eej eej eej ee ee ee eeef dddZ  ZS )rw   (  
    word_embeddings  (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
        The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word
        embeddings instead of randomly initialized word embeddings.
    Nro   word_embeddingsc                    sx   t    |d k	r|ntj j j jd| _t | _	t
 j| _t fddt jD | _d| _|   d S )Nrr   c                    s   g | ]}t  qS r   )r   r   r   r   r   r   r     s     z.ProphetNetEncoder.__init__.<locals>.<listcomp>F)r   r   r   rq   
vocab_sizer   r}   r   r   position_embeddingsr   embeddings_layer_norm
ModuleListr(   Znum_encoder_layerslayersrx   	post_initrU   ro   r   r   r   r   r     s    
 zProphetNetEncoder.__init__c                 C   s   | j S ru   r   rT   r   r   r   get_input_embeddings  s    z&ProphetNetEncoder.get_input_embeddingsc                 C   s
   || _ d S ru   r  rU   ry   r   r   r   set_input_embeddings  s    z&ProphetNetEncoder.set_input_embeddingsoutput_typer   )r   r   	head_maskinputs_embedsr   output_hidden_statesreturn_dictr   c                    sn   dk	r n| j j |dk	r |n| j j}|dk	r4|n| j j}|dkrV|dkrVtdn4|dk	rp|dk	rptdn|dk	r|dkr| |}|dk	rd|ddddddf d| j jdd t	| j
j }||j
}nd}| |jdd |j\}	}
||	 }| |}tjj|| j j| jd}|r.dnd} r<dnd}|dk	r| d	 t| jkstd
t| j d| d	  dt| jD ]\}}|r||f }| jr| jr fdd}tjj|||||dk	r|| nd}n"||||dk	r|| nd d}|d	 } r||d f }q|rB||f }|s`tdd |||fD S t|||dS )a  
        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetEncoder
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        ```Nz3Either input_ids or inputs_embeds has to be passed.z2Make sure to only pass input_ids or inputs_embeds.      ?r   r"   r   r   r   z&The head_mask should be specified for  layers, but it is for .c                    s    fdd}|S )Nc                     s    | f S ru   r   inputs)rs   r   r   r   custom_forward6  s    zPProphetNetEncoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr   rs   r  r   rs   r   create_custom_forward5  s    z8ProphetNetEncoder.forward.<locals>.create_custom_forward)r   r   r   c                 s   s   | ]}|d k	r|V  qd S ru   r   r   vr   r   r   	<genexpr>R  s      z,ProphetNetEncoder.forward.<locals>.<genexpr>)r_   rb   rd   )ro   r   r  use_return_dictr   r   r<   r   r   r$   r   r%   r   r   r   r    r   r   r   r   r   r=   lenr  r~   	enumeraterx   utils
checkpointtupler   )rU   r   r   r
  r  r   r  r  extended_attention_maskr   r>   rb   rM   Zall_attentionsidxZencoder_layerr  layer_outputsr   r   r   r     sr    


*




  zProphetNetEncoder.forward)N)NNNNNNN)rW   rX   rY   rZ   r   r   rq   r   r  r  r   &PROPHETNET_STANDALONE_INPUTS_DOCSTRINGr   r   _CONFIG_FOR_DOCr   r   r   r   r   r   r   r   r   r   r   r   rw     s.   
       
rw   z3The standalone decoder part of the ProphetNetModel.c                       s   e Zd ZdZdeeej d fddZdd Z	dd	 Z
eeeeed
deej eej eej eej eej eej eeeej   eej ee ee ee ee eeef dddZdd Zdd Zdd Z  ZS )rv   r   Nr   c                    s   t     j| _ j| _ j| _ j| _ j| _|d k	r@|ntj	 j
 j jd| _t | _t	| j jd | _t fddt jD | _t j| _d| _|   d S )Nr   c                    s   g | ]}t  qS r   )r   r   r   r   r   r   t  s     z.ProphetNetDecoder.__init__.<locals>.<listcomp>F)r   r   r+   r5   r   r   r   max_target_positionsr   rq   r   r   r}   r   r   r   ngram_embeddingsr   r(   Znum_decoder_layersr  r   r   rx   r  r  r   r   r   r   c  s     
 zProphetNetDecoder.__init__c                 C   s   | j S ru   r  rT   r   r   r   r  {  s    z&ProphetNetDecoder.get_input_embeddingsc                 C   s
   || _ d S ru   r  r  r   r   r   r  ~  s    z&ProphetNetDecoder.set_input_embeddingsr  )r   r   rM   encoder_attention_maskr
  cross_attn_head_maskrF   r  r   r   r  r  r   c           &         sX  dk	rn| j jdk	r n| j j|dk	r4|n| j j}|dk	rH|n| j j}|dkrj|dkrjtdn4|dk	r|dk	rtdn|dk	r|dkr| |}|jdd \ }| j |f|j	|d\}}|dk	rd\}}n| 
|\}}| j|d || }| jj|dk	rR|ddks,td fd	d
t| jD }d}d}n2fdd
t| jD }| ||}| ||}|dk	rd|ddddddf d| j jdd t| jj }||j}nd}t|g| d}| jr| |}tjj|| j| j d}|rdnd}|r:| j jdkr:dnd}rHdnd}rVdnd}rn| j j!rndnd}| j"r| j rrt#$d drdnd}t%||gddgD ]V\}}|dk	r| d t&| j'kstd| dt&| j' d| d  dqt(| j'D ]z\}} |rj||ddd|f f7 }| j jdkrj||dd|df f7 }|dk	r||| nd}!| j"r| j rfdd}"tj)j**|"| |||||dk	r|| nd|dk	r|| nd||||d}#nF| |||||dk	r|| nd|dk	r|| nd|||||!d}#|#d }rX||#rNdnd f7 }r||#d f7 }||#d f7 }| j j!r||#d f7 }q|r||ddd|f f7 }| j jdkr||dd|df f7 }|ddd|f }$| j jdkr|dd|df nd}%|s@t+dd |$|%||||||fD S t,|$|%||||||dS ) ap	  
        encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetDecoder
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetDecoder.from_pretrained("microsoft/prophetnet-large-uncased", add_cross_attention=False)
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        ```NzGEither `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.zFMake sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.r"   )r    rF   )NNr   zOAt the moment `use_cache` is only supported for `decoder_input_ids` of length 1c                    s&   g | ]}|d      d d qS r   )r<   r   r+   )r   r&  predicting_stream_pos_embedr   r   r     s   z-ProphetNetDecoder.forward.<locals>.<listcomp>c                    s   g | ]} |d    qS r)  r   r*  )r&  r+  r   r   r     s    r  r   r   r   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr
  r(  zThe `z` should be specified for r  r  c                    s    fdd}|S )Nc                     s    | f S ru   r   r  )rs   r   r   r   r   r  &  s    zPProphetNetDecoder.forward.<locals>.create_custom_forward.<locals>.custom_forwardr   r  )r   r   r  r   r  %  s    z8ProphetNetDecoder.forward.<locals>.create_custom_forward)r   rM   r   r   r   r   r?   r@   r>   r   r   r   r   r	   c                 s   s   | ]}|d k	r|V  qd S ru   r   r  r   r   r   r  c  s   
z,ProphetNetDecoder.forward.<locals>.<genexpr>)r_   r`   rF   rb   rc   rd   re   rK   )-ro   r   r   r  r  r   r   r   r   r    !compute_buffered_relative_bucketsr   r&  rm   r=   r~   r(   r+   prepare_attention_maskprepare_predict_attention_maskr<   r   r   r$   r   r%   r   r)   r   r   r   r   r   r   rx   loggerZwarning_oncezipr  r  r  r  r  r  ra   )&rU   r   r   rM   r'  r
  r(  rF   r  r   r   r  r  r*   Zmain_stream_pos_embedr>   r?   r@   rb   Zngram_hidden_statesr   r   Zextended_encoder_attention_maskZall_main_stream_hidden_statesZall_ngram_stream_hidden_statesZall_main_stream_attnsZall_ngram_stream_attnsZall_cross_attnsZpresent_key_valuesZ	attn_maskZ	mask_namer!  Zdecoder_layerr   r  r"  r_   r`   r   )r   r&  r   r+  r   r   r     s   :






*

$
&zProphetNetDecoder.forwardc              	   C   s   |j \}}td| j|jdd}t| j| j	|\}}|d d d |d |f |dd}t
|d d d |d |f |d d d || j| j| f gd|dd}||fS r   )r   r   r   r%  r   r    r<   rA   r5   r   r)   )rU   r>   r   r*   Zmain_relative_bucketsZpredict_relative_bucketsr   r   r   r,  |  s0    
  $
   
z3ProphetNetDecoder.compute_buffered_relative_bucketsc                 C   s   |j d d \}}tj||ft|jj|j|jd}t|d}|d |d |f d d d d d d f || j	j
f|j  }|d k	rd|d d d d d d f  t| jj }|| }n|}||jS )Nr"   r   r   r  )r   r   fullr$   r   r%   r    Ztriuexpandro   r   r   )rU   rb   r   r   
seq_lengthZcausal_maskZextended_causal_maskr   r   r   r   r-    s     (*
z(ProphetNetDecoder.prepare_attention_maskc           	      C   s"  |j d d \}}t| j| j|j|j}tj|d d d |d |f |d d d || j| j| f gdd}|d d d d d d d d f || j	j
f|j  }|d k	rd|d d d d d d d f  t| jj }||| j	j
| j||f}tj|t|gdd}|| }n|}||jS )Nr"   r:   r   r  )r   r,   r%  r+   r    r   r   r)   r2  ro   r   r$   r%   r.   r   )	rU   rb   r   r   r3  Zpredict_causal_maskZextended_predict_causal_maskr   r   r   r   r   r.    sB       
 	
, 
z0ProphetNetDecoder.prepare_predict_attention_mask)N)NNNNNNNNNNNN)rW   rX   rY   rZ   r   r   r   rq   r   r  r  r   r#  r   ra   r$  r   r   r   r   r   r   r,  r-  r.  r   r   r   r   r   rv   X  sJ   
            
 zrv   zXThe bare ProphetNet Model outputting raw hidden-states without any specific head on top.c                       s   e Zd ZddgZed fddZdd Zdd	 Zd
d Zdd Z	e
eeeeddeej eej eej eej eej eej eej ee eeeej   eej eej ee ee ee ee eeef dddZ  ZS )ProphetNetModelencoder.word_embeddings.weightdecoder.word_embeddings.weightr   c                    sx   t  | tj|j|j|jd| _t	|}d|_
d|_t|| j| _t	|}d|_d|_
t|| j| _|   d S )Nr   FT)r   r   r   rq   r   r   r}   r   copydeepcopyis_encoder_decoderr   rw   encoder
is_decoderrv   decoderr  )rU   ro   Zencoder_configZdecoder_configr   r   r   r     s    

zProphetNetModel.__init__c                 C   s   | j S ru   r  rT   r   r   r   r    s    z$ProphetNetModel.get_input_embeddingsc                 C   s   || _ | j | j_ | j | j_ d S ru   )r   r:  r<  r  r   r   r   r    s    
z$ProphetNetModel.set_input_embeddingsc                 C   s   | j S ru   )r:  rT   r   r   r   get_encoder  s    zProphetNetModel.get_encoderc                 C   s   | j S ru   r<  rT   r   r   r   get_decoder  s    zProphetNetModel.get_decoderr  N)r   r   decoder_input_idsdecoder_attention_maskr
  decoder_head_maskr(  encoder_outputsrF   r  decoder_inputs_embedsr   r   r  r  r   c                 C   s   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}|dk	rH|n| j j}|dkrp| j||||
|||d}| j|||d ||||	|||||d}|s|| S t|j|j	|j
|j|j|j|j|j|j|j|jdS )ag  
        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetModel

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetModel.from_pretrained("microsoft/prophetnet-large-uncased")

        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

        >>> last_hidden_states = outputs.last_hidden_state  # main stream hidden states
        >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram  # predict hidden states
        ```N)r   r   r
  r  r   r  r  r   )r   r   rM   r'  r
  r(  rF   r  r   r  r   r  )r_   r`   rF   rG   rH   rI   rJ   rK   rL   rM   rN   )ro   r   r   r  r  r:  r<  r^   r_   r`   rF   rb   rc   rd   re   rK   )rU   r   r   r@  rA  r
  rB  r(  rC  rF   r  rD  r   r   r  r  Zdecoder_outputsr   r   r   r     sX    (zProphetNetModel.forward)NNNNNNNNNNNNNNN)rW   rX   rY   _tied_weights_keysr   r   r  r  r=  r?  r   PROPHETNET_INPUTS_DOCSTRINGr   r^   r$  r   r   r   
BoolTensorr   r   r   r   r   r   r   r   r   r4    sR   
               
r4  z^The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.c                       s0  e Zd ZdddgZed fddZdd Zd	d
 Zdd Ze	e
eeedd eej eej eej eej eej eej eej eej eeeej   eej eej eej ee ee ee ee eeef dddZd!ddZd"ddZejdddZedd Zdd Zdd Z  ZS )#"ProphetNetForConditionalGenerationr5  r6  lm_head.weightr   c                    sH   t  | t|| _|j| _|j| _tj|j	|j
dd| _|   d S )NFrp   )r   r   r4  rh   r}   rr   disable_ngram_lossr   rl   r   r   lm_headr  r   r   r   r   r   V  s    
z+ProphetNetForConditionalGeneration.__init__c                 C   s   | j S ru   rL  rT   r   r   r   get_output_embeddingsa  s    z8ProphetNetForConditionalGeneration.get_output_embeddingsc                 C   s
   || _ d S ru   rM  rU   Znew_embeddingsr   r   r   set_output_embeddingsd  s    z8ProphetNetForConditionalGeneration.set_output_embeddingsc                 C   s   | j jS ru   )rh   r   rT   r   r   r   r  g  s    z7ProphetNetForConditionalGeneration.get_input_embeddingsr  N)r   r   r@  rA  r
  rB  r(  rC  rF   r  rD  labelsr   r   r  r  r   c                 C   s~  |dk	r|n| j j}|dk	r6|dkr6|dkr6| |}| j|||||||||	|
|||||d}|dk	rl|jn|jdd \}}|d || j j|d}| |}|dddf }| j jdkr|ddddf nd}| s|	 }d}|dk	r| 
||}|sHtdd ||fD }|dk	r8|f| |dd  S ||dd  S t||||j|j|j|j|j|j|j|j|jd	S dS )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetForConditionalGeneration

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased")

        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

        >>> logits_next_token = outputs.logits  # logits to predict next token as usual
        >>> logits_ngram_next_tokens = outputs.logits_ngram  # logits to predict 2nd, 3rd, ... next tokens
        ```N)r   r   r@  rA  r
  rB  r(  rC  rF   r  rD  r   r   r  r  r"   r   r:   r   c                 s   s   | ]}|d k	r|V  qd S ru   r   r  r   r   r   r    s      z=ProphetNetForConditionalGeneration.forward.<locals>.<genexpr>)rC   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   )ro   r  r   rh   r   r   r+   rL  Zis_contiguousr   _compute_lossr  rB   rF   rG   rH   rI   rJ   rK   rL   rM   rN   )rU   r   r   r@  rA  r
  rB  r(  rC  rF   r  rD  rQ  r   r   r  r  r   r   r*   predicting_streamspredict_logitsrD   rE   rC   
all_logitsr   r   r   r   j  s`    .

$0z*ProphetNetForConditionalGeneration.forwardr{   c                 C   s  | | jj|d|d|}t| jjD ],}|dkrF| jrF q^|||d d d d f< q0|dd }t	j
j|d|ddtjd}t	j
j||ddd}| jjdkr|jddd	 }||d}	||	 }| }| jj|d }
d
| jj | |
|  }|S Nr   r   r:   r   rj   )Z	reductionri   T)r   Zkeepdimr  r   ro   r+   r=   Zfill_r(   rK  r   r   r   r   Zlog_softmaxr   r   r   Znll_lossZepssumnerj   rU   rD   rQ  Zignore_indexZexpend_targetsiZlprobsrC   Zsmooth_lossZnon_masked_tokensZeps_ir   r   r   rR    s(    $z0ProphetNetForConditionalGeneration._compute_lossc	           
   
   K   s@   |d k	st d|r(|d d dd f }d ||||||||d	S )Nz3`encoder_outputs` have to be passed for generation.r:   )	r   rC  rF   r@  r   r
  rB  r(  r   )r~   )
rU   r@  rF   r   r
  rB  r(  r   rC  kwargsr   r   r   prepare_inputs_for_generation  s    z@ProphetNetForConditionalGeneration.prepare_inputs_for_generation)rQ  c                 C   s
   |  |S ru   )r   )rU   rQ  r   r   r   %prepare_decoder_input_ids_from_labels  s    zHProphetNetForConditionalGeneration.prepare_decoder_input_ids_from_labelsc                    sB   d}| D ]4}|t  fdd|d d D |dd   f7 }q|S )Nr   c                 3   s"   | ]}| d  |jV  qdS r   NZindex_selectr   r    r   Z
past_statebeam_idxr   r   r    s     zDProphetNetForConditionalGeneration._reorder_cache.<locals>.<genexpr>r"   r  rF   rc  Zreordered_pastZ
layer_pastr   rb  r   _reorder_cache  s    
z1ProphetNetForConditionalGeneration._reorder_cachec                 C   s   | j jS ru   )rh   r:  rT   r   r   r   r=    s    z.ProphetNetForConditionalGeneration.get_encoderc                 C   s   | j jS ru   rh   r<  rT   r   r   r   r?    s    z.ProphetNetForConditionalGeneration.get_decoder)NNNNNNNNNNNNNNNN)r{   )NNNNNNN)rW   rX   rY   rE  r   r   rN  rP  r  r   rF  r   rB   r$  r   r   r   rG  r   r   r   r   rR  r]  r^  staticmethodrf  r=  r?  r   r   r   r   r   rH  O  sp   

                
h
       


rH  z}The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal language modeling.c                       s  e Zd ZdgZed fddZdd Zdd Zd	d
 Zdd Z	dd Z
dd Zeeeeeddeej eej eej eej eej eej eeeej   eej eej ee ee ee ee eeef dddZdddZdddZedd Z  ZS ) ProphetNetForCausalLMrI  r   c                    s^   t |}d|_d|_t | t|| _|j| _	|j
| _
tj|j|jdd| _|   d S )NTFrJ  )r7  r8  r;  r9  r   r   ProphetNetDecoderWrapperrh   r}   rr   rK  r   rl   r   r   rL  r  r   r   r   r   r   +  s    

zProphetNetForCausalLM.__init__c                 C   s
   | j jjS ru   rh   r<  r   rT   r   r   r   r  ;  s    z*ProphetNetForCausalLM.get_input_embeddingsc                 C   s   || j j_d S ru   rk  r  r   r   r   r  >  s    z*ProphetNetForCausalLM.set_input_embeddingsc                 C   s   | j S ru   rM  rT   r   r   r   rN  A  s    z+ProphetNetForCausalLM.get_output_embeddingsc                 C   s
   || _ d S ru   rM  rO  r   r   r   rP  D  s    z+ProphetNetForCausalLM.set_output_embeddingsc                 C   s   || j _d S ru   rg  )rU   r<  r   r   r   set_decoderG  s    z!ProphetNetForCausalLM.set_decoderc                 C   s   | j jS ru   rg  rT   r   r   r   r?  J  s    z!ProphetNetForCausalLM.get_decoderr  N)r   r   rM   r'  r
  r(  rF   r  rQ  r   r   r  r  r   c                 C   s<  |dk	r|n| j j}| jj|||||||||
|||d}|dk	rF|jn|jdd \}}|d || j j|d}| |}|dddf }| j jdkr|ddddf nd}d}|	dk	r| ||	}|st	dd ||fD }|dk	r|f| |dd  S ||dd  S t
||||j|j|j|j|j|jd		S dS )
a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetForCausalLM
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetForCausalLM.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> logits = outputs.logits

        >>> # Model can also be used with EncoderDecoder framework
        >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer
        >>> import torch

        >>> tokenizer_enc = BertTokenizer.from_pretrained("bert-large-uncased")
        >>> tokenizer_dec = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "bert-large-uncased", "microsoft/prophetnet-large-uncased"
        ... )

        >>> ARTICLE = (
        ...     "the us state department said wednesday it had received no "
        ...     "formal word from bolivia that it was expelling the us ambassador there "
        ...     "but said the charges made against him are `` baseless ."
        ... )
        >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
        >>> labels = tokenizer_dec(
        ...     "us rejects charges against its ambassador in bolivia", return_tensors="pt"
        ... ).input_ids
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:])

        >>> loss = outputs.loss
        ```N)r   r   rM   r'  r
  r(  rF   r  r   r   r  r  r"   r   r:   r   c                 s   s   | ]}|d k	r|V  qd S ru   r   r  r   r   r   r    s      z0ProphetNetForCausalLM.forward.<locals>.<genexpr>)	rC   rD   rE   rF   rb   rc   rd   re   rK   )ro   r  rh   r<  r   r   r+   rL  rR  r  rf   rF   rb   rc   rd   re   rK   )rU   r   r   rM   r'  r
  r(  rF   r  rQ  r   r   r  r  r   r   r*   rS  rT  rD   rE   rC   rU  r   r   r   r   M  sJ    X 
$0zProphetNetForCausalLM.forwardr{   c                 C   s  | | jj|d|d|}t| jjD ],}|dkrF| jrF q^|||d d d d f< q0|dd }t	j
j|d|ddtjd}t	j
j||ddd}| jjdkr|jddd	 }||d}	||	 }| }| jj|d }
d
| jj | |
|  }|S rV  rW  rZ  r   r   r   rR    s(    $z#ProphetNetForCausalLM._compute_lossc                 K   s<   |d kr| |j}|r,|d d dd f }|||||dS )Nr:   )r   r   r
  rF   r   )Znew_onesr   )rU   r   rF   r   r
  r   r\  r   r   r   r]    s    
z3ProphetNetForCausalLM.prepare_inputs_for_generationc                    s.   d}| D ] }|t  fdd|D f7 }q|S )Nr   c                 3   s"   | ]}| d  |jV  qdS r_  r`  ra  rb  r   r   r  	  s     z7ProphetNetForCausalLM._reorder_cache.<locals>.<genexpr>rd  re  r   rb  r   rf  	  s    z$ProphetNetForCausalLM._reorder_cache)NNNNNNNNNNNNN)r{   )NNNN)rW   rX   rY   rE  r   r   r  r  rN  rP  rl  r?  r   r#  r   rf   r$  r   r   r   r   r   r   r   rR  r]  rh  rf  r   r   r   r   r   ri  #  s`   
             
 
    
ri  c                       s.   e Zd ZdZed fddZdd Z  ZS )rj  z
    This is a wrapper class, so that [`ProphetNetForCausalLM`] can correctly be loaded from pretrained prophetnet
    classes.
    r   c                    s   t  | t|| _d S ru   )r   r   rv   r<  r   r   r   r   r   	  s    z!ProphetNetDecoderWrapper.__init__c                 O   s   | j ||S ru   r>  )rU   argsr\  r   r   r   r   	  s    z ProphetNetDecoderWrapper.forward)rW   rX   rY   rZ   r   r   r   r   r   r   r   r   rj  	  s   rj  )F)F)>rZ   r7  r3   rQ   dataclassesr   typingr   r   r   r   Ztorch.utils.checkpointr   r   Ztorch.nnr   Zactivationsr
   Zmodeling_outputsr   Zmodeling_utilsr   r  r   r   r   r   r   Zconfiguration_prophetnetr   Z
get_loggerrW   r/  r$  Z_CHECKPOINT_FOR_DOCZ(PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LISTZPROPHETNET_START_DOCSTRINGrF  r#  r   r,   r9   rA   rB   r^   ra   rf   rg   rq   r   Moduler   r   r   r   r   rw   rv   r4  rH  ri  rj  r   r   r   r   <module>   s   
I"

TU;=*+   A+T 	  t} Q k