U
    9%e~                     @   s:  d dl mZ d dlZd dlZd dlmZ d dlm	Z	 d dl
mZmZ d dlmZ d dlmZ d dlmZ d dlmZmZmZmZmZmZmZ d d	lmZmZ d d
lmZm Z  d dl!m"Z"m#Z#m$Z$ d dl%m&Z& d dl'm(Z( d dl)m*Z*m+Z+ d dl,m-Z-m.Z. d dl/m0Z0m1Z1 d dl2m3Z3m4Z4 d dl5m6Z6 d dl7m8Z8m9Z9 d dl:m;Z;m<Z< dZ=ej>?de= dZ@eA ZBejCDd ZEeEFeBjGjHZIeBjJeI eB_JeBjGeI eB_GdZKdd ZLdd ZMdd ZNdd  ZOd!d" ZPd#d$ ZQd%d& ZRd'd( ZSd)d* ZTd+d, ZUd-d. ZVd/d0 ZWd1d2 ZXd3d4 ZYd5d6 ZZd7d8 Z[d9d: Z\d;d< Z]d=d> Z^d?d@ Z_dAdB Z`dCdD ZadEdF ZbdGdH ZcdIdJ ZddKdL ZedMdN ZfdOdP ZgdQdR ZhdSdT ZidUdV ZjdWdX ZkdYdZ Zld[d\ Zmd]d^ Znd_d` Zodadb Zpej>qdce#e"gddde Zrej>qdce#e"gdfdg Zsej>qdce#e"gdhdi Ztej>qdjejuejvgdkdl Zwdmdn ZxdS )o    )escapeN)assert_allclose)datasetssvm)load_breast_cancer)NotFittedError)SimpleImputer)
ElasticNetLassoLinearRegressionLogisticRegression
PerceptronRidgeSGDClassifier)precision_scorerecall_score)GridSearchCVcross_val_score)OneVsOneClassifierOneVsRestClassifierOutputCodeClassifier)MultinomialNB)KNeighborsClassifier)Pipelinemake_pipeline)SVC	LinearSVC)DecisionTreeClassifierDecisionTreeRegressor)check_arrayshuffle)CheckingClassifier)assert_almost_equalassert_array_equal)check_classification_targetstype_of_targetz/The default value for `force_alpha` will changezignore:z:FutureWarning   c               	   C   s   t tddd} tt | g  W 5 Q R X d}tjt|dD tddgddgg}tddgddgg}t t	 
|| W 5 Q R X tjt|dD tddgddgg}td	d
gddgg}t t	 
|| W 5 Q R X d S )Nautor   dualrandom_statez@Multioutput target data is not supported with label binarizationmatch      r&   g      ?g333333@g@皙?)r   r   pytestraisesr   predict
ValueErrornparrayr   fit)ovrmsgXy r;   \/var/www/html/Darija-Ai-API/env/lib/python3.8/site-packages/sklearn/tests/test_multiclass.pytest_ovr_exceptions4   s    r=   c               	   C   s@   t ddddg} t| }tjt|d t|  W 5 Q R X d S )N        g?       @g      @r+   )r4   r5   r%   r0   r1   r3   r$   )r:   r8   r;   r;   r<   !test_check_classification_targetsH   s    r@   c                  C   s   t tddd} | tjtjtj}t| jt	ks:t
tddd}|tjtjtj}ttj|kttj|kkst
t t } | tjtjtj}ttj|kdkst
d S )Nr'   r   r(   ?)r   r   r6   irisdatatargetr2   lenestimators_	n_classesAssertionErrorr4   meanr   )r7   predclfpred2r;   r;   r<   test_ovr_fit_predictP   s    $
rM   c                  C   s  t tjtjdd\} }tt }|| d d |d d t| || dd  |dd   |	| }tt }|
| |	| }t|| t|jtt|kstt||kdkstttjdd} ddddddddddddddg}ttdd d	dd
}|| d d |d d t| || dd  |dd   |	| }ttdd d	dd
}|
| |	| }t||kt||kksttt }t|drtd S )Nr   r*   d   rA      r.   r-   r&   F)max_itertolr    r*      partial_fit)r    rB   rC   rD   r   r   rT   r4   uniquer2   r6   r"   rE   rF   rH   rI   absrandomrandnr   r   hasattr)r9   r:   r7   rJ   Zovr2rL   Zovr1pred1r;   r;   r<   test_ovr_partial_fit`   s2    
$


 $
"
r[   c                  C   s   t t } ttjdd}ddddddddddddddg}| |d d |d d t| dg|dd  }d	}tj	t
|d
 | j|dd  |d W 5 Q R X d S )NrP   r.   r-   r&   r   rS      zAMini-batch contains \[.+\] while classes must be subset of \[.+\]r+   )r9   r:   )r   r   r4   rV   rW   rX   rT   rU   r0   r1   r3   )r7   r9   r:   y1r8   r;   r;   r<   test_ovr_partial_fit_exceptions   s    
 $r_   c                  C   s   t t } | tjtjtj}t| jt	ks4t
tt|dddg t|tjkdksbt
tt } | tjtjtj}t| jt	t	d  d kst
tt|dddg t|tjkdkst
d S )Nr   r-   r.   ?)r   r   r6   rB   rC   rD   r2   rE   rF   rG   rH   r#   r4   rU   rI   r   )r7   rJ   r;   r;   r<   test_ovr_ovo_regressor   s    

ra   c               
   C   s2  t jt jt jt jt jfD ]} tdd}tjddddddd	d
\}}|d d |d d  }}|dd  }t	|
||}||}t	|
|| |}	|	|}
|jstt |
stt|
 | |	|}|dk}t||
  t }t	|
|| |}	|	|d	kt}t||	|  qd S )Nr-   alpharO      r\   r&   2   Tr   	n_samples
n_featuresrG   Zn_labelslengthZallow_unlabeledr*   P         ?)spZ
csr_matrix
csc_matrixZ
coo_matrixZ
dok_matrixZ
lil_matrixr   r   make_multilabel_classificationr   r6   r2   multilabel_rH   issparser#   Ztoarraypredict_probar   r   decision_functionastypeint)sparsebase_clfr9   YX_trainY_trainX_testrK   Y_predZclf_sprsZY_pred_sprsY_probarJ   Zdec_predr;   r;   r<   test_ovr_fit_predict_sparse   s@    







r}   c               	   C   s  t d} d| d dd d f< t d}d|dd df< d|d d df< d|d d df< tt }d}tjt|d || | W 5 Q R X |	| }t
t |t | || }t |d d d	d f dkst|| }t
|d d d
f t | jd  t d}d|dd df< tt }d}tjt|d || | W 5 Q R X || }t
|d d d
f t | jd  d S )N
   r.   r   r\   )r   r&   r-   r.   z,Label .+ is present in all training examplesr+   r]   z/Label not 1 is present in all training examples)r4   oneszerosr   r   r0   ZwarnsUserWarningr6   r2   r#   r5   rr   rU   rH   rq   shape)r9   r:   r7   r8   y_predr;   r;   r<   test_ovr_always_present   s0    




"
"


r   c               	   C   s&  t dddgdddgdddgdddgdddgg} dddddg}t dddgdddgdddgdddgdddgg}td	 }t td
ddt t t fD ]}t	|
| |}t|j|kst|t dddggd }t|dg t	|
| |}|dddggd }t|dddg qd S )Nr   r\   r&      eggsspamZhamr-   zham eggs spamr'   r(      )r4   r5   setsplitr   r   r   r   r	   r   r6   classes_rH   r2   r#   )r9   r:   rw   classesrv   rK   r   r;   r;   r<   test_ovr_multiclass   s"    22
r   c               	      s   t dddgdddgdddgdddgdddgg dddddgt dddddggjtd d fd
d	} tdddt t t fD ]}| | qt	 t
ddt fD ]}| |dd qd S )Nr   r\   r&   r   r   r   r-   z	eggs spamFc                    s   t |  }t|jks"t|tdddggd }t|dg t	| drl|
 }|jdkslt|rtdddgg}||}dt|d kst|jtj|dd ||kstt |  }|d	ddggd }|dkstd S )
Nr   r   r   rr   )r\   r.   r-   Zaxisr&   )r   r6   r   r   rH   r2   r4   r5   r#   rY   rr   r   rq   rE   argmax)rv   test_predict_probarK   r   decrz   Zprobabilitiesr9   rw   r   r:   r;   r<   conduct_test  s    


"z%test_ovr_binary.<locals>.conduct_testr'   r(   Tprobability)r   )F)r4   r5   Tr   r   r   r   r   r	   r   r   r   )r   rv   r;   r   r<   test_ovr_binary  s    2

r   c               	   C   s   t dddgdddgdddgdddgdddgg} t dddgdddgdddgdddgdddgg}t tdddt t t td	d
fD ]D}t|	| |}|
dddggd }t|dddg |jstqd S )Nr   r   r\   r&   r   r-   r'   r(   rk   rb   )r4   r5   r   r   r   r   r	   r
   r   r6   r2   r#   ro   rH   )r9   r:   rv   rK   r   r;   r;   r<   test_ovr_multilabel9  s    22
r   c                  C   sJ   t t } | tjtj t| jdks.t	| 
tjtjdksFt	d S )Nr&   r`   )r   r   r   r6   rB   rC   rD   rE   rF   rH   score)r7   r;   r;   r<   test_ovr_fit_predict_svcL  s    r   c               
   C   s   t dd} tdddD ]\}}}tjdddd	d
|dd\}}|d d |d d  }}|dd  |dd   }}	t| ||}
|
|}|
jstt	t
|	|dd|d	d t	t|	|dd|d	d qd S )Nr-   rb   )TF)RQ?gQ?)r   r/   rO   rd   r\   r.   re   r   rf   rj   micro)Zaverage)decimal)r   zipr   rn   r   r6   r2   ro   rH   r"   r   r   )rv   auprecZrecallr9   rw   rx   ry   rz   ZY_testrK   r{   r;   r;   r<   test_ovr_multilabel_datasetS  s4    

	

    r   c               
   C   sH  t dd} dD ]2}tjddddd|d	d
\}}|d d |d d  }}|dd  }t| ||}tt ||}t|drtttj	dd}t|drt||| t|drtt|dstt
tj	ddddgid}	t|	}
t|
dr t|
|| t|
dst||}||}|dk}t|| qd S )Nr-   rb   )FTrO   rd   r\   r&   re   r   rf   rj   rq   Fr   rr   r   T)Z
param_gridrk   )r   r   rn   r   r6   r   SVRrY   rH   r   r   r2   rq   r#   )rv   r   r9   rw   rx   ry   rz   rK   decision_onlygsZproba_after_fitr{   r|   rJ   r;   r;   r<   !test_ovr_multilabel_predict_probam  sB    


	
 

r   c                  C   s   t dd} tjtj }}|d d |d d  }}|dd  }t| ||}tt ||}t|drpt	|
|}||}	t|	jddd |	jdd}
|
|  rt	d S )Nr-   rb   rj   rq   r         ?)r   rB   rC   rD   r   r6   r   r   rY   rH   r2   rq   r"   sumr   any)rv   r9   rw   rx   ry   rz   rK   r   r{   r|   rJ   r;   r;   r<   #test_ovr_single_label_predict_proba  s    


r   c               	   C   sz   t jdddddddd\} }| d d	 |d d	  }}| d	d  }tt ||}t||dkt	|
| d S )
NrO   rd   r\   r&   re   Tr   rf   rj   )r   rn   r   r   r   r6   r#   rr   rs   rt   r2   r9   rw   rx   ry   rz   rK   r;   r;   r<   %test_ovr_multilabel_decision_function  s     
	 r   c                  C   sp   t jdddd\} }| d d |d d  }}| dd  }tt ||}t|| dk|	| d S )NrO   rd   r   )rg   rh   r*   rj   )
r   Zmake_classificationr   r   r   r6   r#   rr   Zravelr2   r   r;   r;   r<   'test_ovr_single_label_decision_function  s
    r   c                  C   sV   t tddd} dddg}t| d|i}|tjtj |jjd j	}||ksRt
d S Nr'   r   r(   皙?rk   r/   estimator__C)r   r   r   r6   rB   rC   rD   best_estimator_rF   CrH   )r7   Cscvbest_Cr;   r;   r<   test_ovr_gridsearch  s    
r   c                  C   s`   t dt fg} t| }|tjtj tt }|tjtj t|tj|tj d S )Ntree)	r   r   r   r6   rB   rC   rD   r#   r2   )rK   Zovr_piper7   r;   r;   r<   test_ovr_pipeline  s    
r   c               	   C   s4   t tddd} tt | g  W 5 Q R X d S Nr'   r   r(   )r   r   r0   r1   r   r2   ovor;   r;   r<   test_ovo_exceptions  s    r   c                  C   s^   t tddd} | tjtjtj}dd tjD }| |ttj|}t|| d S )Nr'   r   r(   c                 S   s   g | ]}t |qS r;   )list).0ar;   r;   r<   
<listcomp>  s     z(test_ovo_fit_on_list.<locals>.<listcomp>)	r   r   r6   rB   rC   rD   r2   r   r#   )r   Zprediction_from_arrayZiris_data_listZprediction_from_listr;   r;   r<   test_ovo_fit_on_list  s    r   c                  C   s   t tddd} | tjtjtj t| jt	t	d  d ksFt
t t } | tjtjtj t| jt	t	d  d kst
d S )Nr'   r   r(   r-   r.   )r   r   r6   rB   rC   rD   r2   rE   rF   rG   rH   r   r   r;   r;   r<   test_ovo_fit_predict  s    
r   c                  C   s  t  } | j| j }}tt }||d d |d d t| ||dd  |dd   |	|}tt }|
|| |	|}t|jttd  d kstt||kdkstt|| tt }||d d |d d t| ||dd  |dd   |	|}tt }|
||	|}t|| t|jtt|ks`tt||kdksxttt }tjdd}dddddddd	d	d	d	d	ddg}||d d
 |d d
 ddddd	g ||d
d  |d
d   |	|}tt }|
||	|}t|| tt }ddddd	ddg}	tdt|	t|}
tjt|
d" ||d d
 |	t| W 5 Q R X tt }t|drtd S )NrO   r-   r.   rA   <   rP   r&   r   r   rS   r\   z6Mini-batch contains {0} while it must be subset of {1}r+   rT   )r   	load_irisrC   rD   r   r   rT   r4   rU   r2   r6   rE   rF   rG   rH   rI   r"   rW   Zrandr   formatr0   r1   r3   r   rY   )tempr9   r:   Zovo1rZ   Zovo2rL   r   rJ   Zerror_yZ
message_rer7   r;   r;   r<   test_ovo_partial_fit_predict  sT    
$




$



 (



 &
r   c            	      C   s  t jjd } ttddd}|t jt jdk |t j}|j| fksLt|t jt j |t j}|j| t	fksztt
|jdd|t j t| t	f}d}tt	D ]b}t|d t	D ]N}|j| t j}||dk|f  d7  < ||dk|f  d7  < |d7 }qqt
|t| tt	D ]T}t|d d |f tdddgsXttt|d d |f d	ks*tq*d S )
Nr   r'   r(   r-   r   r>   r   r?      )rB   rC   r   r   r   r6   rD   rr   rH   rG   r#   r   r2   r4   r   rangerF   roundr   issubsetrE   rU   )	rg   Zovo_clfZ	decisionsvoteskijrJ   Z	class_idxr;   r;   r<   test_ovo_decision_function1  s*    *r   c                  C   sV   t tddd} dddg}t| d|i}|tjtj |jjd j	}||ksRt
d S r   )r   r   r   r6   rB   rC   rD   r   rF   r   rH   )r   r   r   r   r;   r;   r<   test_ovo_gridsearch`  s    
r   c                  C   s   t ddgddgddgddgg} t ddddg}ttddd d}|| || }|| }t |}|| }t|dd d f d tt j	|dd  dd	|dd   |d |d 	 kst
d S )
Nr-   r.   r   r]   r   Fr   r    rQ   rR   r   )r4   r5   r   r   r6   r2   rr   r   r#   r   rH   )r9   r:   	multi_clfovo_predictionZovo_decisionr   Znormalized_confidencesr;   r;   r<   test_ovo_tiesi  s    "

$r   c                  C   s   t ddgddgddgddgg} t ddddg}tdD ]H}|| d }ttddd d	}|| || }|d |d ks<tq<d S )
Nr-   r.   r   r]   r   r&   Fr   r   )r4   r5   r   r   r   r6   r2   rH   )r9   Zy_refr   r:   r   r   r;   r;   r<   test_ovo_ties2  s    "r   c                  C   sJ   t d} t ddddg}ttdd}|| | t|||  d S )Nr   r   bcdr'   r)   )r4   eyer5   r   r   r6   r#   r2   )r9   r:   r   r;   r;   r<   test_ovo_string_y  s
    
r   c               	   C   sV   t d} t dgd }ttdd}d}tjt|d || | W 5 Q R X d S )Nr   r   r'   r   zwhen only one classr+   )	r4   r   r5   r   r   r0   r1   r3   r6   r9   r:   r   r8   r;   r;   r<   test_ovo_one_class  s    
r   c               	   C   sT   t j} t jd d df }ttdd}d}tjt|d || | W 5 Q R X d S Nr   r'   r   zUnknown label typer+   )rB   rC   r   r   r0   r1   r3   r6   r   r;   r;   r<   test_ovo_float_y  s    r   c               	   C   s4   t tddd} tt | g  W 5 Q R X d S r   )r   r   r0   r1   r   r2   ecocr;   r;   r<   test_ecoc_exceptions  s    r   c                  C   s   t tdddddd} | tjtjtj t| jt	d ksDt
t t ddd} | tjtjtj t| jt	d kst
d S )Nr'   r   r(   r.   )Z	code_sizer*   )r   r   r6   rB   rC   rD   r2   rE   rF   rG   rH   r   r   r;   r;   r<   test_ecoc_fit_predict  s    
  r   c                  C   sZ   t tddddd} dddg}t| d|i}|tjtj |jjd j	}||ksVt
d S )	Nr'   r   r(   rN   r   rk   r/   r   )r   r   r   r6   rB   rC   rD   r   rF   r   rH   )r   r   r   r   r;   r;   r<   test_ecoc_gridsearch  s    
r   c               	   C   sT   t j} t jd d df }ttdd}d}tjt|d || | W 5 Q R X d S r   )rB   rC   r   r   r0   r1   r3   r6   r   r;   r;   r<   test_ecoc_float_y  s    r   c               	   C   s   t jt j } }t| }ttdddd}t|dd}tj	t
dd ||| W 5 Q R X || | tj	t
dd || W 5 Q R X ttd	dd
}|||| t|jdkstd S )NTF)Z	ensure_2dZaccept_sparse)Zcheck_XZcheck_X_paramsr   rN   zA sparse matrix was passedr+   r'   r(   r   )rB   rC   rD   rl   rm   r!   r   r   r0   r1   	TypeErrorr6   r2   r   rE   rF   rH   )r9   r:   ZX_spZbase_estimatorr   r;   r;   r<   (test_ecoc_delegate_sparse_base_estimator  s    
r   c                  C   s~   t jdd} tjtj }}t| }t||j}|	|| t
|j}|j}|D ](}|jd | |d  |jd ksPtqPd S )NprecomputedZkernelr   r-   )r   r   rB   rC   rD   r   r4   dotr   r6   rE   rF   Zpairwise_indices_r   rH   )clf_precomputedr9   r:   	ovr_falselinear_kernelZn_estimatorsZprecomputed_indicesidxr;   r;   r<   test_pairwise_indices  s    
r   c            
      C   s   t jt j } }|d dkst| dd } |dd }| jdksDttjdd| |}|jdksftt	|| |}|jdkst|j
D ]}|jdkstqt|| |}|jdkst|jdkstt|j
dkst|j
D ]}|jdkstq| | j }|jd	ksttjd
d||}|jdks4tt	|||}|jdksTt|jdksdtt|j
dksxt|j
D ]}|jdks~tq~t|||}	|	jdkst|jdkstt|j
dkst|	j
d jdkst|	j
d jdkst|	j
d jdkstdS )a  Check the n_features_in_ attributes of the meta and base estimators

    When the training data is a regular design matrix, everything is intuitive.
    However, when the training data is a precomputed kernel matrix, the
    multiclass strategy can resample the kernel matrix of the underlying base
    estimator both row-wise and column-wise and this has a non-trivial impact
    on the expected value for the n_features_in_ of both the meta and the base
    estimators.
    r]   r   N)   r   linearr   r   r&   )r   r   r   r   c   r-   r.   rO   )rB   rC   rD   rH   r   r   r   r6   Zn_features_in_r   rF   r   Z
n_classes_rE   r   )
r9   r:   clf_notprecomputedZovr_notprecomputedZestZovo_notprecomputedKr   Zovr_precomputedZovo_precomputedr;   r;   r<   test_pairwise_n_features_in  sD    




r   MultiClassClassifierc                 C   sH   t jdd}t  }| |}| d r,t| |}| d sDtd S )Nr   r   pairwise)r   r   Z	_get_tagsrH   )r   r   r   r   Zovr_truer;   r;   r<   test_pairwise_tagF  s    r   c           
      C   sr   t jdd}t jdd}tjtj }}| |}| |}t||j}t|||dd}t|||dd}	t	|	| d S )Nr   r   r   raise)Zerror_score)
r   r   rB   rC   rD   r4   r   r   r   r#   )
r   r   r   r9   r:   Zmulticlass_clf_notprecomputedZmulticlass_clf_precomputedr   Zscore_not_precomputedZscore_precomputedr;   r;   r<   test_pairwise_cross_val_scoreT  s&          r   c                 C   s|   t jd}tjtj }}t |}|jddg|jddgd	t
}t j||< tt t|d}| ||||| d S )N*   r-   r   r   r`   )prN   )r4   rW   RandomStaterB   rC   rD   copychoicer   rs   boolnanr   r   r   r6   r   )r   rngr9   r:   masklrr;   r;   r<   test_support_missing_valuesj  s    	
 
r  make_yc                 C   sj   t d}| dt jd}tt }||| ||}t |jd df}d|dddf< t	|| dS )zUCheck that constant y target does not raise.

    Non-regression test for #21869
    r~   )r   r-   )Zdtyper   r.   r-   N)
r4   r   Zint32r   r   r6   rq   r   r   r   )r  r9   r:   r7   r   expectedr;   r;   r<   test_constant_int_target}  s    


r
  c                  C   sT   t dd\} }tddd}t|}|| | || | t|| ||  dS )z^Check that ovo is consistent with binary classifier.

    Non-regression test for #13617.
    T)Z
return_X_y   Zdistance)Zn_neighborsweightsN)r   r   r   r6   r#   r2   )r9   r:   rK   r   r;   r;   r<   )test_ovo_consistent_binary_classification  s    r  )yrer   numpyr4   r0   Zscipy.sparseru   rl   Znumpy.testingr   Zsklearnr   r   Zsklearn.datasetsr   Zsklearn.exceptionsr   Zsklearn.imputer   Zsklearn.linear_modelr	   r
   r   r   r   r   r   Zsklearn.metricsr   r   Zsklearn.model_selectionr   r   Zsklearn.multiclassr   r   r   Zsklearn.naive_bayesr   Zsklearn.neighborsr   Zsklearn.pipeliner   r   Zsklearn.svmr   r   Zsklearn.treer   r   Zsklearn.utilsr   r    Zsklearn.utils._mockingr!   Zsklearn.utils._testingr"   r#   Zsklearn.utils.multiclassr$   r%   r8   markfilterwarningsZ
pytestmarkr   rB   rW   r   r  ZpermutationrD   sizepermrC   rG   r=   r@   rM   r[   r_   ra   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zparametrizer   r   r  r   r   r
  r  r;   r;   r;   r<   <module>   s   $	$0$(-	5/	
	E 
 
 

