U
    9%e|                     @   s  d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dl	m
Z
 d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZmZ d d
lmZ d dlmZmZmZ d dlmZ d dl m!Z!m"Z" d dl#m$Z$ d dl%m&Z&m'Z'm(Z(m)Z)m*Z* dd Z+dd Z,dd Z-ej./dddddgdd Z0ej./ddddgdd Z1ej./ddddgdd  Z2ej./dddddgd!d" Z3d#d$ Z4d%d& Z5d'd( Z6d)d* Z7ej./dddgej./d+de8e9gd,d- Z:ej./dddgej./d.d/d0gd1d2 Z;ej./dddgej./d+e9e<d3e<d4gd5d6 Z=d7d8 Z>ej./d9dej?d:d;d gd<d= Z@ej./d+e8d>gd?d@ ZAej./dAdBdCej?fgdDdE ZBdFdG ZCej./dHejDejEgdIdJ ZFej./d9dej?d:d;d gdKdL ZGej./d+e8d>gdMdN ZHej./dOdPgdQggdPgej?gggdRdS ZIdTdU ZJdVdW ZKdXdY ZLdZd[ ZMd\d] ZNej./d^d_d`dadbdcgddde ZOej./dfde e e e gdgdh ZPdidj ZQdkdl ZRdmdn ZSej./ddddgdodp ZTdqdr ZUdsdt ZVdudv ZWej./dwdxdygdzd{ ZXd|d} ZYd~d ZZdd Z[ej.j/dd de\d gdx dgdx gfdde\ej] gdx ej]gdx gfej] ej]e\ej] gdx ej]gdx gfddydgdddge\ddydgdddggfdej] dgddej]ge\dej] dgddej]ggfgdddddgddd Z^ej./ddej]ej] dfddygddd gdfgdd Z_ej.j/dddgej] ej]gfddgdgd dgd gfgddgddd Z`ej./dddgdd Zaej./dddPejbjcdPdgej./dddPejbjcdPdgdd Zdej./de\ddPgdPdQgge\ddPgdPdggddddfej\ddgddgge9dej\ddgddgge9di dfgdd Zeej./dej?ejfej\fd ejgej\fdejgej\fej?ejfejhfdejgejhfej?ejfejDfdejgejDfej?ejfejifdejgejifej?ejfejjfdejgejjfej?ejfejkfdejgejkfgej./dddxe\d dPdQgfddxe\d dPdQgfgdd Zlej./dejhejDejiejjejkgdd Zmej./ddddgej./dej?ej\fd ej\fej?ejhfej?ejDfej?ejifej?ejjfgdd Zndd Zoej./dej\ddgddgge8ddej\ddddgddddgge8dfe\ej?dCgdCej?ggej?e\dCdCddgdCdCddggfej\ej?dgdej?gge8dej?ej\ddddgddddgge8dfej\ddgddgge8ddej\ddddgddddgge8dfgddĄ Zpej./deegej./ddej?dfdgdd˄ Zqdd̈́ Zrddτ Zsej./deegddф Ztej./dejhejDejiejjejkgddӄ Zuej./dddgdd؄ Zvej./ddadxddQd dPgfdbdPd dQddxgfgddۄ Zwej./ddej?gddބ Zxej./ddej?gdd Zyej./dddddge8ddQfddddge8ddPfdddge8ddQfddddge8ddQfddPdQdxgezddQfdPdPdPdQgezddPfddddPgezddQfdPdPdPdgezddQfgdd Z{ej./dddddgdd Z|dd Z}ej./dddgdd Z~dd Zdd Zdd Zej./dejejfgdd Zej./dddgej./dddgdd  Zej./dddgej./ddddgej./dddgdd ZdS (      Nsparse)kstest)tree)load_diabetes)DummyRegressor)ConvergenceWarning)enable_iterative_imputer)IterativeImputer
KNNImputerMissingIndicatorSimpleImputer)_most_frequent)ARDRegressionBayesianRidgeRidgeCV)GridSearchCV)Pipeline
make_union)_sparse_random_matrix)_convert_containerassert_allcloseassert_allclose_dense_sparseassert_array_almost_equalassert_array_equalc                 C   s   t | | | j|jkstd S N)r   dtypeAssertionErrorxy r!   _/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/sklearn/impute/tests/test_impute.py"_assert_array_equal_and_same_dtype   s    
r#   c                 C   s   t | | | j|jkstd S r   )r   r   r   r   r!   r!   r"   _assert_allclose_and_same_dtype$   s    
r$   c           	      C   s   d||f }t }| jjdks(|jjdkr,t}t||d}|| |  }||j||	dd ||||	dd t||d}|t
|  |t
|  }t
|r| }||j||	dd ||||	dd dS )zUtility function for testing imputation for a given strategy.

    Test with dense and sparse arrays

    Check that:
        - the statistics (mean, median, mode) are correct
        - the missing values are imputed correctlyz<Parameters: strategy = %s, missing_values = %s, sparse = {0}fmissing_valuesstrategyF)err_msgTN)r   r   kindr   r   fit	transformcopystatistics_formatr   
csc_matrixissparsetoarray)	XX_truer(   
statisticsr'   r)   Z	assert_aeimputerX_transr!   r!   r"   _check_statistics)   s$    	
r8   r(   meanmedianmost_frequentconstantc                 C   s   t jdd}t j|d d d< t| d}|t|}|jdksFt	||}|jdks^t	t
| d}||}|jdkst	d S )N
      r(   )r=   r>   )initial_strategy)nprandomrandnnanr   fit_transformr   
csr_matrixshaper   r
   )r(   r3   r6   	X_imputedZiterative_imputerr!   r!   r"   test_imputation_shapeN   s    



rI   c              	   C   sT   t d}t j|d d df< t| d|}tjtdd || W 5 Q R X d S )N      r   r?   ZSkippingmatch)	rA   onesrD   r   r+   pytestwarnsUserWarningr,   r(   r3   r6   r!   r!   r"    test_imputation_deletion_warning_   s
    
rT   c              	   C   s   t d}tj}tjddddgtd}|j||d|gd|d	d
gg|d}t| d|}t	|j
| t jtdd || W 5 Q R X d S )Npandasabcdr         r>   r=   columnsr?   z6Skipping features without any observed values: \['b'\]rM   )rP   importorskiprA   rD   arrayobject	DataFramer   r+   r   Zfeature_names_in_rQ   rR   r,   )r(   pdr'   feature_namesr3   r6   r!   r!   r"   .test_imputation_deletion_warning_feature_namesi   s     


 re   c              	   C   s   t d}d|d< t|}t| dd}tjtdd || W 5 Q R X ||	  tjtdd |
| W 5 Q R X d S )NrJ   r   )r(   r'   zProvide a dense arrayrM   )rA   rO   r   r0   r   rP   raises
ValueErrorr+   r2   r,   rS   r!   r!   r"   test_imputation_error_sparse_0   s    

rh   c                 O   s8   t | dr| jnt| }|dkr&tjS tj| f||S Nsizer   )hasattrrj   lenrA   rD   r:   Zarrargskwargslengthr!   r!   r"   safe_median   s    rq   c                 O   s8   t | dr| jnt| }|dkr&tjS tj| f||S ri   )rk   rj   rl   rA   rD   r9   rm   r!   r!   r"   	safe_mean   s    rr   c               
   C   sv  t jd} d}d}|| || f}t |d }t d|d d }|dd d  |dd d< dt jdd fd	t jd
d fg}|D ]\}}}	t |}
t |}t |d }t|d D ]Z}|| d dk|| d  || d  }t|d ||  || ||   d}|d | | }|d | }t 	||}|| 
t|d |  }|	|||||< t |||f|
d d |f< d|krt |t 	|| || f|d d |f< n(t ||t 	|| |f|d d |f< t j||
d d |f  t j||d d |f  q|d	kr<t |jdd }nt |jdd }|d d |f }t|
|||| qd S )Nr   r=   r[   r\   r>   r9   c                 S   s   t t| |fS r   )rr   rA   hstackzvpr!   r!   r"   <lambda>       z-test_imputation_mean_median.<locals>.<lambda>r:   c                 S   s   t t| |fS r   )rq   rA   rs   rt   r!   r!   r"   rx      ry   )Zaxis)rA   rB   RandomStatezerosarangerD   emptyrangemaxrepeatZpermutationrl   rs   shuffleisnananyallr8   )rngdimdecrG   r{   valuestestsr(   Ztest_missing_valuesZtrue_value_funr3   r4   Ztrue_statisticsjZnb_zerosZnb_missing_valuesZ	nb_valuesru   rw   rv   Zcols_to_keepr!   r!   r"   test_imputation_mean_median   sJ    

(&
 
r   c                  C   s   t dt jt jgdt jt jgddt jgddt jgddt jgddt jgddt jgddt jgg } t dddgdddgdddgdddgddd	gddd
gdddgdddgg }ddddd	d
ddg}t| |d|t j d S )Nr   rL   r\   r>   g      g      @g      @g            ?r:   )rA   r`   rD   Z	transposer8   )r3   ZX_imputed_medianZstatistics_medianr!   r!   r"   $test_imputation_median_special_cases   s0    





r   r   c              	   C   s\   t jdddgdddgddd	gg|d
}d}tjt|d t| d}|| W 5 Q R X d S )NrV   rW   rK   r\   e   gh	   rZ   4non-numeric data:
could not convert string to float:rM   r?   )rA   r`   rP   rf   rg   r   rE   )r(   r   r3   msgr6   r!   r!   r"   .test_imputation_mean_median_error_invalid_type  s
    &
r   typelist	dataframec              	   C   sn   dddgdddgddd	gg}|d
kr8t d}||}d}t jt|d t| d}|| W 5 Q R X d S )NrV   rW   rK   r\   r   r   r   r   r   r   rU   r   rM   r?   )rP   r_   rb   rf   rg   r   rE   )r(   r   r3   rc   r   r6   r!   r!   r"   :test_imputation_mean_median_error_invalid_type_list_pandas  s    


r   USc              	   C   s   t jt jt jddgt jdt jdgt jddt jgt jdddgg|d}d}tjt|d	  t| d
}||| W 5 Q R X d S )NrV   r%   rX   rY   rW   r   rZ   z#SimpleImputer does not support datarM   r?   )	rA   r`   rD   rP   rf   rg   r   r+   r,   )r(   r   r3   r)   r6   r!   r!   r"   /test_imputation_const_mostf_error_invalid_types  s    

r   c               	   C   sz   t ddddgddddgddddgddddgg} t dddgdddgdddgdddgg}t| |dt jdddgd d S )	Nr   r   rL   r>   rK   r[      r;   )rA   r`   r8   rD   )r3   r4   r!   r!   r"   test_imputation_most_frequent.  s    



	r   markerZNAN c                 C   s   t j| | ddg| d| dg| dd| g| dddggtd}t jdddgdddgdddgdddggtd}t| dd	}|||}t|| d S )
NrV   r%   rX   rY   rW   r   rZ   r;   r&   )rA   r`   ra   r   r+   r,   r   r   r3   r4   r6   r7   r!   r!   r"   %test_imputation_most_frequent_objectsI  s&    





r   categoryc                 C   sr   t d}td}|j|| d}tjdddgdddgdddgd	ddggtd}td
d}|	|}t
|| d S )NrU   ,Cat1,Cat2,Cat3,Cat4
,i,x,
a,,y,
a,j,,
b,j,x,rZ   rV   ir   r   r    rW   r;   r?   rP   r_   ioStringIOZread_csvrA   r`   ra   r   rE   r   r   rc   r%   dfr4   r6   r7   r!   r!   r"   $test_imputation_most_frequent_pandasf  s    

"

r   zX_data, missing_value)r[   r         ?c              	   C   sN   t jd| td}||d< tjtdd t|ddd}|| W 5 Q R X d S )	NrJ   rZ   r   r   zimputing numericalrM   r<   r   r'   r(   
fill_value)rA   fullfloatrP   rf   rg   r   rE   )ZX_datamissing_valuer3   r6   r!   r!   r"   +test_imputation_constant_error_invalid_typez  s      r   c               	   C   s   t ddddgddddgddddgdd	d
dgg} t d
ddd
gdd
dd
gddd
d
gdd	d
d
gg}tddd
d}|| }t|| d S )Nr   r>   rK   r\   rL   r   r      r   r   r<   r   )rA   r`   r   rE   r   )r3   r4   r6   r7   r!   r!   r"    test_imputation_constant_integer  s
    22
r   array_constructorc              	   C   s   t t jddt jgdt jdt jgddt jt jgdddt jgg}t ddddgddddgddddgddddgg}| |}| |}tddd	}||}t|| d S )
Ng?r   333333??gffffff?      ?r   r<   )r(   r   )rA   r`   rD   r   rE   r   )r   r3   r4   r6   r7   r!   r!   r"   test_imputation_constant_float  s    	*
r   c                 C   s   t j| dd| gd| d| gdd| | gddd	| ggtd
}t jddddgddddgddddgddd	dggtd
}t| ddd}||}t|| d S )NrV   rW   rX   rY   r   r%   r   r   r   rZ   missingr<   r   )rA   r`   ra   r   rE   r   r   r!   r!   r"   test_imputation_constant_object  s.    









  
r   c                 C   sz   t d}td}|j|| d}tjddddgddddgdd	ddgd
d	ddggtd}tdd}|	|}t
|| d S )NrU   r   rZ   r   r   r   rV   r    r   rW   r<   r?   r   r   r!   r!   r"   test_imputation_constant_pandas  s    








r   r3   r[   r>   c                 C   sf   t  | }|jdkstt  }|dgdgg |jdks@t|dgtjgg |jdksbtd S )Nr   r[   r>   )r
   r+   n_iter_r   rA   rD   r3   r6   r!   r!   r"   "test_iterative_imputer_one_feature  s    r   c                  C   st   t dddd} | jd }tdt|dfdtjddfg}d	d
ddgi}t dddd }t||}|| | d S )Nd   皙?)densityr   r6   r'   r   random_stateZimputer__strategyr9   r:   r;   r[   )	r   datar   r   r   ZDecisionTreeRegressorr2   r   r+   )r3   r'   Zpipeline
parametersYgsr!   r!   r"   $test_imputation_pipeline_grid_search  s    

r   c                  C   sv  t ddddd} |   }tdddd}|||}d|d	< t||krTt|  }t|j	d ddd}|||}d|j	d< t|j	|j	krt|   }tddd
d}|||}d|d	< t
|| |   }t|j	d dd
d}|||}d|j	d< t
|j	|j	 |  }t|j	d dd
d}|||}d|j	d< t|j	|j	krrtd S )NrL   g      ?r   r   r   r9   T)r'   r(   r-   r   r   F)r   r-   r2   r   r+   r,   rA   r   r   r   r   Ztocsc)ZX_origr3   r6   Xtr!   r!   r"   test_imputation_copy  s4    



r   c                  C   s   t jd} d}d}t||d| d }|dk}t j||< tdd}||}t||j	
| tdd|}t |
||j	
|krtd|_t|
||j	
| d S )Nr   r   r=   r   r   )max_iterrL   )rA   rB   rz   r   r2   rD   r
   rE   r   initial_imputer_r,   r+   r   r   r   )r   nrY   r3   Zmissing_flagr6   rH   r!   r!   r"   !test_iterative_imputer_zero_iters/  s    


 r   c                  C   sp   t jd} d}d}t||d| d }tdddd}|| || tdddd}|| || d S )	Nr   r   rK   r   r   r[   )r'   r   verboser>   )rA   rB   rz   r   r2   r
   r+   r,   )r   r   rY   r3   r6   r!   r!   r"   test_iterative_imputer_verboseG  s    


r   c                  C   sB   d} d}t | |f}tddd}||}t||j| d S )Nr   rK   r   r[   )r'   r   )rA   r{   r
   rE   r   r   r,   )r   rY   r3   r6   rH   r!   r!   r"   "test_iterative_imputer_all_missingU  s    
r   imputation_orderrB   roman	ascending
descendingarabicc           
      C   sR  t jd}d}d}d}t||d|d }d|d d df< td|dd	d
ddd| |d
}|| dd |jD }t||j	 |j
kst| dkrt |d |d  t d|kstn| dkrt |d |d  t |d ddkstn^| dkr*|d |d  }||d d  }	||	ksNtn$d| krNt|||d  ksNtd S )Nr   r   r=   r>   r   r   r[   rL   FT)
r'   r   n_nearest_featuressample_posteriorskip_complete	min_value	max_valuer   r   r   c                 S   s   g | ]
}|j qS r!   Zfeat_idx).0r   r!   r!   r"   
<listcomp>v  s     z;test_iterative_imputer_imputation_order.<locals>.<listcomp>r   r   r   rB   ending)rA   rB   rz   r   r2   r
   rE   imputation_sequence_rl   r   Zn_features_with_missing_r   r   r|   )
r   r   r   rY   r   r3   r6   Zordered_idxZordered_idx_round_1Zordered_idx_round_2r!   r!   r"   'test_iterative_imputer_imputation_order^  s>    
(.

r   	estimatorc           	      C   s   t jd}d}d}t||d|d }tdd| |d}|| g }|jD ]>}| d k	r`t| ntt	 }t
|j|szt|t|j qLtt|t|kstd S )Nr   r   r=   r   r   r[   )r'   r   r   r   )rA   rB   rz   r   r2   r
   rE   r   r   r   
isinstancer   r   appendidrl   set)	r   r   r   rY   r3   r6   hashestripletexpected_typer!   r!   r"   !test_iterative_imputer_estimators  s$       

r   c                  C   s   t jd} d}d}t||d| d }tdddd| d}||}tt ||dk d tt 	||dk d t||dk ||dk  d S )	Nr   r   r=   r   r   r[   皙?)r'   r   r   r   r   
rA   rB   rz   r   r2   r
   rE   r   minr   r   r   rY   r3   r6   r   r!   r!   r"   test_iterative_imputer_clip  s        
r   c                  C   s   t jd} d}d}t||d| d }d|d d df< tdddd	dd
dd| d	}||}tt ||dk d tt 	||dk d
 t||dk ||dk  d S )Nr   r   r=   r   r   r[   r>   rL   Tr   rB   )	r'   r   r   r   r   r   r   r   r   r   r   r!   r!   r"   %test_iterative_imputer_clip_truncnorm  s(    
r   c                     s   t jd} | jdd t j d d< tddd| d  t  fdd	td
D }t	|dksnt
t	|dks~t
| |  }}t|| | d\}}|dkr|d7 }t|| | d\}}|dk s|dkst
dd S )N*   )rL   rL   )rj   r   r   T)r   r   r   r   c                    s   g | ]}  d  d  qS )r   )r,   )r   _r   r!   r"   r     s     zEtest_iterative_imputer_truncated_normal_posterior.<locals>.<listcomp>r   Znormg-q=r   r   z&The posterior does appear to be normal)rA   rB   rz   normalrD   r
   rE   r`   r~   r   r   r9   Zstdr   )r   ZimputationsmusigmaZks_statisticZp_valuer!   r   r"   1test_iterative_imputer_truncated_normal_posterior  s&       
r   c                 C   s   t jd}d}d}|jdd||fd}|jdd||fd}d|d d df< d|d< tdd| |d|}td| d	|}t||d d df ||d d df  d S )
Nr   r   r=   rK   )lowhighrj   r[   r   )r'   r   r@   r   r&   )	rA   rB   rz   randintr
   r+   r   r   r,   )r(   r   r   rY   X_trainX_testr6   Zinitial_imputerr!   r!   r"   +test_iterative_imputer_missing_at_transform  s(        r  c                  C   s   t jd} t jd}d}d}t||d| d }tddd| d}|| ||}||}t |t	
t |ksttddd	d d
| d}tddd	d d
|d}	|| |	| ||}
||}|	|}t|
| t|
| d S )Nr   r[   r   r=   r   r   T)r'   r   r   r   Fr   )r'   r   r   r   r   r   )rA   rB   rz   r   r2   r
   r+   r,   r9   rP   Zapproxr   r   )Zrng1Zrng2r   rY   r3   r6   Z
X_fitted_1Z
X_fitted_2imputer1imputer2ZX_fitted_1aZX_fitted_1br!   r!   r"   .test_iterative_imputer_transform_stochasticity  sL       


	





r  c                  C   s   t jd} | dd}t j|d d df< td| d}td| d}|||}||}t	|d d dd f | t	|| d S )Nr   r   r=   )r   r   r[   )
rA   rB   rz   randrD   r
   r+   r,   rE   r   )r   r3   m1m2Zpred1Zpred2r!   r!   r"   !test_iterative_imputer_no_missing4  s    
r
  c            	      C   s   t jd} d}| |d}| d|}t ||}| ||dk }| }t j||< tdd| d}||}t	||dd d S )	Nr   2   r[   r   rL   r   r   r   g{Gz?atol)
rA   rB   rz   r  dotr-   rD   r
   rE   r   )	r   rY   ABr3   nan_mask	X_missingr6   X_filledr!   r!   r"   test_iterative_imputer_rank_oneB  s    

r  rankrK   rL   c                 C   s   t jd}d}d}||| }|| |}t ||}|||dk }| }t j||< |d }|d | }	||d  }
||d  }tddd|d|	}|	|}t
|
|d	d
 d S )Nr   F   r   r>   rL   r   r[   )r   r   r   r   r   r  )rA   rB   rz   r  r  r-   rD   r
   r+   r,   r   )r  r   r   rY   r  r  r  r  r  r  X_test_filledr  r6   
X_test_estr!   r!   r"   )test_iterative_imputer_transform_recoveryQ  s.    
   
r  c               	   C   s  t jd} d}d}| ||}| ||}t |j}t|D ]R}t|D ]D}|d d || | f  |d d |f |d d |f  d 7  < qLq@| ||dk }| }	t j	|	|< |d }|	d | }
||d  }|	|d  }t
dd| d|
}||}t||dd	d
 d S )Nr   r   r=   r>   g      ?r[   r  gMbP?{Gz?)rtolr  )rA   rB   rz   rC   r{   rG   r~   r  r-   rD   r
   r+   r,   r   )r   r   rY   r  r  r  r   r   r  r  r  r  r  r6   r  r!   r!   r"   &test_iterative_imputer_additive_matrixj  s&    D

r  c                  C   s   t jd} d}d}| |d}| d|}t ||}| ||dk }| }t j||< tdddd| d	}||}	t	|j
||j kstt|jdd| d
}||}
t|	|
dd tdddd| d	}|| |j|jkstd S )Nr   r  rL   r[   r   r   r  F)r   Ztolr   r   r   )r   r   r   r   gHz>r  )rA   rB   rz   r  r  r-   rD   r
   rE   rl   r   r   r   r   r+   r   )r   r   rY   r  r  r3   r  r  r6   ZX_filled_100ZX_filled_earlyr!   r!   r"   %test_iterative_imputer_early_stopping  sF    
    
   
    
r  c            
   	   C   s   t dd\} }| j\}}d| d d df< tjd}d}t|D ]0}|jt|t|| dd}tj	| ||f< q@t
d	dd
}t  tdt || |}	W 5 Q R X tt|	rtd S )NT)Z
return_X_yr[   rK   r   g333333?F)rj   replacerL   )r   r   error)r   rG   rA   rB   rz   r~   choicer|   intrD   r
   warningscatch_warningssimplefilterRuntimeWarningrE   r   r   r   )
r3   r    Z	n_samples
n_featuresr   Zmissing_rateZfeatZ
sample_idxr6   ZX_fillr!   r!   r"   $test_iterative_imputer_catch_warning  s"    
 
 
r(  z$min_value, max_value, correct_outputr   r   r=      i,  ZscalarszNone-defaultinflistszlists-with-inf)Zidsc                 C   s   t jddd}t| |d}|| t|jt jrFt|j	t jsJt
|jjd |jd krv|j	jd |jd kszt
t|dd d f |j t|dd d f |j	 d S )Nr   r=   rK   r   r   r[   )rA   rB   rz   rC   r
   r+   r   Z
_min_valuendarrayZ
_max_valuer   rG   r   )r   r   Zcorrect_outputr3   r6   r!   r!   r"   )test_iterative_imputer_min_max_array_like  s    
 r.  zmin_value, max_value, err_msg)r   r   min_value >= max_value.r/  z_value' should be of shapec              	   C   s@   t jd}t| |d}tjt|d || W 5 Q R X d S )Nr=   rK   r,  rM   )rA   rB   r
   rP   rf   rg   r+   )r   r   r)   r3   r6   r!   r!   r"   *test_iterative_imputer_catch_min_max_error  s    r1  zmin_max_1, min_max_2ir\   zNone-vs-infzScalar-vs-vectorc              	   C   s   t t jdddgdt jt jdgddt jdgt jddt jgg}t t jdt jdgddt jt jgt jdddgg}t| d | d dd	}t|d |d dd	}|||}|||}t|d d df |d d df  d S )
Nr>   r[   r=   r   rK   r\   rL   r   )r   r   r   )rA   r`   rD   r
   r+   r,   r   )Z	min_max_1Z	min_max_2r  r  r  r  ZX_test_imputed1ZX_test_imputed2r!   r!   r"   4test_iterative_imputer_min_max_array_like_imputation  s.    *    r2  r   TFc              	   C   s   t jd}t ddddgddddgddddgdd	ddgg}t t jdd	dgt jd	ddgt jdddgg}td
| |d}|||}| rt|d d df t 	|d d df  n t|d d df dddgdd d S )Nr   rL   r>   r[   r=   r   rK   r   r\   r9   )r@   r   r         g-C6?)r  )
rA   rB   rz   r`   rD   r
   r+   r,   r   r9   )r   r   r  r  r6   r  r!   r!   r"   'test_iterative_imputer_skip_non_missing
  s    2.  *r5  
rs_imputer)seedrs_estimatorc                 C   sH   G dd d}||d}t | d}td}|| |j|ksDtd S )Nc                   @   s$   e Zd Zdd Zdd Zdd ZdS )zCtest_iterative_imputer_dont_set_random_state.<locals>.ZeroEstimatorc                 S   s
   || _ d S r   r   )selfr   r!   r!   r"   __init__!  s    zLtest_iterative_imputer_dont_set_random_state.<locals>.ZeroEstimator.__init__c                 _   s   | S r   r!   )r9  rn   Zkgardsr!   r!   r"   r+   $  s    zGtest_iterative_imputer_dont_set_random_state.<locals>.ZeroEstimator.fitc                 S   s   t |jd S )Nr   )rA   r{   rG   )r9  r3   r!   r!   r"   predict'  s    zKtest_iterative_imputer_dont_set_random_state.<locals>.ZeroEstimator.predictN)__name__
__module____qualname__r:  r+   r;  r!   r!   r!   r"   ZeroEstimator   s   r?  r   r0  )r
   rA   r{   r+   r   r   )r6  r8  r?  r   r6   r  r!   r!   r"   ,test_iterative_imputer_dont_set_random_state  s    




r@  zX_fit, X_trans, params, msg_errr   missing-onlyauto)featuresr   zBhave missing values in transform but have no missing values in fitrV   rW   rX   rZ   z1MissingIndicator does not support data with dtypec              	   C   sD   t dd}|jf | tjt|d || | W 5 Q R X d S )Nr   r   rM   )r   
set_paramsrP   rf   rg   r+   r,   )X_fitr7   paramsZmsg_err	indicatorr!   r!   r"   test_missing_indicator_error1  s    
rH  zmissing_values, dtype, arr_typez,param_features, n_features, features_indicesr   c                 C   s  t | | dgdd| gg}t | | dgdddgg}t dddgdddgg}t dddgdddgg}	|||}|||}||}|	|}	t| |dd}
|
|}|
|}|jd |kst|jd |kstt|
j	| t
||d d |f  t
||	d d |f  |jtks&t|jtks6tt|t jsHtt|t jsZt|
jd	d
 |
|}|
|}|jtkst|jtkst|jdkst|jdkstt
| | t
| | d S )Nr[   r\   r>   r4  r=   r   F)r'   rC  r   Tr   csc)rA   r`   astyper   rE   r,   rG   r   r   Z	features_r   r   boolr   r-  rD  r/   r2   )r'   arr_typer   Zparam_featuresr'  Zfeatures_indicesrE  r7   ZX_fit_expectedZX_trans_expectedrG  
X_fit_maskX_trans_maskZX_fit_mask_sparseZX_trans_mask_sparser!   r!   r"   test_missing_indicator_newI  sB    

  



rO  rL  c              	   C   s   d}t ||dgd|dgg}t ||dgdddgg}| |}| |}t|d}tjtdd	 || W 5 Q R X || tjtdd	 || W 5 Q R X d S )
Nr   r[   r\   r>   r4  r=   r   z"Sparse input with missing_values=0rM   )rA   r`   r   rP   rf   rg   rE   r,   )rL  r'   rE  r7   ZX_fit_sparseZX_trans_sparserG  r!   r!   r"   5test_missing_indicator_raise_on_sparse_with_missing_0  s    

rP  param_sparsezmissing_values, arr_typec                 C   sL  t ||dgd|dgg}t ||dgdddgg}| |t j}| |t j}t||d}||}||}|dkr|jdkst|jdkstn|d	kr|d
krt	|t j
stt	|t j
stn||dkrt	|t j
stt	|t j
stnRt|r$|jdkst|jdksHtn$t	|t j
s6tt	|t j
sHtd S )Nr[   r\   r>   r4  r=   )r'   r   TrI  rB  r   F)rA   r`   rJ  float64r   rE   r,   r/   r   r   r-  r   r1   )rL  r'   rQ  rE  r7   rG  rM  rN  r!   r!   r"   #test_missing_indicator_sparse_param  s*    

rS  c                  C   sX   t jdddgdddggtd} tddd}|| }t|t dddgdddgg d S )	NrV   rW   rX   rZ   r   )r'   rC  TF)rA   r`   ra   r   rE   r   )r3   rG  r7   r!   r!   r"   test_missing_indicator_string  s    
rT  zX, missing_values, X_trans_expc                 C   s0   t t|ddt|d}|| }t|| d S )Nr;   r&   r   )r   r   r   rE   r   )r3   r'   ZX_trans_expZtransr7   r!   r!   r"   #test_missing_indicator_with_imputer  s    

rU  imputer_constructorz.imputer_missing_values, missing_value, err_msgNaNzInput X contains NaN)z-1r   z(types are expected to be both numerical.c              	   C   sR   t jd}|dd}||d< | |d}tjt|d || W 5 Q R X d S )Nr   r=   r   r   rM   )rA   rB   rz   rC   rP   rf   rg   rE   )rV  Zimputer_missing_valuesr   r)   r   r3   r6   r!   r!   r"   (test_inconsistent_dtype_X_missing_values  s    
rX  c                  C   sB   t ddgddgg} tddd}|| }|jd dks>td S )Nr[   rA  r   rC  r'   r   )rA   r`   r   rE   rG   r   r3   mir   r!   r!   r"   !test_missing_indicator_no_missing  s    
r\  c                  C   sP   t dddgdddgdddgg} tddd}|| }| | ksLtd S )Nr   r[   r>   r   rY  )r   rF   r   rE   Zgetnnzsumr   rZ  r!   r!   r"   /test_missing_indicator_sparse_no_explicit_zeros  s    "
r^  c                 C   s8   t ddgddgg}|  }|| |jd ks4td S )Nr[   )rA   r`   r+   Z
indicator_r   )rV  r3   r6   r!   r!   r"   test_imputer_without_indicator)  s    
r_  c                 C   s   | t jddgdt jdgddt jgdddgg}t ddd	dd
d
gdddd
dd
gddd	d
d
dgdddd
d
d
gg}tt jdd}||}t|st|j|jkstt	|
 | d S )Nr[   rL   r>   r   rK   r         @r   g      @g               @g      @g      "@T)r'   add_indicator)rA   rD   r`   r   rE   r   r1   r   rG   r   r2   )rL  ZX_sparser4   r6   r7   r!   r!   r"   2test_simple_imputation_add_indicator_sparse_matrix2  s    .	
rc  zstrategy, expected)r;   rW   )r<   r   c                 C   sN   ddgdt jgg}t jddgd|ggtd}t| d}||}t|| d S )NrV   rW   rX   rZ   r?   )rA   rD   r`   ra   r   rE   r   )r(   expectedr3   r4   r6   r7   r!   r!   r"   "test_simple_imputation_string_listO  s
    

re  zorder, idx_orderc              	   C   s   t jd}|dd}t j|d ddf< t j|d ddf< t j|d dd	f< t j|d d
df< tt6 td| dd	|}dd |j
D }||kstW 5 Q R X d S )Nr   r   rL   r  r[      r      r>   r=   r\   )r   r   r   c                 S   s   g | ]
}|j qS r!   r   )r   r   r!   r!   r"   r   n  s     z)test_imputation_order.<locals>.<listcomp>)rA   rB   rz   r  rD   rP   rQ   r   r
   r+   r   r   )orderZ	idx_orderr   r3   Ztrsidxr!   r!   r"   test_imputation_order]  s    rj  r   c              	   C   sD  t d| ddgddddgdd| dgddd	| gg}t ddd
dgd
d| dgd| ddgddd
| gg}t d| ddg| d| | gd
| d| g| d| dgg}t ddddg| d
| dgd
dddg| d| d
gg}t| ddd}||}||}||}||}	t|| t|	| ||fD ]$}
||
}||}t||
 qd S )Nr   rK   r   r\   rL   r   r   r   r   r>   r[   r9   T)r'   r(   rb  )rA   r`   r   rE   inverse_transformr,   r   )r   X_1ZX_2ZX_3ZX_4r6   	X_1_transZX_1_inv_transZ	X_2_transZX_2_inv_transr3   r7   ZX_inv_transr!   r!   r"   (test_simple_imputation_inverse_transformr  sV    



	



	



	



	  







rn  c              	   C   sz   t d| ddgddddgdd| dgddd	| gg}t| d
d}||}tjtd|j dd || W 5 Q R X d S )Nr   rK   r   r\   rL   r   r   r   r   r9   r&   zGot 'add_indicator='rM   )	rA   r`   r   rE   rP   rf   rg   rb  rk  )r   rl  r6   rm  r!   r!   r"   3test_simple_imputation_inverse_transform_exceptions  s    



	
 rp  z)expected,array,dtype,extra_value,n_repeatextra_valueZmost_frequent_valuevaluer   Zmin_valuevalueru   rg  c                 C   s"   | t tj||d||kstd S )NrZ   )r   rA   r`   r   )rd  r`   r   rq  Zn_repeatr!   r!   r"   test_most_frequent  s
      rs  r@   c                 C   sp   t dt jdgdt jt jgg}t| dd}||}t|dddf d ||}t|dddf d dS )zCheck the behaviour of the iterative imputer with different initial strategy
    and keeping empty features (i.e. features containing only missing values).
    r[   r>   rK   T)r@   keep_empty_featuresNr   )rA   r`   rD   r
   rE   r   r,   )r@   r3   r6   rH   r!   r!   r"   *test_iterative_imputer_keep_empty_features  s      

ru  c               	   C   sb   t ddddgddddgddddgdd	d
dgg} d}tdd|d
d}||  t|jj| dS )z<Check that we propagate properly the parameter `fill_value`.r   r>   rK   r\   rL   r   r   r   r   r   r   r<   )r'   r@   r   r   N)rA   r`   r
   rE   r   r   r.   )r3   r   r6   r!   r!   r"   *test_iterative_imputer_constant_fill_value  s    2
rv  rt  c                 C   s   t dt jdgdt jt jgg}t| d}dD ]`}t|||}| rl|j|jksTtt|dddf d q.|j|jd |jd d fks.tq.dS )z>Check the behaviour of `keep_empty_features` for `KNNImputer`.r[   r>   rK   )rt  rE   r,   Nr   )rA   r`   rD   r   getattrrG   r   r   )rt  r3   r6   methodrH   r!   r!   r"   $test_knn_imputer_keep_empty_features  s     
rz  c                  C   s  t d} | d| jdd dgddi}t| jddd	}t||tj	dgdgdggt
d | d| jddd
gddi}tddd}t||tj	dgdgd
ggt
d | d| jdd dgddi}t| jddd	}t||tj	dgdgdggdd ttjddd	}t||tj	dgdgdggdd | d| jdd ddgddi}t| jdd}t||tj	dgdgdgdggdd | d| jdd dgddi}t| jdd}t||tj	dgdgdggdd | d| jdd dgddi}t| jddd	}t||tj	dgdgdggdd | d| jdd ddgddi}t| jdd}t||tj	dgdgdgdggdd d S )NrU   featureabcdestringrZ   r<   nar   Zfghok)r   r(   r[   rK   ZInt64r   rR  r>   r:   r&   r9   r   r   r`  g       ra  )rP   r_   rb   ZSeriesr   ZNAr#   rE   rA   r`   ra   r$   rD   )rc   r   r6   r!   r!   r"   test_simple_impute_pd_na  s`    
         r  c                  C   sj   t d} tj}| j||d|gd|ddggdddd	gd
}t|d|}| }dddg}t|| dS )zDCheck that missing indicator return the feature names with a prefix.rU   r[   r\   r>   r=   rV   rW   rX   rY   r]   r   Zmissingindicator_aZmissingindicator_bZmissingindicator_dN)	rP   r_   rA   rD   rb   r   r+   Zget_feature_names_outr   )rc   r'   r3   rG  rd   Zexpected_namesr!   r!   r"   (test_missing_indicator_feature_names_outG  s    




r  c                  C   s\   ddgddgddgg} t dd| }|tjtjgg}|jtksHtt|ddgg dS )zkCheck transform uses object dtype when fitted on an object dtype.

    Non-regression test for #19572.
    rV   rW   rX   r;   r?   N)	r   r+   r,   rA   rD   r   ra   r   r   )r3   Zimp_frequentr7   r!   r!   r"    test_imputer_lists_fit_transformZ  s
    r  
dtype_testc                 C   sp   t jddt jgt jddgdddggt jd}t |}t jt jt jt jgg| d}||}|j| ksltdS )	zACheck transform preserves numeric dtype independent of fit dtype.r   g333333@r   g@r>   r[   rZ   N)	rA   asarrayrD   rR  r   r+   r,   r   r   )r  r3   impr  r7   r!   r!   r"   .test_imputer_transform_preserves_numeric_dtypeg  s     
r  
array_typer`   r   c                 C   s   t t jdgt jdgt jdgg}t|| }d}td||d}dD ]X}t|||}|j|jksdt| dkr|d	d	d
f  n|d	d	d
f }t	|| qBd	S )zCheck the behaviour of `keep_empty_features` with `strategy='constant'.
    For backward compatibility, a column full of missing values will always be
    fill and never dropped.
    r>   rK   r   r=   r<   )r(   r   rt  rw  r   Nr   
rA   r`   rD   r   r   rx  rG   r   r2   r   )r  rt  r3   r   r6   ry  rH   constant_featurer!   r!   r"   0test_simple_imputer_constant_keep_empty_featurest  s    "
*r  c                 C   s   t t jdgt jdgt jdgg}t||}t| |d}dD ]}t|||}|r|j|jksbt|dkr~|dddf  n|dddf }t	|d q<|j|jd |jd	 d	 fks<tq<dS )
zYCheck the behaviour of `keep_empty_features` with all strategies but
    'constant'.
    r>   rK   r   )r(   rt  rw  r   Nr   r[   r  )r(   r  rt  r3   r6   ry  rH   r  r!   r!   r"   'test_simple_imputer_keep_empty_features  s    "
*r  )r   r#  numpyrA   rP   Zscipyr   Zscipy.statsr   Zsklearnr   Zsklearn.datasetsr   Zsklearn.dummyr   Zsklearn.exceptionsr   Zsklearn.experimentalr	   Zsklearn.imputer
   r   r   r   Zsklearn.impute._baser   Zsklearn.linear_modelr   r   r   Zsklearn.model_selectionr   Zsklearn.pipeliner   r   Zsklearn.random_projectionr   Zsklearn.utils._testingr   r   r   r   r   r#   r$   r8   markZparametrizerI   rT   re   rh   rq   rr   r   r   ra   strr   r   r   r   r   rD   r   r   r   r   rF   r  r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r
  r  r  r  r  r(  r`   r*  r.  r1  r2  r5  rB   rz   r@  rH  rR  Zint32r0   Z
coo_matrixZ
lil_matrixZ
bsr_matrixrO  rP  rS  rT  rU  rX  r\  r^  r_  rc  re  rj  rn  rp  r"  rs  ru  rv  rz  r  r  r  Zfloat32r  r  r  r!   r!   r!   r"   <module>   s"  	%

	

C 





"
+	 
% 
!
2
$**


	0


*,






   
	



 
"

9

 


<
