U
    -e                    @   s  d dl Z d dlmZ d dlZd dlZd dlZd dlmZ	 d dl
mZmZmZ d dlmZmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZ d d
lmZ d dlm Z m!Z!m"Z"m#Z# d dl$m%Z% d dl&m'Z'm(Z(m)Z)m*Z*m+Z+ dd Z,G dd dej-Z.G dd dej/Z0G dd dej1Z2dd Z-dd Z/dd Z1dd Z3dd Z4d d! Z5e6d"d#gd#d#gd#d"gd$d$gd$d%gd%d$ggZ7d$d$d$d%d%d%gZ8e6d#d#gd%d%gd&d%ggZ9d$d%d%gZ:e6d#d$gd'd(gd)d*gd$d$gd+d(gd*d*gd#d#gd d,gd$d#gg	Z;d-gd& d.gd&  d/gd&  Z<e6d)d(gd$d%gd d"ggZ=d-d.d/gZ>e6d$d$d d d d gd$d$d d d d gd d d$d d d gd d d$d d d gd d d d d$d$gd d d d d$d$gd d d d$d d gd d d d$d d ggZ?e6d$d$d$d$d%d%d%d%gZ@e6d$d0d1d d d gd$d2d3d d d gd$d4d5d d d gd$d6d7d d d gd d d d8d6d$gd d d d9d2d$gd d d d6d:d$gd d d d;d$d$ggZAe6d$d$d$d$d%d%d%d%gZBeC ZDe6d"d#gd#d#gd#d"gd$d$gd$d%gd%d$ggZEd$d$d$d%d%d%gZFd d$d$gZGd
d=d>ZHd?d@ ZIejJKdAe-e3e/e4gejJKdBdCdDdEdFgdGdH ZLejJKdAe-e3e/e4gdIdJ ZMejJKdAe-e3e/e4gdKdL ZNejJKdAe-e3e/e4e1e5gdMdN ZOejJKdAe-e3e/e4e1e5gdOdP ZPejJKdAe-e3e/e4gdQdR ZQejJKdAe-e3e/e4gdSdT ZRejJKdAe-e3e/e4gdUdV ZSejJKdAe-e3e/e4gdWdX ZTejJKdAe-e3e/e4gdYdZ ZUejJKdAe-e3e/e4gd[d\ ZVejJKdAe-e3gd]d^ ZWejJKdAe-e3e1e5gd_d` ZXejJKdae-dbeYdcife3dbeYdcife1ddeYdcife5ddeYdcifgdedf ZZejJKdAe-e3e/e4gdgdh Z[ejJKdae-dbd ife3dbd ife1ddd ife5ddd ifgdidj Z\ejJKdAe-e3gdkdl Z]ejJKdAe-e3gdmdn Z^ejJKdAe-e3gdodp Z_ejJKdAe-e3gdqdr Z`ejJKdAe-e3gdsdt ZaejJKdAe-e3gdudv ZbejJKdAe-e3gdwdx ZcejJKdAe-e3gdydz ZdejJKdAe-e3gd{d| ZeejJKdAe-e3gd}d~ ZfejJKdAe-e3gdd ZgejJKdAe-e3gdd ZhejJKdAe-e3gdd ZiejJKdAe-e3gdd ZjejJKdAe-e3gdd ZkejJKdAe-e3gdd ZlejJKdAe-e3gdd ZmejJKdAe-e3gdd ZnejJKdAe-e3e1e5gdd ZoejJKdAe-e3gdd ZpejJKdAe-e3gdd ZqejJKdAe-e3gdd ZrejJKdAe-e3gdd ZsejJKdAe-e3gdd ZtejJKdAe-e3gejJKdBdCdDdEdFgdd ZuejJKdAe-e3gdd ZvejJKdAe-e3gdd ZwejJKdAe-e3gdd ZxejJKdAe/e4gdd ZyejJKdAe/e4gdd ZzejJKdAe/e4gdd Z{ejJKdAe/e4gdd Z|ejJKdAe/e4gdd Z}ejJKdAe/e4gdd Z~ejJKdAe/e4gdd ZejJKdAe/e4gdd Ze+ejJKdAe/e4gdd ZejJKdAe/e4gejJKdBdCdDdEdFgdd ZejJKdAe/e4gdd ZdddZejJKdAe1e5gdd ZejJKdAe1e5gejJKdBdCdDdEdFgdd ZejJKdAe1e5gdd ZejJKdAe1e5gdd ZejJKdAe1e5gejJKdBdCdDdEdFgddĄ ZejJKdAe1e5gddƄ ZejJKdAe1e5gddȄ ZejJKdAe1e5gddʄ ZejJKdAe1e5gdd̄ Zdd΄ ZddЄ Zdd҄ ZddԄ Zddք ZejJKddddgdd܄ Zddބ Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd ZejJKddddgdd ZejJKdej-ej/gdd Zdd Zd d ZejJKde-e/gdd ZejJKde-e3e/e4e1e5gejJKdejejfdd ZejJKde-e3e/e4e1e5gdd	 ZdS (      N)Mock)datasetslinear_modelmetrics)cloneis_classifier)ConvergenceWarning)Nystroem)	_sgd_fast)_stochastic_gradient)RandomizedSearchCVShuffleSplitStratifiedShuffleSplit)make_pipeline)LabelEncoderMinMaxScalerStandardScalerscale)OneClassSVM)assert_allcloseassert_almost_equalassert_array_almost_equalassert_array_equalignore_warningsc                 C   s4   d| krd| d< d| kr d | d< d| kr0d| d< d S )Nrandom_state*   tolmax_iter    kwargsr   r   d/var/www/html/Darija-Ai-Train/env/lib/python3.8/site-packages/sklearn/linear_model/tests/test_sgd.py_update_kwargs    s    r#   c                       s@   e Zd Z fddZ fddZ fddZ fddZ  ZS )	_SparseSGDClassifierc                    s    t |}t j||f||S N)sp
csr_matrixsuperfitselfXyargskw	__class__r   r"   r)   +   s    
z_SparseSGDClassifier.fitc                    s    t |}t j||f||S r%   )r&   r'   r(   partial_fitr*   r0   r   r"   r2   /   s    
z _SparseSGDClassifier.partial_fitc                    s   t |}t |S r%   )r&   r'   r(   decision_functionr+   r,   r0   r   r"   r3   3   s    
z&_SparseSGDClassifier.decision_functionc                    s   t |}t |S r%   )r&   r'   r(   predict_probar4   r0   r   r"   r5   7   s    
z"_SparseSGDClassifier.predict_proba)__name__
__module____qualname__r)   r2   r3   r5   __classcell__r   r   r0   r"   r$   *   s   r$   c                   @   s$   e Zd Zdd Zdd Zdd ZdS )_SparseSGDRegressorc                 O   s"   t |}tjj| ||f||S r%   )r&   r'   r   SGDRegressorr)   r*   r   r   r"   r)   =   s    
z_SparseSGDRegressor.fitc                 O   s"   t |}tjj| ||f||S r%   )r&   r'   r   r;   r2   r*   r   r   r"   r2   A   s    
z_SparseSGDRegressor.partial_fitc                 O   s    t |}tjj| |f||S r%   )r&   r'   r   r;   r3   r+   r,   r.   r/   r   r   r"   r3   E   s    
z%_SparseSGDRegressor.decision_functionNr6   r7   r8   r)   r2   r3   r   r   r   r"   r:   <   s   r:   c                   @   s$   e Zd Zdd Zdd Zdd ZdS )_SparseSGDOneClassSVMc                 O   s    t |}tjj| |f||S r%   )r&   r'   r   SGDOneClassSVMr)   r<   r   r   r"   r)   L   s    
z_SparseSGDOneClassSVM.fitc                 O   s    t |}tjj| |f||S r%   )r&   r'   r   r?   r2   r<   r   r   r"   r2   P   s    
z!_SparseSGDOneClassSVM.partial_fitc                 O   s    t |}tjj| |f||S r%   )r&   r'   r   r?   r3   r<   r   r   r"   r3   T   s    
z'_SparseSGDOneClassSVM.decision_functionNr=   r   r   r   r"   r>   K   s   r>   c                  K   s   t |  tjf | S r%   )r#   r   SGDClassifierr    r   r   r"   r@   Y   s    r@   c                  K   s   t |  tjf | S r%   )r#   r   r;   r    r   r   r"   r;   ^   s    r;   c                  K   s   t |  tjf | S r%   )r#   r   r?   r    r   r   r"   r?   c   s    r?   c                  K   s   t |  tf | S r%   )r#   r$   r    r   r   r"   SparseSGDClassifierh   s    rA   c                  K   s   t |  tf | S r%   )r#   r:   r    r   r   r"   SparseSGDRegressorm   s    rB   c                  K   s   t |  tf | S r%   )r#   r>   r    r   r   r"   SparseSGDOneClassSVMr   s    rC            g            ?g            ?g      ?      onetwothree?皙?gzG?g\(\?Q?g)\(?gQ?Gz?g{Gz?gHzG?gffffff?g(\?        c                 C   s   |d krt |jd }n|}t |jd }|}	d}
d}| ttfkrJd}t|D ]\}}t ||}||	7 }|||  }|d||  9 }||| |  7 }|	||  | 7 }	||9 }||7 }||d  }|
|9 }
|
|	7 }
|
|d  }
qR||
fS )NrF   rS         ?{Gz?)npzerosshaperA   rB   	enumeratedot)klassr,   r-   etaalphaweight_initintercept_initweightsaverage_weights	interceptaverage_interceptdecayientrypgradientr   r   r"   asgd   s.    ri   c                 C   s   | ddd|d}| || | ddd|d}|j |||j |j d | dddd|d}| || |j|jksxtt|j|j |jdd | || |j|jkstt|j|j d S )	NrU   F)r]   eta0shufflelearning_rateMbP?	coef_initr_   T)r]   rj   rk   
warm_startrl   r]   )r)   coef_copy
intercept_t_AssertionErrorr   
set_params)r[   r,   Ylrclfclf2clf3r   r   r"   _test_warm_start   s$        r}   r[   ry   constantoptimalZ
invscalingadaptivec                 C   s   t | tt| d S r%   )r}   r,   rx   r[   ry   r   r   r"   test_warm_start   s    r   c              	   C   sd   | ddd}| tt ttd d tjf }tj||f }tt	 | t| W 5 Q R X d S )NrU   Fr]   rk   )
r)   r,   rx   rV   arrayZnewaxisZc_pytestraises
ValueError)r[   rz   Y_r   r   r"   test_input_format   s    r   c                 C   sV   | ddd}t |}|jdd |tt | ddd}|tt t|j|j d S )NrU   l1)r]   penaltyl2)r   )r   rw   r)   r,   rx   r   rr   r[   rz   r{   r   r   r"   
test_clone  s    r   c                 C   s   | ddd}| tt t|ds&tt|ds4tt|dsBtt|dsPt|  }| tt t|drptt|dr~tt|drtt|drtd S )NTrU   )averagerj   Z_average_coefZ_average_interceptZ_standard_interceptZ_standard_coef)r)   r,   rx   hasattrrv   r[   rz   r   r   r"   test_plain_has_no_average_attr  s    r   c                 C   s   | dd}|  }t dD ]R}t|rR|jttttd |jttttd q|tt |tt qt|j|jdd | t	t
ttfkrt|j|jdd n| ttfkrt|j|j d S )NiX  r   d   classes   decimal)ranger   r2   r,   rx   rV   uniquer   rr   r@   rA   r;   rB   r   rt   r?   rC   r   offset_)r[   clf1r{   _r   r   r"   %test_late_onset_averaging_not_reached:  s    
r   c              	   C   s   d}d}t t}d||dk< d||dk< | ddd	||dd
d}| ddd	||dd
d}|t| |t| t| t||||j |jd\}}t	|j | dd t
|j|dd d S )Nrm   -C6?      rF   rT   rG      r~   squared_errorF)r   rl   lossrj   r]   r   rk   r   )r^   r_   r   r   )rV   r   rx   r)   r,   ri   rr   ravelrt   r   r   )r[   rj   r]   ZY_encoder   r{   ra   rc   r   r   r"   !test_late_onset_averaging_reachedW  sH    
	


r   c                 C   sV   t jt jdk }t jt jdk }dD ],}d}| |d|d||}|j|k s$tq$d S )Nr   TF  rm   )early_stoppingr   r   )irisdatatargetr)   n_iter_rv   )r[   r,   rx   r   r   rz   r   r   r"   test_early_stopping  s     r   c                 C   sT   | ddddd}| tjtj | ddddd}| tjtj |j|jksPtd S )Nr   rU   rm   r   )rl   rj   r   r   r~   )r)   r   r   r   r   rv   )r[   r   r{   r   r   r"   "test_adaptive_longer_than_constant  s
    r   c              
   C   s   t jt j }}d}d}d}d}| dtj||ddd ||d}||| |j|ksXt| dtj|ddd ||d	}t	|rt
||d
}	nt||d
}	t|	||\}
}t|
}
|||
 ||
  |j|kstt|j|j d S )N皙?r   F
   Tr~   rU   )r   r   validation_fractionrl   rj   r   r   rk   )r   r   rl   rj   r   r   rk   )Z	test_sizer   )r   r   r   rV   randomRandomStater)   r   rv   r   r   r   nextsplitsortr   rr   )r[   r,   rx   r   seedrk   r   r   r{   ZcvZ	idx_trainZidx_valr   r   r"   )test_validation_set_not_used_for_training  sD    




r   c                    sB   t jt j  dD ]* fdddD }t|t| qd S )Nr   c                    s&   g | ]}|d dd  jqS )r   r   )r   n_iter_no_changer   r   )r)   r   ).0r   r,   rx   r   r[   r   r"   
<listcomp>  s   	 z)test_n_iter_no_change.<locals>.<listcomp>)rG   rH   r   )r   r   r   r   sorted)r[   Zn_iter_listr   r   r"   test_n_iter_no_change  s    	r   c              	   C   s2   | ddd}t t |tt W 5 Q R X d S )NTrR   )r   r   )r   r   r   r)   X3Y3r   r   r   r"   )test_not_enough_sample_for_early_stopping  s    r   c              	   C   s>   dD ]4}| ddd|ddd}| tt t|tt qd S )N)hingesquared_hingelog_lossmodified_huberr   rU   Tr   )r   r]   fit_interceptr   r   rk   )r)   r,   rx   r   predictTtrue_result)r[   r   rz   r   r   r"   test_sgd_clf  s    r   c              	   C   s6   t jtdd |  jtttdd W 5 Q R X dS )z1Check that the shape of `coef_init` is validated.z)Provided coef_init does not match datasetmatchrH   ro   N)r   r   r   r)   r,   rx   rV   rW   r[   r   r   r"   test_provide_coef  s    r   zklass, fit_paramsr_   r   offset_initc              	   C   s4   |  }t jtdd |jttf| W 5 Q R X dS )z:Check that `intercept_init` or `offset_init` is validated.zdoes not match datasetr   Nr   r   r   r)   r,   rx   )r[   
fit_paramsZsgd_estimatorr   r   r"   test_set_intercept_offset  s    r   c              	   C   s4   d}t jt|d | ddtt W 5 Q R X dS )zSCheck that we raise an error for `early_stopping` used with
    `partial_fit`.
    z/early_stopping should be False with partial_fitr   T)r   N)r   r   r   r2   r,   rx   )r[   err_msgr   r   r"   (test_sgd_early_stopping_with_partial_fit  s    r   c                 C   s   |  j ttf| dS )zdCheck that we can pass a scaler with binary classification to
    `intercept_init` or `offset_init`.N)r)   X5Y5)r[   r   r   r   r"    test_set_intercept_offset_binary$  s    r   c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}| dd||d	d
d	dd}t ||}	t |	}	|||	 t| ||	||\}
}|
d
d}
t	|j
|
dd t|j|dd d S )N皙?       @   r   r   sizer   r~   TrF   Fr   rl   rj   r]   r   r   r   rk   rE      r   )rV   r   r   normalrZ   signr)   ri   reshaper   rr   r   rt   )r[   r\   r]   	n_samples
n_featuresrngr,   wrz   r-   ra   rc   r   r   r"   &test_average_binary_computed_correctly3  s0    
r   c                 C   sH   |   tt}|  j tt|jd |   tt}|  j tt|jd d S )Nr_   )r)   r   r   rt   r,   rx   r   r   r   r"   test_set_intercept_to_interceptU  s    r   c              	   C   s8   | ddd}t t |ttd W 5 Q R X d S )NrU   r   r]   r   	   )r   r   r   r)   X2rV   onesr   r   r   r"   test_sgd_at_least_two_labels_  s    r   c              	   C   s>   d}t jt|d" | ddjttttd W 5 Q R X d S )Na`  class_weight 'balanced' is not supported for partial_fit\. In order to use 'balanced' weights, use compute_class_weight\('balanced', classes=classes, y=y\). In place of y you can use a large enough sample of the full training set target to properly estimate the class frequency distributions\. Pass the resulting weights as the class_weight parameter\.r   balanced)class_weightr   )r   r   r   r2   r,   rx   rV   r   )r[   regexr   r   r"   &test_partial_fit_weight_class_balancedg  s    
r   c                 C   sf   | ddd tt}|jjdks$t|jjdks4t|ddggjdksNt|t	}t
|t d S )NrU   r   r   rH   rG   r   r   rF   rH   r)   r   Y2rr   rX   rv   rt   r3   r   T2r   true_result2r[   rz   predr   r   r"   test_sgd_multiclassx  s    
r   c              
   C   s   d}d}| dd||ddddd}t t}|t| t |}t|D ]`\}}t |jd	 }d
|||k< t	| t|||\}	}
t
|	|j| dd t|
|j| dd qHd S )Nrm   rU   r   r~   TrF   Fr   r   rE   r   r   )rV   r   r   r)   r   r   rY   r   rX   ri   r   rr   r   rt   )r[   r\   r]   rz   Znp_Y2r   re   clZy_iaverage_coefrc   r   r   r"   test_sgd_multiclass_average  s*    

r   c                 C   sb   | ddd}|j tttdtdd |jjdks:t|jjsJtd|	t
}t|t d S )NrU   r   r   r   rH   rn   r   )r)   r   r   rV   rW   rr   rX   rv   rt   r   r   r   r   r   r   r   r"   "test_sgd_multiclass_with_init_coef  s    
r  c                 C   sh   | dddd tt}|jjdks&t|jjdks6t|ddggjdksPt|t	}t
|t d S )	NrU   r   rG   )r]   r   n_jobsr   r   r   r   r   r   r   r   r"   test_sgd_multiclass_njobs  s    
r  c              	   C   s   |  }t t |jtttdd W 5 Q R X |  jtttdd}|  }t t |jtttdd W 5 Q R X |  jtttdd}d S )N)rG   rG   r   r   rF   r   r   )r   r   r   r)   r   r   rV   rW   r   r   r   r"   test_set_coef_multiclass  s      r  c              
   C   s   t jjD ]}t|d}|dkr<t|ds,tt|dstqd|}t|drTtt|drbttjt|d |j	 W 5 Q R X tjt|d |j
 W 5 Q R X qd S )N)r   r   r   r5   predict_log_probaz5probability estimates are not available for loss={!r}r   )r   r@   loss_functionsr   rv   formatr   r   AttributeErrorr5   r  )r[   r   rz   messager   r   r"   $test_sgd_predict_proba_method_access  s    
r  c                 C   s  t dddd dtt}t|dr&tt|dr4tdD ]}| |ddd}|tt |d	d
gg}|d dksvt|ddgg}|d dk st|d	d
gg}|d |d kst|ddgg}|d |d k s8tq8| ddddtt	}|
ddgddgg}|ddgddgg}ttj|ddtj|dd t|d  d t|d dksjt|ddgg}|
ddgg}tt|d t|d  |d	d
gg}|d	d
gg}tt|| |ddgg}|ddgg}tt|| | dddd}|tt	 |
d	d
gg}|d	d
gg}| tkrptj|ddtj|ddkstn"tj|ddtj|ddksttjdd}|
|g}t|dk r||g}t|d dgd	  d S )Nr   rU   r   )r   r]   r   r   r5   r  r  )r   r]   r   rH   rG   r   rF   rI   rE   )r   r   r   r   皙333333?皙?rF   )Zaxisr   r   gUUUUUU?)r@   r)   r,   rx   r   rv   r5   r  r   r   r3   r   rV   Zargmaxr   sumallZargsortr   logrA   Zargminmean)r[   rz   r   rg   dZlpxr   r   r"   test_sgd_proba  sR    
$"r  c                 C   s   t t}tjd}t|}|| t|d d f }t| }| ddddd dd}||| t	|j
ddd	f td
 ||}t	|| |  t|j
st||}t	|| tt|}t|j
st||}t	|| d S )N   r   r  F  )r   r]   r   r   r   rk   r   rF   rE   )   )lenX4rV   r   r   arangerk   Y4r)   r   rr   rW   r   Zsparsifyr&   issparserv   pickleloadsdumps)r[   nr   idxr,   rx   rz   r   r   r   r"   test_sgd_l1'  s4    






r%  c                 C   s   t ddgddgddgddgddgg}dddddg}| dd	d
d d}||| t|ddggt dg | dd	d
ddid}||| t|ddggt dg d S )Nr   r   皙rT   rS   rF   rE   r   r   F)r]   r   r   r   r  rm   rV   r   r)   r   r   r[   r,   r-   rz   r   r   r"   test_class_weightsL  s    (r)  c                 C   s   ddgddgddgddgg}ddddg}| ddd d}| || ddgddgg}ddg}| dddddd}| || t|j|jdd	 d S )
NrF   r   r   r   r]   r   r   rI   r  rG   r   )r)   r   rr   )r[   r,   r-   rz   Zclf_weightedr   r   r"   test_equal_class_weight_  s    r+  c              	   C   s8   | ddddid}t t |tt W 5 Q R X d S )Nr   r   r   rI   r*  r   r   r   r   r"   test_wrong_class_weight_labelp  s    r,  c                 C   s   ddd}t jd}|tjd }t |}|tdk  |d 9  < |tdk  |d 9  < | dd|d	}| ddd
}|jtt|d |jtt|d t	|j
|j
 d S )Ng333333?r  )rF   rG   r   rF   rG   r   r   r*  r   sample_weight)rV   r   r   Zrandom_sampler  rX   rs   r)   r  r   rr   )r[   Zclass_weightsr   Zsample_weightsZmultiplied_togetherr   r{   r   r   r"   test_weights_multipliedx  s    

r/  c                 C   s  t jt j }}t|}t|jd }tjd}|	| || }|| }| ddd dd
||}tj|||dd}t|d	d
d | ddddd
||}tj|||dd}t|d	d
d t|j|jd ||dkd d f }||dk }	t|g|gd  }
t|g|	gd  }| dd dd}|
|
| ||}tj||ddd	k s^t| dddd}|
|
| ||}tj||ddd	kstd S )Nr      r   r   F)r]   r   r   rk   Zweightedr   rQ   rF   r   r   r   )r   r   rk   )r   r   r   r   rV   r  rX   r   r   rk   r)   r   Zf1_scorer   r   r   rr   Zvstackconcatenaterv   )r[   r,   r-   r$  r   rz   f1Zclf_balancedZX_0Zy_0ZX_imbalancedZy_imbalancedy_predr   r   r"   test_balanced_weight  sD    
    

r4  c                 C   s   t ddgddgddgddgddgg}dddddg}| dd	d
d}||| t|ddggt dg |j||dgd dgd  d t|ddggt dg d S )Nr   r   r&  rT   rS   rF   rE   r   r   Fr]   r   r   r  rm   rH   rG   r-  r'  r(  r   r   r"   test_sample_weights  s    ( r6  c              	   C   sf   | t tfkr| dddd}n| ttfkr6| dddd}tt |jtt	t
dd W 5 Q R X d S )Nr   r   Fr5  )nur   r   r   r-  )r@   rA   r?   rC   r   r   r   r)   r,   rx   rV   r  r   r   r   r"   test_wrong_sample_weights  s    r8  c              	   C   s0   | dd}t t |tt W 5 Q R X d S )NrU   rq   )r   r   r   r2   r   r   r   r   r   r"   test_partial_fit_exception  s    
r9  c                 C   s   t jd d }| dd}tt}|jt d | td | |d |jjdt jd fks\t|jjdkslt|	ddggjdkstt
|jj}|t |d  t|d   t
|jj}|st||t}t|t d S )Nr   rH   rU   rq   r   rF   r  )r,   rX   rV   r   rx   r2   rr   rv   rt   r3   idr   r   r   r   r   )r[   thirdrz   r   id1id2r3  r   r   r"   test_partial_fit_binary  s    

 
r>  c                 C   s   t jd d }| dd}tt}|jt d | td | |d |jjdt jd fks\t|jjdkslt|	ddggjdkstt
|jj}|t |d  t|d   t
|jj}|st|d S )	Nr   rH   rU   rq   r   rF   r   r   )r   rX   rV   r   r   r2   rr   rv   rt   r3   r:  r   )r[   r;  rz   r   r<  r=  r   r   r"   test_partial_fit_multiclass  s    

 r?  c                 C   s   t jd d }| dt jd d}tt}|jt d | td | |d |jjdt jd fksdt|jjdkstt|t |d  t|d   |jjdt jd fkst|jjdkstd S )Nr   rH   rU   )r]   r   r   rF   r   )	r   rX   rV   r   r   r2   rr   rv   rt   )r[   r;  rz   r   r   r   r"   #test_partial_fit_multiclass_average	  s    
 r@  c                 C   s"   |  }| tt |tt d S r%   )r)   r   r   r2   r   r   r   r"   test_fit_then_partial_fit  s    rA  c                 C   s   t ttftttffD ]\}}}| ddd|dd}||| ||}|j}t	
|}| dd|dd}tdD ]}	|j|||d qn||}
|j|kstt||
dd qd S )NrU   rG   F)r]   rj   r   rl   rk   r]   rj   rl   rk   r   r   )r,   rx   r   r   r   r   r)   r3   ru   rV   r   r   r2   rv   r   )r[   ry   ZX_r   ZT_rz   r3  tr   re   y_pred2r   r   r"   "test_partial_fit_equal_fit_classif"  s    


rE  c                 C   s   t jd}| dddd|d}|tt dt |ttkksFt| dddd|d}|tt dt |ttkkst| dd	|d
}|tt dt |ttkkst| dddd|d}|tt dt |ttkkstd S )NrF   rU   r~   r   epsilon_insensitive)r]   rl   rj   r   r   rT   Zsquared_epsilon_insensitivehuber)r]   r   r   r   )	rV   r   r   r)   r,   rx   r  r   rv   )r[   r   rz   r   r   r"   test_regression_losses5  s>    rH  c                 C   s   t | ttd d S )Nr   )r}   r   r   r   r   r   r"   test_warm_start_multiclass[  s    rI  c                 C   s\   | ddd}| tt t|ds&tdd t tD }| td d d df | d S )NrU   Fr   rr   c                 S   s   g | ]}d dg| qS )ZhamZspamr   )r   re   r   r   r"   r   h  s     z%test_multiple_fit.<locals>.<listcomp>rE   )r)   r,   rx   r   rv   r   fit_transform)r[   rz   r-   r   r   r"   test_multiple_fit`  s
    rK  c                 C   sN   | dddd}| ddgddgddggdddg |jd |jd ksJtd S )Nr   rG   Fr5  r   rF   )r)   rr   rv   r   r   r   r"   test_sgd_regp  s    $rL  c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}t ||}| dd||d	d
d	dd}	|	|| t| ||||\}
}t|	j|
dd t	|	j
|dd d S )Nrm   rU   r   r   r   r   r   r~   TrF   Fr   r   r   )rV   r   r   r   rZ   r)   ri   r   rr   r   rt   r[   r\   r]   r   r   r   r,   r   r-   rz   ra   rc   r   r   r"   $test_sgd_averaged_computed_correctlyx  s,    rN  c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}t ||}| dd||d	d
d	dd}	|	|d t|d  d d  |d t|d   |	|t|d d  d d  |t|d d   t| ||||\}
}t|	j	|
dd t
|	jd |dd d S )Nrm   rU   r   r   r   r   r   r~   TrF   Fr   rG   r   r   )rV   r   r   r   rZ   r2   intri   r   rr   r   rt   rM  r   r   r"   test_sgd_averaged_partial_fit  s.    44rP  c              
   C   s   d}d}| dd||ddddd}t jd	 }|td t|d
  d d  t d t|d
   |tt|d
 d  d d  t t|d
 d   t| tt ||\}}t|j|dd t|j	|dd d S )Nrm   rU   r   r~   TrF   Fr   r   rG   r   r   )
r   rX   r2   r   rO  ri   r   rr   r   rt   )r[   r\   r]   rz   r   ra   rc   r   r   r"   test_average_sparse  s$    
44rQ  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| dddd	d
}||| |||}|dksntd|  |	|d  }| dddd	d
}||| |||}|dkstd S )Nr   r   r   rF   rI   r   r   r   F)r   r]   r   r   rR   
rV   r   r   Zlinspacer   r   r)   scorerv   randn	r[   ZxminZxmaxr   r   r,   r-   rz   rU  r   r   r"   test_sgd_least_squares_fit  s    rX  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| dddd	d
d}||| |||}|dksptd|  |	|d  }| dddd	d
d}||| |||}|dkstd S )NrR  r   r   rF   rI   rF  rU   r   r   Fr   epsilonr]   r   r   rR   rT  rW  r   r   r"   test_sgd_epsilon_insensitive  s4    r[  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| ddddd	d
}||| |||}|dksptd|  |	|d  }| ddddd	d
}||| |||}|dkstd S )NrR  r   r   rF   rI   rG  r   r   FrY  rR   rT  rW  r   r   r"   test_sgd_huber_fit  s    r\  c              	   C   s   d\}}t jd}|||}||}t ||}dD ]h}dD ]^}tj||dd}	|	|| | dd||dd	}
|
|| d
||f }t|	j	|
j	d|d qBq:d S )N)r   r   r   )rU   rm   )rI   rP   rT   F)r]   l1_ratior   
elasticnet2   )r   r   r]   r]  r   zNcd and sgd did not converge to comparable results for alpha=%f and l1_ratio=%frG   )r   r   )
rV   r   r   rV  rZ   r   Z
ElasticNetr)   r   rr   )r[   r   r   r   r,   Zground_truth_coefr-   r]   r]  cdZsgdr   r   r   r"   test_elasticnet_convergence)  s4    
  ra  c                 C   s   t jd d }| dd}|t d | td |  |jjt jd fksLt|jjdks\t|ddggjdksvtt|jj	}|t |d  t|d   t|jj	}|st|d S )Nr   rH   rU   rq   rF   r  )
r,   rX   r2   rx   rr   rv   rt   r   r:  r   )r[   r;  rz   r<  r=  r   r   r"   test_partial_fitK  s    
rb  c                 C   s   | ddd|dd}| tt |t}|j}| dd|dd}tdD ]}|tt qF|t}|j|ksptt	||dd d S )NrU   rG   F)r]   r   rj   rl   rk   rB  r   )
r)   r,   rx   r   r   ru   r   r2   rv   r   )r[   ry   rz   r3  rC  re   rD  r   r   r"   test_partial_fit_equal_fit]  s    

rc  c                 C   s0   | dd}|j dd |jd d dks,td S )NrO   )rZ  r   rG  rF   )rw   r  rv   r   r   r   r"   test_loss_function_epsilonn  s    
rd  c                 C   s  |d krt |jd }n|}t |jd }|}d| }	d}
d}| tkrNd}t|D ]\}}t ||}||	7 }|dkrd}nd}|tdd|| d  9 }||| |  7 }|	|||   | 7 }	||9 }||7 }||d  }|
|9 }
|
|	7 }
|
|d  }
qV|d|
 fS )NrF   rS   rT   rU   rE   r   rG   )rV   rW   rX   rC   rY   rZ   max)r[   r,   r\   r7  ro   r   coefr   offsetrb   rc   rd   re   rf   rg   rh   r   r   r"   asgd_oneclassz  s4    rh  c                 C   s   | ddd|d}| | | ddd|d}|j ||j |j d | dddd|d}| | |j|jksrtt|j|j |jdd	 | | |j|jkstt|j|j d S )
NrI   rU   F)r7  rj   rk   rl   r   ro   r   T)r7  rj   rk   rp   rl   r7  )r)   rr   rs   r   ru   rv   r   rw   )r[   r,   ry   rz   r{   r|   r   r   r"   _test_warm_start_oneclass  s    


rk  c                 C   s   t | t| d S r%   )rk  r,   r   r   r   r"   test_warm_start_oneclass  s    rl  c                 C   sN   | dd}t |}|jdd |t | dd}|t t|j|j d S )NrI   rj  r   )r   rw   r)   r,   r   rr   r   r   r   r"   test_clone_oneclass  s    



rm  c              	   C   s   t jd d }| dd}|t d |  |jjt jd fksBt|jjdksRt|ddggjdkslt|j}|t |d   |j|ksttt	 |t d d df  W 5 Q R X d S )Nr   rH   r   rj  rF   r  )
r,   rX   r2   rr   rv   r   r   r   r   r   )r[   r;  rz   Zprevious_coefsr   r   r"   test_partial_fit_oneclass  s    
rn  c           	      C   s   | ddd|dd}| t |t}|j}|j}|j}| ddd|dd}tdD ]}|t qR|t}|j|kszt	t
|| t
|j| t
|j| d S )N皙?rG   rU   F)r7  r   rj   rl   rk   rF   )r7  rj   r   rl   rk   )r)   r,   r3   r   ru   rr   r   r   r2   rv   r   )	r[   ry   rz   Zy_scoresrC  rf  rg  r   Z	y_scores2r   r   r"   #test_partial_fit_equal_fit_oneclass  s    



rp  c                 C   s   d}d}| dd||ddd}| dd||d	dd}| t | t t| t|||j |jd
\}}t|j |  t|j| d S )Nrm   ro  r   r~   rG   F)r   rl   rj   r7  r   rk   r   rF   ri  )r)   r,   rh  rr   r   r   r   )r[   rj   r7  r   r{   r   average_offsetr   r   r"   *test_late_onset_averaging_reached_oneclass  s<              

     
rr  c           
   	   C   sz   d}d}d}d}t jd}|j||fd}| d||dd	dd
d}|| t| |||\}}	t|j| t|j|	 d S )Nrm   ro  r   r   r   r   r~   TrF   Frl   rj   r7  r   r   r   rk   )	rV   r   r   r   r)   rh  r   rr   r   
r[   r\   r7  r   r   r   r,   rz   r   rq  r   r   r"   -test_sgd_averaged_computed_correctly_oneclass  s&    

ru  c           
   	   C   s   d}d}d}d}t jd}|j||fd}| d||dd	dd
d}||d t|d  d d   ||t|d d  d d   t| |||\}}	t|j| t|j	|	 d S )Nrm   ro  r   r   r   r   r~   TrF   Frs  rG   )
rV   r   r   r   r2   rO  rh  r   rr   r   rt  r   r   r"   &test_sgd_averaged_partial_fit_oneclass+  s(    
""rv  c              	   C   s   d}d}| d||ddddd}t jd }|t d t|d	   |t t|d	 d   t| t ||\}}t|j| t|j| d S )
Nrm   rU   r~   TrF   Frs  r   rG   )r   rX   r2   rO  rh  r   rr   r   )r[   r\   r7  rz   r   r   rq  r   r   r"   test_average_sparse_oneclassG  s"    

rw  c                  C   s   t ddgddgddgg} t ddgddgg}tdddddd}||  t|jt d	d
g |jd dksvt||}t|t ddg |||j }t|	|| |
|}t|t ddg d S )NrD   rE   rF   rI   rG   r~   F)r7  rj   rl   rk   r   g      g      ?r   rK   g      g      ?)rV   r   r?   r)   r   rr   r   rv   Zscore_samplesr3   r   r   )X_trainX_testrz   Zscoresdecr   r   r   r"   test_sgd_oneclass`  s$        


r{  c                  C   s.  d} d}d}t j|}d|dd }t j|d |d f }d|dd }t j|d |d f }t|d| d	}|| ||}||	d
d}	d}
t
||d}t| dd|
|d d}t||}|| ||}||	d
d}t ||kdkstt t |	|fd }|dks*td S )Nro  r   r   r    rG   r   Zrbf)gammaZkernelr7  rF   rE      )r}  r   T)r7  rk   r   r   r   r   rR   r  rO   )rV   r   r   rV  Zr_r   r)   r   r3   r   r	   r?   r   r  rv   corrcoefr1  )r7  r}  r   r   r,   rx  ry  rz   Zy_pred_ocsvmZ	dec_ocsvmr   Z	transformZclf_sgdZpipe_sgdZy_pred_sgdocsvmZdec_sgdocsvmr  r   r   r"   test_ocsvm_vs_sgdocsvmv  s:    




r  c                  C   s   t jddddd\} }tddd dd	d
d| |}tdddd
d d| |}t|j|j tddd ddd
d| |}tdddd
d d| |}t|j|j d S )Nr   r   r   i  )r   r   Zn_informativer   rm   r^  r0  gA?r   )r]   r   r   r   r]  r   r   )r]   r   r   r   r   g|=r   )r   make_classificationr@   r)   r   rr   )r,   r-   Zest_enZest_l1Zest_l2r   r   r"   test_l1_ratio  sd       
            r  c            	   
   C   s  t jdd t jd} d}d}| j||fd}|d d d df  d9  < t | sbtt 	|}t | st| j|d}t 
||d	kt j}tt |dd
g tdddd}||| t |j std}tjt|d ||| W 5 Q R X W 5 Q R X d S )Nraiser  r   r   r   r   rG   gu <7~rS   rF   r   r   r|  )r]   r   r   zwFloating-point under-/overflow occurred at epoch #.* Scaling input data with StandardScaler or MinMaxScaler might help.r   )rV   errstater   r   r   isfiniter  rv   r   rJ  rZ   astypeZint32r   r   r@   r)   rr   r   r   r   )	r   r   r   r,   ZX_scaledZground_truthr-   modelZ	msg_regxpr   r   r"   test_underflow_or_overlow  s&    r  c                  C   sZ   t ddddddddd d		} tjd
d | tjtj W 5 Q R X t| j	 sVt
d S )Nr   r   Tr^  r  rU   rm   r   )	r   r   rk   r   r]  r]   rj   r   r   r  r  )r@   rV   r  r)   r   r   r   r  rr   r  rv   )r  r   r   r"   'test_numerical_stability_large_gradient  s    r  r   r   r   r^  c              	   C   sV   t ddd| dd dd}tjdd |tjtj W 5 Q R X t|jt	|j d S )	Ng     j@r~   r   Fr0  )r]   rl   rj   r   rk   r   r   r  r  )
r@   rV   r  r)   r   r   r   r   rr   
zeros_like)r   r  r   r   r"   test_large_regularization  s    	r  c               	   C   s   t  tj} tjdk}d}td d|d}|| | ||jksDtd}tdd|d}|| | ||jkspt|jdks~ttdd|d}|| | |j|jkst|jdksttdd	dd
}d}t	j
t|d || | W 5 Q R X |jdkstd S )NrF   r   r   )r   r   r   r  r   r   rH   rm   )r   r   r   zhMaximum number of iteration reached before convergence. Consider increasing max_iter to improve the fit.r   )r   rJ  r   r   r   r@   r)   r   rv   r   warnsr   )r,   r-   r   Zmodel_0Zmodel_1Zmodel_2Zmodel_3Zwarning_messager   r   r"   test_tol_parameter  s*    
r  c                 C   s:   |D ]0\}}}}t | ||| t | ||| qd S r%   )r   py_losspy_dloss)Zloss_functioncasesrg   r-   Zexpected_lossZexpected_dlossr   r   r"   _test_loss_common1  s    r  c                  C   sT   t d} dddddddd	g}t| | t d
} ddddddddg}t| | d S )NrT   )g?rT   rS   rS          r   rS   rS   )rT   rT   rS   r   )r   r   rS   rT   )rI   rT   rI   r   )r   r         @rT   )rK   r   rI   rT   )rS   rT   rF   r   rS   rT   rT   rS   rS   )r  r   rS   rS   )rS   rT   rS   r   )rS   r   rS   rT   )rI   r   rI   rT   )r   r   r   rT   )rK   rT   rI   r   )r   rT   rT   r   )sgd_fastZHinger  r   r  r   r   r"   test_loss_hinge9  s,    


r  c                  C   s(   t d} ddddddg}t| | d S )NrT   r  r  )rT   r         @r  r   rT   r        )rI   rT   g      ?r   rI   r   g      @r  )r  ZSquaredHinger  r  r   r   r"   test_gradient_squared_hingeZ  s    
	r  c                  C   sH  t  } ddtdtd dtdd  fddtdtd dtdd  fddtdtd dtdd  fddtdtd dtdd  fddtddfddtddfddg}t| | t| d	dtd
d d t| d	dtd
d t| d
dtd
d d t| d
dd	d d S )NrT   r   rS   rG   rK   rI   )fffff1@r   r  rT   )gfffff1rT   r  r   g2@g2r   )	r  LogrV   r  expr  r   r  r  r  r   r   r"   test_loss_logi  s    ((((
r  c                  C   s$   t  } dddddg}t| | d S )NrS   rS   rS   rS   r  )rT   rS   rI   rT   )rI   r   g      ?rJ   )g      r   g     @$@g      )r  ZSquaredLossr  r  r   r   r"   test_loss_squared_loss~  s    r  c                  C   s(   t d} ddddddg}t| | d S )Nr   r  )r   rS   {Gzt?r   )rS   r   r  r  )g@r  g{GzT?g)      @r   gzG?r   )r   r  g
ףp=
?r  )r  ZHuberr  r  r   r   r"   test_loss_huber  s    
	r  c                  C   s*   t  } ddddddddg}t| | d S )	Nr  )r   r   rS   rS   )r   rT   rS   rS   )rS   rT   rT   r  r  r  )r  rT      r  )g      rT      r  )r  ZModifiedHuberr  r  r   r   r"   test_loss_modified_huber  s    r  c                  C   s,   t d} dddddddd	g}t| | d S )
Nr   r  r   rS   rS   rS   gffffff r  rS   rS   gffffff@r  rS   rS   )皙@r   r   rT   )r   r   333333@rT   )r   r  r   r   )r  rT   r  r   )r  ZEpsilonInsensitiver  r  r   r   r"   test_loss_epsilon_insensitive  s    
r  c                  C   s,   t d} dddddddd	g}t| | d S )
Nr   r  r  r  r  )r  r   rU   r  )r   r   R @g333333@)r   r  rU   gɿ)r  rT   r  g333333)r  ZSquaredEpsilonInsensitiver  r  r   r   r"   %test_loss_squared_epsilon_insensitive  s    
r  c               	   C   sf   t dddddddd} | tjtj | j| jks6t| j| jd k sJt| tjtjd	ksbtd S )
Nrm   r   Tr   r   rG   )r]   r   r   r   r   r   r  r   rP   )	r@   r)   r   r   r   r   r   rv   rU  )rz   r   r   r"   0test_multi_thread_multi_class_and_early_stopping  s    	r  c                  C   s^   t ddddddgd} tdd	d
dd}t|| dddd}|tjtj |jdksZt	d S )Nr  r   r   r   r_  )r]   r   rU   r   Tr   )r   r   r   r   rG   )Zn_iterr  r   rP   )
rV   Zlogspacer@   r   r)   r   r   r   Zbest_score_rv   )Z
param_gridrz   searchr   r   r"   -test_multi_core_gridsearch_and_early_stopping  s    r  backendZlokymultiprocessing	threadingc              	   C   s   t jd}tjdddd|d}|dd}tdd	dd
}||| tdddd
}tj| d ||| W 5 Q R X t	|j
|j
 d S )Nr   r|  r  g{Gz?Zcsr)Zdensityr	  r   r   r   rF   )r   r  r   r  )r  )rV   r   r   r&   choicer@   r)   joblibZparallel_backendr   rr   )r  r   r,   r-   Zclf_sequentialZclf_parallelr   r   r"   'test_SGDClassifier_fit_for_all_backends  s    r  	Estimatorc              	   C   s  | t jkrtj|d\}}ntj|d\}}| |dd}tt" |||j	}|j
dks`tW 5 Q R X | |dd}tt" |||j	}|j
dkstW 5 Q R X t|| | |d dd}tt" |||j	}|j
dkstW 5 Q R X t||  dkstd S )N)r   rF   )r   r   rT   )r   r;   r   Zmake_regressionr  r   r  r   r)   rr   r   rv   r   rV   absre  )r  Zglobal_random_seedr,   r-   ZestZcoef_same_seed_aZcoef_same_seed_bZcoef_other_seedr   r   r"   test_sgd_random_state  s"    

r  c           	      C   s   t jt j }}|jd }d}tjddd|d}ttjd}| 	td| |
|| |jd d	d
 \}}|jd t|| kst|jd t|| kstdS )ziTest that data passed to validation callback correctly subsets.

    Non-regression test for #23255.
    r   r  Trm   r   )r   r   r   r   )Zside_effect_ValidationScoreCallbackrF   rH   N)r   r   r   rX   r   r@   r   r   r  setattrr)   Z	call_argsrO  rv   )	Zmonkeypatchr,   rx   r   r   rz   ZmockZX_valZy_valr   r   r"   &test_validation_mask_correctly_subsets9  s    
r  c               	   C   s^   t jt j } }t|}d}tjd|dd}d}tjt	|d |j
| ||d W 5 Q R X d S )Nr   Tr   )r   r   r   z\The sample weights for validation set are all zero, consider using a different random state.r   r-  )r   r   r   rV   r  r   r@   r   r   r   r)   )r,   rx   r.  r   rz   error_messager   r   r"   (test_sgd_error_on_zero_validation_weightQ  s    
  r  c                 C   s   | dd tt dS )z!non-regression test for gh #25249rF   )verboseN)r)   r,   rx   )r  r   r   r"   test_sgd_verbosed  s    r  SGDEstimator	data_typec                 C   s>   t |}tjt|d}|  }||| |jj|ks:td S )Ndtype)	r,   r  rV   r   rx   r)   rr   r  rv   )r  r  Z_XZ_YZ	sgd_modelr   r   r"   test_sgd_dtype_matchj  s
    
r  c                 C   sz   t jtjd}tjttjd}t jtjd}tjttjd}| dd}||| | dd}||| t|j	|j	 d S )Nr  r   )r   )
r,   r  rV   float64r   rx   float32r)   r   rr   )r  ZX_64ZY_64ZX_32ZY_32Zsgd_64Zsgd_32r   r   r"   test_sgd_numerical_consistency~  s    

r  )NrS   )NrS   )r   Zunittest.mockr   r  numpyrV   r   Zscipy.sparsesparser&   Zsklearnr   r   r   Zsklearn.baser   r   Zsklearn.exceptionsr   Zsklearn.kernel_approximationr	   Zsklearn.linear_modelr
   r  r   Zsklearn.model_selectionr   r   r   Zsklearn.pipeliner   Zsklearn.preprocessingr   r   r   r   Zsklearn.svmr   Zsklearn.utils._testingr   r   r   r   r   r#   r@   r$   r;   r:   r?   r>   rA   rB   rC   r   r,   rx   r   r   r   r   r   r   r   r   r  r  Z	load_irisr   r   r   Ztrue_result5ri   r}   markZparametrizer   r   r   r   r   r   r   r   r   r   r   r   r   rW   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r%  r)  r+  r,  r/  r4  r6  r8  r9  r>  r?  r@  rA  rE  rH  rI  rK  rL  rN  rP  rQ  rX  r[  r\  ra  rb  rc  rd  rh  rk  rl  rm  rn  rp  rr  ru  rv  rw  r{  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r   r   r   r"   <module>   s  	
.

.
	" 
 

 



 

+ 

 

 

) 

 


 

	
 

	



	

!
	











D
$




.
 






	
%



 
 


#

!
&






)#&
$!
# 

#