U
    0-e                     @   s  d dl Z d dlZd dlmZ d dlmZ d dlZd dlm  m	Z
 d dlmZmZmZ d dlmZ ddlmZ ddlmZ d	d
lmZmZ d	dlmZmZ e rd dlmZmZmZ e rd dl m!Z!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z( d dl)m*Z*m+Z+m,Z,m-Z- d dl.m/Z/m0Z0m1Z1 d dl2m3Z3m4Z4 d dl5m6Z6 d dl7m8Z8m9Z9m:Z:m;Z;m<Z< d dl=m>Z>m?Z?m@Z@mAZAmBZB d dl=mZC d dlDmEZE d dlFmGZG d dlHmIZI d dlJmKZKmLZL d dlMmNZNmOZO d dlPmQZQ d dlRmSZSmTZTmUZU d dlVmWZWmXZXmYZYmZZZ d?ddZ[dd  Z\G d!d" d"Z]d#d$ Z^G d%d& d&eZ_d'd( Z`G d)d* d*ZaG d+d, d,eZbd-d. ZcG d/d0 d0eZdG d1d2 d2edZeG d3d4 d4edZfG d5d6 d6edZgdi fd7d8ZhG d9d: d:ejjiZjd;d< Zkd=d> ZldS )@    N)ABC)partial)BCEWithLogitsLossCrossEntropyLossMSELoss)DistributedDataParallel   )AcceleratedOptimizer)AcceleratedScheduler   )is_megatron_lm_availableis_transformers_available)recursively_applysend_to_device)!CausalLMOutputWithCrossAttentionsSeq2SeqLMOutputSequenceClassifierOutput)get_argsget_num_microbatchesget_tensorboard_writer
get_timersget_tokenizermpuprint_rank_0print_rank_last)_add_data_args_add_validation_args
parse_argsvalidate_args)load_args_from_checkpointload_checkpointsave_checkpoint) MegatronPretrainingRandomSamplerMegatronPretrainingSampler)set_global_variables)_compile_dependencies_init_autoresume_set_random_seedset_jit_fusion_optionswrite_args_to_tensorboard)	BertModelFloat16ModuleGPTModel	ModelTypeT5Model)Classification)get_megatron_optimizer)get_forward_backward_func)broadcast_int_listbroadcast_tensor)%beam_search_and_return_on_first_stage/generate_tokens_probs_and_return_on_first_stage)_vocab_size_with_padding)	get_modelget_optimizer_param_schedulertraining_log))average_losses_across_data_parallel_groupcalc_params_l2_normget_ltor_masks_and_position_idsunwrap_modelTc                 C   s   t  }|jrdnd}|jdkr>td|j d| d td |jdkr|jrr|jrXd	nd}t||jd
| |d}qt|jd	| |d}nL|jdkrt	dd
| |d}n0|jdkrt
dd
| |||d}ntd|j |S )zBuild the model.zpre-trainingzfine-tuningr   z	Building z model in the z mode.zThe Megatron LM model weights are initialized at random in `accelerator.prepare`. Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup.bertr   T)num_tokentypesZadd_binary_headparallel_outputpre_processpost_process)Znum_classesr?   rA   rB   gpt)r?   r@   rA   rB   t5)r?   r@   rA   rB   add_encoderadd_decoderUnsupported model type: )r   pretraining_flagrankprintmodel_type_namebert_binary_headr*   r/   
num_labelsr,   r.   
ValueError)rA   rB   rE   rF   argsmoder?   model rR   ]/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/accelerate/utils/megatron_lm.pymodel_provider_funcR   sH    

   

	rT   c                 C   s   |  d t }| jjjd k	rN| jjjd kr4td| jjj}| jj|}nL|jdkr`tj	}n0|jdkrtj
}|jd kr|jdkr|jd |_tt|}|S )NzPreparing modelzaYou must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`.)r>   rC   rD   r   r   )rJ   r   statemegatron_lm_pluginZcustom_prepare_model_functionZcustom_model_provider_functionrN   rK   r-   Zencoder_or_decoderZencoder_and_decoder"pipeline_model_parallel_split_rankpipeline_model_parallel_sizer7   rT   )acceleratorrO   Zcustom_model_provider_funcrQ   Z
model_typerR   rR   rS   prepare_modelz   s"    




rZ   c                   @   s8   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d ZdS )MegatronLMDummyDataLoaderz
    Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training

    Args:
        **dataset_kwargs: Megatron data arguments.
    c                 K   sH   t  }t|}t|}| }t|d | _| j| d| jd< d S )Nr   Tmegatron_dataset_flag)argparseArgumentParserr   r   parse_known_argsvarsdataset_argsupdate)selfZdataset_kwargsparserZ	data_argsrR   rR   rS   __init__   s    z"MegatronLMDummyDataLoader.__init__c                 C   s*   t  }| j D ]\}}t||| qd S N)r   ra   itemssetattr)rc   rO   keyvaluerR   rR   rS   set_megatron_data_args   s    z0MegatronLMDummyDataLoader.set_megatron_data_argsc                 C   s   dd }|S )Nc                 S   s   t  }|j|j|j| |j |jd}|jdkrL||j|j	|j
|jd nV|jdkrh|d|ji n:|jdkr||j|j|j	|j
dd ntd|j |jdkrd	d
lm} nd	d
lm} |f |\}}}|||fS )z&Build train, valid, and test datasets.)Zdata_prefix	data_implZsplits_stringZtrain_valid_test_num_samplesZskip_warmupseedr>   )max_seq_lengthmasked_lm_probshort_seq_probZbinary_headrC   
seq_lengthrD   )rn   Zmax_seq_length_decro   rp   Zdataset_typerG   r   )build_train_valid_test_datasets)r   Z	data_pathrl   splitZmmap_warmuprm   rK   rb   rq   Z	mask_probrp   rL   Zencoder_seq_lengthZdecoder_seq_lengthrN   Zmegatron.data.gpt_datasetrr   Zmegatron.data.dataset_utils)train_val_test_num_samplesrO   ra   rr   train_dsvalid_dstest_dsrR   rR   rS   "train_valid_test_datasets_provider   sJ    

 


zlMegatronLMDummyDataLoader.get_train_valid_test_datasets_provider.<locals>.train_valid_test_datasets_providerrR   )rc   rx   rR   rR   rS   &get_train_valid_test_datasets_provider   s    -z@MegatronLMDummyDataLoader.get_train_valid_test_datasets_providerc              	   C   s   |d krd S t  }|j|j }|jdkrHtt|||t t d}n@|jdkrxt	|t|||t t |j
d}ntd|jtjjj|||jddS )Nsingle)total_samplesconsumed_samplesmicro_batch_sizedata_parallel_rankdata_parallel_sizecyclic)r{   r|   r}   r~   r   data_shardingz${} dataloader type is not supported.T)batch_samplernum_workersZ
pin_memory)r   r}   num_micro_batchesdataloader_typer#   lenr   get_data_parallel_rankget_data_parallel_world_sizer"   r   	Exceptionformattorchutilsdata
DataLoaderr   )rc   datasetr|   rO   r}   r   rR   rR   rS   build_pretraining_data_loader   s:    


   z7MegatronLMDummyDataLoader.build_pretraining_data_loaderc                 C   s  dd }t  }d\}}}td |jdkrT|jdkrT|jd ksFtd|j|j |_|jdkr|jdkr|jd kr|j|j |j	 |j |_t
 dkr|jr|j}n|j|j }|j|j d |j	 }|j	}|||j ||j g}	td td	|	d  td
|	d  td|	d  |  }
|
|	\}}}| ||j}| ||j}| |d}|d k	ot|jdk}|d k	o|j	dk}|d k	o|j	dk}tjt|t|t|g}ntjdddg}tjj|t
 t
 d |d  |_|d  |_|d  |_|j}|dks(t|d k	rR|dkrDt|n
t||}nd }|d k	r|dkrrt|n
t||}nd }|d k	r|dkrt|n
t||}nd }|||fS )Nc                 s   s   | D ]
}|V  qq d S rf   rR   )iterxrR   rR   rS   cyclic_iter   s    zTMegatronLMDummyDataLoader.build_train_valid_test_data_iterators.<locals>.cyclic_iter)NNNz3> building train, validation, and test datasets ...r   z?only backward compatiblity support for iteration-based trainingr   z( > datasets target sizes (minimum size):z    train:      {}z    validation: {}z    test:       {}r   group)rz   r   rz   )r   r   	iterationconsumed_train_samplestrain_samplesAssertionErrorZglobal_batch_sizeconsumed_valid_samplesZeval_interval
eval_itersr   Zget_tensor_model_parallel_rankZtrain_itersr   ry   r   r   cuda
LongTensorintdistributed	broadcastZ"get_tensor_model_parallel_src_rankZget_tensor_model_parallel_groupitemdo_traindo_validdo_testr   r   )rc   r   rO   Ztrain_dataloaderZvalid_dataloaderZtest_dataloaderr   r   Z
test_itersrt   rx   ru   rv   rw   r   r   r   flagsZdl_typetrain_data_iteratorvalid_data_iteratortest_data_iteratorrR   rR   rS   %build_train_valid_test_data_iterators   sr    

   


 z?MegatronLMDummyDataLoader.build_train_valid_test_data_iteratorsN)	__name__
__module____qualname____doc__re   rk   ry   r   r   rR   rR   rR   rS   r[      s   	0!r[   c           	   
      s0  |  d t }|jsddlm m} t }|j|j } fdd D }|d d krt|d t	j
jjrx||d _q|d= |d= |d= ||d	 _n|d	= ||d< t	j
jjjf||| jt t | jd
| j | jdS |jd k	r|j\|_|_|_nd\|_|_|_ \}}}|||fS d S )NzPreparing dataloaderr   )_PYTORCH_DATALOADER_KWARGSprepare_data_loaderc                    s   i | ]}|t | | qS rR   )getattr).0kr   
dataloaderrR   rS   
<dictcomp>\  s      z'prepare_data_loader.<locals>.<dictcomp>
batch_sizeZsamplershuffler   T)Znum_processesZprocess_indexsplit_batchesZput_on_device	rng_typesdispatch_batches)r   r   r   )rJ   r   r\   Zdata_loaderr   r   r}   r   
isinstancer   r   r   ZBatchSamplerr   r   r   devicer   r   r   r   r   copyr   r|   r   r   Zconsumed_test_samplesr   )	rY   r   rO   r   r}   kwargsr   r   r   rR   r   rS   r   T  sR    
r   c                       s:   e Zd Z fddZd
ddZdd Zedd	 Z  ZS )MegatronLMOptimizerWrapperc                    s   t  j|dd d d S )NF)Zdevice_placementZscalersuperre   )rc   	optimizer	__class__rR   rS   re     s    z#MegatronLMOptimizerWrapper.__init__Nc                 C   s   d S rf   rR   )rc   Zset_to_nonerR   rR   rS   	zero_grad  s    z$MegatronLMOptimizerWrapper.zero_gradc                 C   s   d S rf   rR   rc   rR   rR   rS   step  s    zMegatronLMOptimizerWrapper.stepc                 C   s   | j jS )zTWhether or not the optimizer step was done, or skipped because of gradient overflow.)r   skipped_iterr   rR   rR   rS   step_was_skipped  s    z+MegatronLMOptimizerWrapper.step_was_skipped)N)	r   r   r   re   r   r   propertyr   __classcell__rR   rR   r   rS   r     s
   
r   c                 C   s(   |  d t }t||j|j|j}|S )NzPreparing optimizer)rJ   r   r0   Zno_wd_decay_condZscale_lr_condZlr_mult)rY   rQ   rO   r   rR   rR   rS   prepare_optimizer  s    
r   c                   @   s   e Zd ZdZdddZdS )MegatronLMDummySchedulera  
    Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
    loop when scheduler config is specified in the deepspeed config file.

    Args:
        optimizer (`torch.optim.optimizer.Optimizer`):
            The optimizer to wrap.
        total_num_steps (int):
            Total number of steps.
        warmup_num_steps (int):
            Number of steps for warmup.
        **kwargs:
            Other arguments.
    Nr   c                 K   s   || _ || _|| _|| _d S rf   )r   total_num_stepswarmup_num_stepsr   )rc   r   r   r   r   rR   rR   rS   re     s    z!MegatronLMDummyScheduler.__init__)Nr   )r   r   r   r   re   rR   rR   rR   rS   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )MegatronLMSchedulerWrapperc                    s   t  || d S rf   r   )rc   	schedulerZ
optimizersr   rR   rS   re     s    z#MegatronLMSchedulerWrapper.__init__c                 O   s   d S rf   rR   )rc   rO   r   rR   rR   rS   r     s    zMegatronLMSchedulerWrapper.step)r   r   r   re   r   r   rR   rR   r   rS   r     s   r   c                 C   s   |  d t|}|S )NzPreparing scheduler)rJ   r8   )rY   r   r   rR   rR   rS   prepare_scheduler  s    
r   c                       s8   e Zd ZdZ fddZdd Zdd Zdd	 Z  ZS )
AbstractTrainStepz;Abstract class for batching, forward pass and loss handler.c                    s   t    || _d S rf   )r   re   name)rc   r   r   rR   rS   re     s    
zAbstractTrainStep.__init__c                 C   s   d S rf   rR   r   rR   rR   rS   get_batch_func  s    z AbstractTrainStep.get_batch_funcc                 C   s   d S rf   rR   r   rR   rR   rS   get_forward_step_func  s    z'AbstractTrainStep.get_forward_step_funcc                 C   s   d S rf   rR   r   rR   rR   rS   get_loss_func  s    zAbstractTrainStep.get_loss_func)	r   r   r   r   re   r   r   r   r   rR   rR   r   rS   r     s
   r   c                       s8   e Zd ZdZ fddZdd Zdd Zdd	 Z  ZS )
BertTrainStepzg
    Bert train step class.

    Args:
        args (`argparse.Namespace`): Megatron-LM arguments.
    c                    sV   t  d | |j| _| |j|j| _| 	|j|j
| _|jsLd | _nt| _d S )Nr   )r   re   r   r\   	get_batchr   rH   rM   	loss_funcr   rL   forward_stepmodel_return_dictmodel_output_classr   rc   rO   r   rR   rS   re     s    zBertTrainStep.__init__c                 C   s    dd }dd }|r|S |S d S )Nc                 S   s   ddddddg}t j}| dk	r(t| }nd}t|||}|d  }|d  }|d  }|d  }|d  }	|d  }
|||||	|
fS )Build the batch.texttypeslabelsZ	is_random	loss_maskpadding_maskNr   int64nextr   broadcast_datalongfloat)data_iteratorkeysdatatyper   data_btokensr   sentence_orderr   	lm_labelsr   rR   rR   rS   get_batch_megatron  s    
z8BertTrainStep.get_batch_func.<locals>.get_batch_megatronc                 S   s   t | }t|tj }|d  }|d  }d|krF|d  }nd}d|krt|d  }|d dktj}nd}d}d|kr|d  }nd}||||||fS )r   	input_idsattention_maskZtoken_type_idsNr   Znext_sentence_label)r   r   r   r   current_devicer   tor   )r   r   r   r   r   r   r   r   rR   rR   rS   get_batch_transformer  s     z;BertTrainStep.get_batch_func.<locals>.get_batch_transformerrR   rc   r\   r   r   rR   rR   rS   r     s
    zBertTrainStep.get_batch_funcc                    s&   dd } fdd}|r|S |S d S )Nc           	      S   s   |\}}|  }|   } t|d| d |   }|d k	rtj|dd  |ddd}|  }|| }t||g}||d |d dfS |}t|g}|d|d ifS d S )Nr   )Zignore_indexr   r   )lm losszsop lossr   )r   r   sumviewreshapeFZcross_entropyr:   )	r   r   output_tensorlm_loss_Z
sop_logitslm_lossZsop_losslossaveraged_lossesrR   rR   rS   loss_func_pretrain  s    ""
z7BertTrainStep.get_loss_func.<locals>.loss_func_pretrainc                    s    dkr&t  }||d| d}nLjdkrb| jtjtjfkrbt }||d | d}nt }||| }t	|g}|d|d ifS )Nr   r   r  r   )
r   r   rM   dtyper   r   r   r   r   r:   )r   logitsZloss_fctr  r  rM   rc   rR   rS   loss_func_finetune1  s    

z7BertTrainStep.get_loss_func.<locals>.loss_func_finetunerR   )rc   rH   rM   r  r	  rR   r  rS   r     s
    zBertTrainStep.get_loss_funcc                    s    fdd}|S )Nc           
         sf    | \}}}}}} sd}rD|||||d}|tj||fS ||||d}	|	tj|fS dS )Forward step.Ntokentype_idsr   )r  r   r   r   )
r   rQ   r   r   r   r   r   r   r   r  rL   rH   rc   rR   rS   r   E  s    z9BertTrainStep.get_forward_step_func.<locals>.forward_steprR   )rc   rH   rL   r   rR   r  rS   r   D  s    z#BertTrainStep.get_forward_step_func	r   r   r   r   re   r   r   r   r   rR   rR   r   rS   r     s
   
7'r   c                       s8   e Zd ZdZ fddZdd Zdd Zdd	 Z  ZS )
GPTTrainStepzf
    GPT train step class.

    Args:
        args (`argparse.Namespace`): Megatron-LM arguments.
    c                    s   t  d | |j| _|  | _|  | _|j	d | _
|jd k	rRt }|j| _
|j| _|j| _|j| _|jsxd | _nt| _d S )Nr  r   )r   re   r   r\   r   r   r   r   r   padded_vocab_size	eod_token
vocab_filer   eodreset_position_idsreset_attention_maskeod_mask_lossr   r   r   )rc   rO   	tokenizerr   rR   rS   re   ]  s    


zGPTTrainStep.__init__c                    s(    fdd} fdd}|r |S |S d S )Nc                    s   dg}t j}| dk	rt| }nd}t|||}|d  }|ddddf  }|ddddf  }t| j j	 j
 j\}}	}
|||	||
fS )zGenerate a batchr   Nr   r   )r   r   r   r   r   r   
contiguousr<   r  r  r  r  )r   r   r   r   r   tokens_r   r   r   r   position_idsr   rR   rS   r   o  s"    
    
z7GPTTrainStep.get_batch_func.<locals>.get_batch_megatronc           	         s   t | }d|d i}t|tj }|d  }tj|jd df|j|j	d j
 }tj||gdd}|d d dd f  }|d d d df  }t| j
 j jd\}}}|||||fS )Nr   r   r   )r  r   dimr   T)r   r   r   r   r   r   Zzerosshaper  r   r  concatr  r<   r  r  )	r   r   r  paddingr   r   r   r   r  r   rR   rS   r     s     $    
z:GPTTrainStep.get_batch_func.<locals>.get_batch_transformerrR   r   rR   r   rS   r   n  s
    zGPTTrainStep.get_batch_funcc                    s   t    fdd}|S )Nc                    sx    j r|\}}n|}| }| d } t|d|  |   }t|g}d|d i} j rp|d|i ||fS )Nr   r   r   r  )Zreturn_logitsr   r   r   r   r:   rb   )r   r   lossesr  r  Zaveraged_lossZoutput_dictrO   rR   rS   r     s    

z-GPTTrainStep.get_loss_func.<locals>.loss_func)r   rc   r   rR   r"  rS   r     s    zGPTTrainStep.get_loss_funcc                    s    fdd}|S )Nc                    s4     | \}}}}}|||||d}|t j|fS )r
  )r   r  )r   rQ   r   r   r   r   r  r   r   rR   rS   r     s    z8GPTTrainStep.get_forward_step_func.<locals>.forward_steprR   rc   r   rR   r   rS   r     s    z"GPTTrainStep.get_forward_step_funcr  rR   rR   r   rS   r  U  s
   /r  c                       s\   e Zd ZdZ fddZedd Zedd Zedd	 Zd
d Z	dd Z
dd Z  ZS )T5TrainStepze
    T5 train step class.

    Args:
        args (`argparse.Namespace`): Megatron-LM arguments.
    c                    sF   t  d | |j| _|  | _|  | _|j	s<d | _
nt| _
d S )Nr%  )r   re   r   r\   r   r   r   r   r   r   r   r   r   r   rR   rS   re     s    

zT5TrainStep.__init__c                 C   s(   |  d}|  d}|| }|dk }|S )Nr   r         ?)	unsqueeze)r   attention_mask_b1sattention_mask_bs1attention_mask_bssextended_attention_maskrR   rR   rS   attn_mask_postprocess  s
    

z!T5TrainStep.attn_mask_postprocessc                 C   s&   t t jd| | f|d}|dk }|S Nr   r   r&  )r   Ztrilones)rq   r   r   rR   rR   rS   get_decoder_mask  s    zT5TrainStep.get_decoder_maskc           	      C   s<   | j \}}| d}tj||df|d}|| }|dk }|S r-  )r  r'  r   r/  )	r   Zdec_seq_lengthr   r   _r(  r)  r*  r+  rR   rR   rS   get_enc_dec_mask  s    

zT5TrainStep.get_enc_dec_maskc                 C   s    dd }dd }|r|S |S d S )Nc                 S   s   dddddddg}t j}| dk	r*t| }nd}t|||}|d  }|d  }|d  }|d  }|d d	k }	|d d	k }
|d d	k }|||||	|
|fS )
r   Ztext_encZtext_decr   r   enc_maskdec_maskenc_dec_maskNr&  r   )r   r   r   r   r   
tokens_enc
tokens_decr   r   r3  r4  r5  rR   rR   rS   r     s    
z6T5TrainStep.get_batch_func.<locals>.get_batch_megatronc           	      S   s   t | }t|tj }|d  }|d  }|dktj}d|krV|d  }nN|j|j	|j
tjd}|dddf  |dd	df< d
|d< ||dkd
 t|d  }t|j	d	 |j
}t|d  |j	d	 |j
}|||||||fS )r   r   r   r   Zdecoder_input_ids)r   r  .Nr   r   r   ).r   r   )r   r   r   r   r   r   r   r   Z	new_zerosr  r   cloneZmasked_fill_r%  r,  r0  r2  )	r   r   r6  r   r   r7  r3  r4  r5  rR   rR   rS   r   
  s&     
  z9T5TrainStep.get_batch_func.<locals>.get_batch_transformerrR   r   rR   rR   rS   r     s
    zT5TrainStep.get_batch_funcc                 C   s   dd }|S )Nc                 S   sH   |  }t|d| d |   }|}t|g}|d|d ifS )Nr   r   r   )r   r   r   r   r   r:   )r   r   r  r  r  r  rR   rR   rS   r   '  s
    "
z,T5TrainStep.get_loss_func.<locals>.loss_funcrR   r#  rR   rR   rS   r   &  s    	zT5TrainStep.get_loss_funcc                    s    fdd}|S )Nc           
   	      s>     | \}}}}}}}||||||d|d}	|	t j|fS )r
  Nr  r  )
r   rQ   r6  r7  r   r   r3  r4  r5  r   r   rR   rS   r   3  s          z7T5TrainStep.get_forward_step_func.<locals>.forward_steprR   r$  rR   r   rS   r   2  s    z!T5TrainStep.get_forward_step_func)r   r   r   r   re   staticmethodr,  r0  r2  r   r   r   r   rR   rR   r   rS   r%    s   



6r%  c                 C   s,  |  d tj stdt|dd}| D ]L\}}t||d d k	rp|jdkrpt dj	|t|||ddd t
||| q0|js|d	d
r|jd k	stdt| t| t| dd }t }|  t  t  t  t }t|j||_|jdkr|jr|jdkrd|_nd
|_d|_d S )NzInitializing Megatron-LMzMegatron requires CUDA.T)Zignore_unknown_argsr   z[WARNING: overriding default arguments for {key}:{v}                         with {key}:{v2})ri   vZv2)flushuse_checkpoint_argsFz/--use-checkpoints-args requires --load argumentc                  S   s   t  } tj }tj | _tj | _|dkr| j| }| j	d k	rX| j	|ks^t
dn|| _	t rptd nt| j| j| j| j | jdkrtd| j t| j| j d S )Nr   z:expected local-rank to be the same as rank % device-count.z%model parallel is already initializedz > setting random seeds to {} ...)r   r   r   device_countr   get_rankrI   get_world_sizeZ
world_size
local_rankr   r   Zmodel_parallel_is_initializedrJ   Zinitialize_model_parallelZtensor_model_parallel_sizerX   Z$virtual_pipeline_model_parallel_sizerW   r   rm   r'   Zdata_parallel_random_init)rO   r=  r   rR   rR   rS   finish_mpu_initc  s(    




z#initialize.<locals>.finish_mpu_initr>   r   )rJ   r   r   Zis_availabler   r   rg   r   rI   r   rh   r<  getloadr   r   r$   r   r&   r%   r(   r6   Zorig_vocab_sizer  rK   rH   rM   rL   r   )rY   Zextra_args_providerZargs_defaultsrO   ri   rj   rA  rR   rR   rS   
initializeD  s@    

   rD  c                       sj   e Zd ZdZ fddZdd Zdd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd ZdddZ  ZS )MegatronEnginez
    Megatron-LM model wrapper

    Args:
        accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use.
        model: Megatron-LM model
        optimizer: Megatron-LM optimizer
        lr_scheduler: Megatron-LM lr scheduler
    c                    s   t t|   || _|d | _|| _|| _t }|jj	j
d k	rZ|jj	j
|f|jj	j| _nR|jdkrpt|| _n<|jdkrt|| _n&|jdkrt|| _ntd|j d| j_i | _i | _d| _d| _|jd k	rt  d S )Nr   r>   rC   rD   rG   FT)r   rE  re   module
base_modelr   r   r   rU   rV   Zcustom_train_step_classZcustom_train_step_kwargstrain_step_handlerrK   r   r  r%  rN   r   total_loss_dicteval_total_loss_dictr   report_memory_flagtensorboard_dirr)   )rc   rY   rQ   r   r   rO   r   rR   rS   re     s4    




zMegatronEngine.__init__c                 C   s    | j D ]}|  q|   d S rf   )rF  trainlog_eval_resultsrc   Zmodel_modulerR   rR   rS   rM    s    

zMegatronEngine.trainc                 C   s   | j D ]}|  qd S rf   )rF  evalrO  rR   rR   rS   rP    s    
zMegatronEngine.evalc                    s  t   t }t|dkr^g  jdkrXtd jD ]" fdd| D  q2n|gt| jdkrt|dkrfddtt| jD ndgt| j }nt|dkrtnd} j	dkr j
r| jD ]}|  q| j  t }|| jj|| j| jdd	d
} jdkr$tj  |d  | j | |d  |d  | j |\}}}	|d  |r|d  | j | |d  |r| jdk	rt  j  j }
| jj|
d d}nd}| | j_ jdkrtj    jt !  j t  7  _t j"ddri }|d D ]Pfdd|D }t|d j#dkrvt$|t| |< nt%||< q6||||	fS i |||	fS )z
        Training step for Megatron-LM

        Args:
            batch_data (:obj:`dict`): The batch data to train on.
        r   r   c                    s.   i | ]&\}}|| j  d   j   qS r   r}   r   r   r:  rO   irR   rS   r     s    z-MegatronEngine.train_step.<locals>.<dictcomp>c                    s   g | ]}t  qS rR   r   r   r1  data_chunksrR   rS   
<listcomp>  s     z-MegatronEngine.train_step.<locals>.<listcomp>NlocalF)forward_onlyzbackward-reduce-model-gradsr   zbackward-gather-model-params)	incrementr   TZignore_virtualc                    s   g | ]}|  qS rR   rR   r   r   ri   rR   rS   rZ  %  s     )&r   r   r   r   rangeappendrg   rF  r   ZDDP_implZ#use_contiguous_buffers_in_local_ddpZzero_grad_bufferr   r   r1   rH  r   empty_unused_memory_levelr   r   empty_cachestartZreduce_model_gradsstopr   Zgather_model_paramsr   r   r}   r   r   r   r   r   is_pipeline_last_stager  r   r  )rc   
batch_datatimersbatch_data_iterator	partitionforward_backward_funcZlosses_reducedZupdate_successful	grad_normnum_zeros_in_gradr]  r   loss_reducedlosses_reduced_for_keyrR   rO   rY  rU  ri   rS   
train_step  s    








zMegatronEngine.train_stepc                    sH  t   g  jdkrFtd jD ]" fdd| D  q n|gt| jdkrxfddtt| jD }nt}t }|| j	j
|| jdddd	} jdkrtj    jt  j t  7  _tjdd
r@i }|d D ]Nfdd|D }t|d jdkr,t|t| |< qt||< q|S i S dS )z
        Evaluation step for Megatron-LM

        Args:
            batch_data (:obj:`dict`): The batch data to evaluate on.
        r   r   c                    s.   i | ]&\}}|| j  d   j   qS rQ  rR  rS  rT  rR   rS   r   :  s      z,MegatronEngine.eval_step.<locals>.<dictcomp>c                    s   g | ]}t  qS rR   rV  rW  rX  rR   rS   rZ  @  s     z,MegatronEngine.eval_step.<locals>.<listcomp>NT)r   ri  r\  r^  c                    s   g | ]}|  qS rR   rR   r_  r`  rR   rS   rZ  X  s     )r   r   ra  rb  rg   r   rF  r   r1   rH  r   rc  r   r   rd  r   r   r   r}   r   rg  r  r   r  )rc   rh  rj  rl  Z
loss_dictsro  rp  rR   rq  rS   	eval_step-  sD    
	

zMegatronEngine.eval_stepc                 K   s  t  }| jd jr| jf |\}}}}|  jd7  _|jd k	r| j  }d }|j	r`t
| j}t|| j| jjd d | j|| j||||
| _n|| jf |}|jd k	r|D ]^}	| j|	tjdg||	  | j|	< | j|	d tjdgtjdg | j|	d < qtjd|jd}
|D ]&}	t||	 jdkr |
||	 7 }
q d }d|kr^|d }| jjd k	r|| jj|
|d	S |
S )
Nr   r   lr        
_num_iters      ?r.  r  )r  r  )r   rF  Ztrainingrr  r   rL  r   Zget_loss_scaler   Zlog_params_normr;   rQ   r9   rI  Zparam_groupsrK  rs  rJ  rB  r   r   ZFloatTensortensorr@  r   r  rH  r   )rc   rh  rO   Z	loss_dictr   rm  rn  Z
loss_scaleZparams_normri   r  r  rR   rR   rS   forwarda  sT    

 
zMegatronEngine.forwardc                 C   s  t  }|jd ks| jdkrd S t  }t }d| j d}| jD ]}|drNq>| j| | j|d   }|| d| d7 }ttd|	 }|j
r|| d| d7 }|r>|| d|	 | j |j
r>|| d	|| j q>t|d
 }td|  t| td|  i | _d S )Nr   zvalidation loss at iteration z | rv  z value:    z PPL: z validationz validation pplr   -)r   rL  r   r   rJ  endswithmathexpminr   rH   Z
add_scalarr   r   )rc   rO   writerstringri   rj   ZppllengthrR   rR   rS   rN    s.    

zMegatronEngine.log_eval_resultsc                 C   sB   |    t }||_tj  t| j| j| j	| j
 tj  d S rf   )rN  r   saver   r   barrierr!   r   rF  r   r   )rc   
output_dirrO   rR   rR   rS   r!     s    
zMegatronEngine.save_checkpointc                 C   sb   t  }||_d|_d|_tj  t| j| j	| j
}tj  || _|jr^| jdkr^| j	  d S )Nr   )r   rC  r   r   r   r   r  r    rF  r   r   r   Zfp16Zreload_model_params)rc   Z	input_dirrO   r   rR   rR   rS   r      s    

zMegatronEngine.load_checkpointNc
                 K   sR  t  }|jdkrtd|jdkr*td|jr8td|jdk	rJtd|jdkr\td|dkrt|dkrttd	|dkrd
}nd|  k rdksn td|dkrd}nd|  krdksn td|dkrd}n<|dkr|dkrtdn"d|  krd
ksn td|
dd}d|  kr:d
ksDn td|
dd}d|  krhd
ksrn td|
dd}t	|t
std|}|dk	rt	|tstd|dk rtd|jd dkrdS t }|
d|j}|dk	rt	|tstd|	dkrd
}	d}d}d}tj dkr|dkr`tj|jd g|jd  }n|jdd  }|dkr||jd  }|dkrtd!|r&||jd  d }d"t|d"  }||jd d  }tj|jg| g|jd  }tjtj|dddf dd | |gdd }nd||jd  }d"t|d"  }||jd  }tj|jg| g|jd  }tj| |gdd }|d|dg}td#|dd$}| }t|tj|dd%}t|d tj|dd%}|
d&d}tj| t | j!t"t#t$f}|dk	r,t%|||||d|	d'\}}n"t&|||d|||||d(d)
\}}}|S )*a  
        Generate method for GPT2 model. This method is used for inference. Supports both greedy and beam search along
        with sampling. Refer the Megatron-LM repo for more details

        Args:
            inputs (torch.Tensor): input ids
            attention_mask (torch.Tensor, optional): attention mask. Defaults to None.
            max_length (int, optional): max length of the generated sequence. Defaults to None.
            Either this or max_new_tokens should be provided.
            max_new_tokens (int, optional): max number of tokens to be generated. Defaults to None.
            Either this or max_length should be provided.
            num_beams (int, optional): number of beams to use for beam search. Defaults to None.
            temperature (float, optional): temperature for sampling. Defaults to 1.0.
            top_k (int, optional): top k tokens to consider for sampling. Defaults to 0.0.
            top_p (float, optional): tokens in top p probability are considered for sampling. Defaults to 0.0.
            length_penalty (float, optional): length penalty for beam search. Defaults to None.
            kwargs: additional key-value arguments
        rC   z1Generate method is not implemented for this modelr   z1Generate method requires data parallelism to be 1z9Generate method requires sequence parallelism to be FalseNz2Checkpoint activations cannot be set for inferencez$Vocab file is required for inferencez;`max_length` or `max_new_tokens` are required for inferencerw  ru  g      Y@zAtemperature must be a positive number less than or equal to 100.0r   i  z:top_k must be a positive number less than or equal to 1000z/top_p and top_k sampling cannot be set togetherz'top_p must be less than or equal to 1.0top_p_decayz-top_p_decay must be less than or equal to 1.0top_p_boundz-top_p_bound must be less than or equal to 1.0add_BOSFzadd_BOS must be a booleanzbeam_width must be an integerz!beam_width must be greater than 0z,When doing beam_search, batch size must be 1
stop_tokenzstop_token must be an integerr   )Zaxisz%max_new_tokens must be greater than 0   r   )Zint_listrI   )rx  rI   random_seed)r  Znum_return_genlength_penaltyT)Zreturn_output_log_probstop_ktop_pr  r  temperatureZ#use_eod_token_for_early_termination)'r   rK   NotImplementedErrorr   rN   Zsequence_parallelZrecompute_granularityr  rB  r   boolr   r  r   r  r   r   r>  r   r   r   r}  ceilr  r'  sizer2   tolistr3   r   randomZmanual_seedr=   rG  torchDDPLocalDDPr+   r4   r5   )rc   inputsr   
max_lengthZmax_new_tokensZ	num_beamsr  r  r  r  r   rO   r  r  r  Z
beam_widthr  r  Z
sizes_listZprompts_tokens_tensorZprompts_length_tensorr   Zsizes_tensorsizesZcontext_tokens_tensorZcontext_length_tensorr  Zunwrapped_modelr   r1  rR   rR   rS   megatron_generate  s    !









 

 "  

z MegatronEngine.megatron_generate)NNNNNNNN)r   r   r   r   re   rM  rP  rr  rs  ry  rN  r!   r    r  r   rR   rR   r   rS   rE    s&   
e4=        rE  c                 C   s   t | S )z
    Average losses across data parallel group.

    Args:
        losses (List[Tensor]): List of losses to average across data parallel group.
    )r:   )r!  rR   rR   rS   %avg_losses_across_data_parallel_group  s    r  c                 C   s   dd }t || ddS )z
    Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks.

    Args:
        tensor (nested list/tuple/dictionary of `torch.Tensor`):
            The data to gather across data parallel ranks.

    c                    s^    j dkr  d    fddttjjt dD }tjj| t d tj	|ddS )Nr   c                    s   g | ]}t  qS rR   )r   Z
empty_likerW  rx  rR   rS   rZ    s   zOgather_across_data_parallel_groups.<locals>._gpu_gather_one.<locals>.<listcomp>r   r  )
ndimr8  ra  r   r   r?  r   Zget_data_parallel_groupZ
all_gathercat)rx  Zoutput_tensorsrR   r  rS   _gpu_gather_one  s    

z;gather_across_data_parallel_groups.<locals>._gpu_gather_oneT)Zerror_on_other_type)r   )rx  r  rR   rR   rS   "gather_across_data_parallel_groups  s    

r  )TTTT)mr]   r}  abcr   	functoolsr   r   Ztorch.nn.functionalnnZ
functionalr   Ztorch.nnr   r   r   Ztorch.nn.parallel.distributedr   r  r   r	   r   r
   Zimportsr   r   
operationsr   r   Ztransformers.modeling_outputsr   r   r   Zmegatronr   r   r   r   r   r   r   r   Zmegatron.argumentsr   r   r   r   Zmegatron.checkpointingr   r    r!   Zmegatron.data.data_samplersr"   r#   Zmegatron.global_varsr$   Zmegatron.initializer%   r&   r'   r(   r)   Zmegatron.modelr*   r+   r,   r-   r.   r  Zmegatron.model.classificationr/   Zmegatron.optimizerr0   Zmegatron.schedulesr1   Z&megatron.text_generation.communicationr2   r3   Z#megatron.text_generation.generationr4   r5   Zmegatron.tokenizer.tokenizerr6   Zmegatron.trainingr7   r8   r9   Zmegatron.utilsr:   r;   r<   r=   rT   rZ   r[   r   r   r   r   r   r   r   r   r  r%  rD  ModulerE  r  r  rR   rR   rR   rS   <module>   sj   (
	
( E2 j S   l