import html
from contextlib import closing
from io import StringIO

import pytest

from sklearn import config_context
from sklearn.cluster import AgglomerativeClustering, Birch
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.ensemble import StackingClassifier, StackingRegressor, VotingClassifier
from sklearn.feature_selection import SelectPercentile
from sklearn.gaussian_process.kernels import ExpSineSquared
from sklearn.impute import SimpleImputer
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV
from sklearn.multiclass import OneVsOneClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.svm import LinearSVC, LinearSVR
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils._estimator_html_repr import (
    _get_visual_block,
    _write_label_html,
    estimator_html_repr,
)


@pytest.mark.parametrize("checked", [True, False])
def test_write_label_html(checked):
    # Test checking logic and labeling
    name = "LogisticRegression"
    tool_tip = "hello-world"

    with closing(StringIO()) as out:
        _write_label_html(out, name, tool_tip, checked=checked)
        html_label = out.getvalue()
        assert "LogisticRegression</label>" in html_label
        assert html_label.startswith('<div class="sk-label-container">')
        assert "<pre>hello-world</pre>" in html_label
        if checked:
            assert "checked>" in html_label


@pytest.mark.parametrize("est", ["passthrough", "drop", None])
def test_get_visual_block_single_str_none(est):
    # Test estimators that are represented by strings
    est_html_info = _get_visual_block(est)
    assert est_html_info.kind == "single"
    assert est_html_info.estimators == est
    assert est_html_info.names == str(est)
    assert est_html_info.name_details == str(est)


def test_get_visual_block_single_estimator():
    est = LogisticRegression(C=10.0)
    est_html_info = _get_visual_block(est)
    assert est_html_info.kind == "single"
    assert est_html_info.estimators == est
    assert est_html_info.names == est.__class__.__name__
    assert est_html_info.name_details == str(est)


def test_get_visual_block_pipeline():
    pipe = Pipeline(
        [
            ("imputer", SimpleImputer()),
            ("do_nothing", "passthrough"),
            ("do_nothing_more", None),
            ("classifier", LogisticRegression()),
        ]
    )
    est_html_info = _get_visual_block(pipe)
    assert est_html_info.kind == "serial"
    assert est_html_info.estimators == tuple(step[1] for step in pipe.steps)
    assert est_html_info.names == [
        "imputer: SimpleImputer",
        "do_nothing: passthrough",
        "do_nothing_more: passthrough",
        "classifier: LogisticRegression",
    ]
    assert est_html_info.name_details == [str(est) for _, est in pipe.steps]


def test_get_visual_block_feature_union():
    f_union = FeatureUnion([("pca", PCA()), ("svd", TruncatedSVD())])
    est_html_info = _get_visual_block(f_union)
    assert est_html_info.kind == "parallel"
    assert est_html_info.names == ("pca", "svd")
    assert est_html_info.estimators == tuple(
        trans[1] for trans in f_union.transformer_list
    )
    assert est_html_info.name_details == (None, None)


def test_get_visual_block_voting():
    clf = VotingClassifier(
        [("log_reg", LogisticRegression()), ("mlp", MLPClassifier())]
    )
    est_html_info = _get_visual_block(clf)
    assert est_html_info.kind == "parallel"
    assert est_html_info.estimators == tuple(trans[1] for trans in clf.estimators)
    assert est_html_info.names == ("log_reg", "mlp")
    assert est_html_info.name_details == (None, None)


def test_get_visual_block_column_transformer():
    ct = ColumnTransformer(
        [("pca", PCA(), ["num1", "num2"]), ("svd", TruncatedSVD, [0, 3])]
    )
    est_html_info = _get_visual_block(ct)
    assert est_html_info.kind == "parallel"
    assert est_html_info.estimators == tuple(trans[1] for trans in ct.transformers)
    assert est_html_info.names == ("pca", "svd")
    assert est_html_info.name_details == (["num1", "num2"], [0, 3])


def test_estimator_html_repr_pipeline():
    num_trans = Pipeline(
        steps=[("pass", "passthrough"), ("imputer", SimpleImputer(strategy="median"))]
    )

    cat_trans = Pipeline(
        steps=[
            ("imputer", SimpleImputer(strategy="constant", missing_values="empty")),
            ("one-hot", OneHotEncoder(drop="first")),
        ]
    )

    preprocess = ColumnTransformer(
        [
            ("num", num_trans, ["a", "b", "c", "d", "e"]),
            ("cat", cat_trans, [0, 1, 2, 3]),
        ]
    )

    feat_u = FeatureUnion(
        [
            ("pca", PCA(n_components=1)),
            (
                "tsvd",
                Pipeline(
                    [
                        ("first", TruncatedSVD(n_components=3)),
                        ("select", SelectPercentile()),
                    ]
                ),
            ),
        ]
    )

    clf = VotingClassifier(
        [
            ("lr", LogisticRegression(solver="lbfgs", random_state=1)),
            ("mlp", MLPClassifier(alpha=0.001)),
        ]
    )

    pipe = Pipeline(
        [("preprocessor", preprocess), ("feat_u", feat_u), ("classifier", clf)]
    )
    html_output = estimator_html_repr(pipe)

    # top level estimators show estimator with changes
    assert html.escape(str(pipe)) in html_output
    for _, est in pipe.steps:
        assert (
            '<div class="sk-toggleable__content"><pre>' + html.escape(str(est))
        ) in html_output

    # low level estimators do not show changes
    with config_context(print_changed_only=True):
        assert html.escape(str(num_trans["pass"])) in html_output
        assert "passthrough</label>" in html_output
        assert html.escape(str(num_trans["imputer"])) in html_output

        for _, _, cols in preprocess.transformers:
            assert f"<pre>{html.escape(str(cols))}</pre>" in html_output

        # feature union
        for name, _ in feat_u.transformer_list:
            assert f"<label>{html.escape(name)}</label>" in html_output

        pca = feat_u.transformer_list[0][1]
        assert f"<pre>{html.escape(str(pca))}</pre>" in html_output

        tsvd = feat_u.transformer_list[1][1]
        first = tsvd["first"]
        select = tsvd["select"]
        assert f"<pre>{html.escape(str(first))}</pre>" in html_output
        assert f"<pre>{html.escape(str(select))}</pre>" in html_output

        # voting classifier
        for name, est in clf.estimators:
            assert f"<label>{html.escape(name)}</label>" in html_output
            assert f"<pre>{html.escape(str(est))}</pre>" in html_output


@pytest.mark.parametrize("final_estimator", [None, LinearSVC()])
def test_stacking_classifier(final_estimator):
    estimators = [
        ("mlp", MLPClassifier(alpha=0.001)),
        ("tree", DecisionTreeClassifier()),
    ]
    clf = StackingClassifier(estimators=estimators, final_estimator=final_estimator)

    html_output = estimator_html_repr(clf)

    assert html.escape(str(clf)) in html_output
    # If final_estimator's default changes from LogisticRegression
    # this should be updated
    if final_estimator is None:
        assert "LogisticRegression(" in html_output
    else:
        assert final_estimator.__class__.__name__ in html_output


@pytest.mark.parametrize("final_estimator", [None, LinearSVR()])
def test_stacking_regressor(final_estimator):
    reg = StackingRegressor(
        estimators=[("svr", LinearSVR())], final_estimator=final_estimator
    )
    html_output = estimator_html_repr(reg)

    assert html.escape(str(reg.estimators[0][0])) in html_output
    assert "LinearSVR</label>" in html_output
    if final_estimator is None:
        assert "RidgeCV</label>" in html_output
    else:
        assert html.escape(final_estimator.__class__.__name__) in html_output


def test_birch_duck_typing_meta():
    # Test duck typing meta estimators with Birch
    birch = Birch(n_clusters=AgglomerativeClustering(n_clusters=3))
    html_output = estimator_html_repr(birch)

    # inner estimators do not show changes
    with config_context(print_changed_only=True):
        assert f"<pre>{html.escape(str(birch.n_clusters))}" in html_output
        assert "AgglomerativeClustering</label>" in html_output

    # outer estimator contains all changes
    assert f"<pre>{html.escape(str(birch))}" in html_output


def test_ovo_classifier_duck_typing_meta():
    # Test duck typing metaestimators with OVO
    ovo = OneVsOneClassifier(LinearSVC(penalty="l1"))
    html_output = estimator_html_repr(ovo)

    # inner estimators do not show changes
    with config_context(print_changed_only=True):
        assert f"<pre>{html.escape(str(ovo.estimator))}" in html_output
        assert "LinearSVC</label>" in html_output

    # outer estimator
    assert f"<pre>{html.escape(str(ovo))}" in html_output


def test_duck_typing_nested_estimator():
    # Test duck typing metaestimators with random search
    kernel_ridge = KernelRidge(kernel=ExpSineSquared())
    param_distributions = {"alpha": [1, 2]}

    kernel_ridge_tuned = RandomizedSearchCV(
        kernel_ridge,
        param_distributions=param_distributions,
    )
    html_output = estimator_html_repr(kernel_ridge_tuned)
    assert "estimator: KernelRidge</label>" in html_output


@pytest.mark.parametrize("print_changed_only", [True, False])
def test_one_estimator_print_change_only(print_changed_only):
    pca = PCA(n_components=10)

    with config_context(print_changed_only=print_changed_only):
        pca_repr = html.escape(str(pca))
        html_output = estimator_html_repr(pca)
        assert pca_repr in html_output


def test_fallback_exists():
    """Check that repr fallback is in the HTML."""
    pca = PCA(n_components=10)
    html_output = estimator_html_repr(pca)

    assert (
        f'<div class="sk-text-repr-fallback"><pre>{html.escape(str(pca))}'
        in html_output
    )


def test_show_arrow_pipeline():
    """Show arrow in pipeline for top level in pipeline"""
    pipe = Pipeline([("scale", StandardScaler()), ("log_Reg", LogisticRegression())])

    html_output = estimator_html_repr(pipe)
    assert (
        'class="sk-toggleable__label sk-toggleable__label-arrow">Pipeline'
        in html_output
    )


def test_invalid_parameters_in_stacking():
    """Invalidate stacking configuration uses default repr.

    Non-regression test for #24009.
    """
    stacker = StackingClassifier(estimators=[])

    html_output = estimator_html_repr(stacker)
    assert html.escape(str(stacker)) in html_output


def test_estimator_get_params_return_cls():
    """Check HTML repr works where a value in get_params is a class."""

    class MyEstimator:
        def get_params(self, deep=False):
            return {"inner_cls": LogisticRegression}

    est = MyEstimator()
    assert "MyEstimator" in estimator_html_repr(est)
