from collections import namedtuple

import numpy as np
import pytest
from numpy.testing import assert_array_equal
from scipy.sparse import csr_matrix

from sklearn._config import config_context, get_config
from sklearn.utils._set_output import (
    _get_output_config,
    _safe_set_output,
    _SetOutputMixin,
    _wrap_in_pandas_container,
)


def test__wrap_in_pandas_container_dense():
    """Check _wrap_in_pandas_container for dense data."""
    pd = pytest.importorskip("pandas")
    X = np.asarray([[1, 0, 3], [0, 0, 1]])
    columns = np.asarray(["f0", "f1", "f2"], dtype=object)
    index = np.asarray([0, 1])

    dense_named = _wrap_in_pandas_container(X, columns=lambda: columns, index=index)
    assert isinstance(dense_named, pd.DataFrame)
    assert_array_equal(dense_named.columns, columns)
    assert_array_equal(dense_named.index, index)


def test__wrap_in_pandas_container_dense_update_columns_and_index():
    """Check that _wrap_in_pandas_container overrides columns and index."""
    pd = pytest.importorskip("pandas")
    X_df = pd.DataFrame([[1, 0, 3], [0, 0, 1]], columns=["a", "b", "c"])
    new_columns = np.asarray(["f0", "f1", "f2"], dtype=object)
    new_index = [10, 12]

    new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index)
    assert_array_equal(new_df.columns, new_columns)

    # Index does not change when the input is a DataFrame
    assert_array_equal(new_df.index, X_df.index)


def test__wrap_in_pandas_container_error_validation():
    """Check errors in _wrap_in_pandas_container."""
    X = np.asarray([[1, 0, 3], [0, 0, 1]])
    X_csr = csr_matrix(X)
    match = "Pandas output does not support sparse data"
    with pytest.raises(ValueError, match=match):
        _wrap_in_pandas_container(X_csr, columns=["a", "b", "c"])


class EstimatorWithoutSetOutputAndWithoutTransform:
    pass


class EstimatorNoSetOutputWithTransform:
    def transform(self, X, y=None):
        return X  # pragma: no cover


class EstimatorWithSetOutput(_SetOutputMixin):
    def fit(self, X, y=None):
        self.n_features_in_ = X.shape[1]
        return self

    def transform(self, X, y=None):
        return X

    def get_feature_names_out(self, input_features=None):
        return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)


def test__safe_set_output():
    """Check _safe_set_output works as expected."""

    # Estimator without transform will not raise when setting set_output for transform.
    est = EstimatorWithoutSetOutputAndWithoutTransform()
    _safe_set_output(est, transform="pandas")

    # Estimator with transform but without set_output will raise
    est = EstimatorNoSetOutputWithTransform()
    with pytest.raises(ValueError, match="Unable to configure output"):
        _safe_set_output(est, transform="pandas")

    est = EstimatorWithSetOutput().fit(np.asarray([[1, 2, 3]]))
    _safe_set_output(est, transform="pandas")
    config = _get_output_config("transform", est)
    assert config["dense"] == "pandas"

    _safe_set_output(est, transform="default")
    config = _get_output_config("transform", est)
    assert config["dense"] == "default"

    # transform is None is a no-op, so the config remains "default"
    _safe_set_output(est, transform=None)
    config = _get_output_config("transform", est)
    assert config["dense"] == "default"


class EstimatorNoSetOutputWithTransformNoFeatureNamesOut(_SetOutputMixin):
    def transform(self, X, y=None):
        return X  # pragma: no cover


def test_set_output_mixin():
    """Estimator without get_feature_names_out does not define `set_output`."""
    est = EstimatorNoSetOutputWithTransformNoFeatureNamesOut()
    assert not hasattr(est, "set_output")


def test__safe_set_output_error():
    """Check transform with invalid config."""
    X = np.asarray([[1, 0, 3], [0, 0, 1]])

    est = EstimatorWithSetOutput()
    _safe_set_output(est, transform="bad")

    msg = "output config must be 'default'"
    with pytest.raises(ValueError, match=msg):
        est.transform(X)


def test_set_output_method():
    """Check that the output is pandas."""
    pd = pytest.importorskip("pandas")

    X = np.asarray([[1, 0, 3], [0, 0, 1]])
    est = EstimatorWithSetOutput().fit(X)

    # transform=None is a no-op
    est2 = est.set_output(transform=None)
    assert est2 is est
    X_trans_np = est2.transform(X)
    assert isinstance(X_trans_np, np.ndarray)

    est.set_output(transform="pandas")

    X_trans_pd = est.transform(X)
    assert isinstance(X_trans_pd, pd.DataFrame)


def test_set_output_method_error():
    """Check transform fails with invalid transform."""

    X = np.asarray([[1, 0, 3], [0, 0, 1]])
    est = EstimatorWithSetOutput().fit(X)
    est.set_output(transform="bad")

    msg = "output config must be 'default'"
    with pytest.raises(ValueError, match=msg):
        est.transform(X)


def test__get_output_config():
    """Check _get_output_config works as expected."""

    # Without a configuration set, the global config is used
    global_config = get_config()["transform_output"]
    config = _get_output_config("transform")
    assert config["dense"] == global_config

    with config_context(transform_output="pandas"):
        # with estimator=None, the global config is used
        config = _get_output_config("transform")
        assert config["dense"] == "pandas"

        est = EstimatorNoSetOutputWithTransform()
        config = _get_output_config("transform", est)
        assert config["dense"] == "pandas"

        est = EstimatorWithSetOutput()
        # If estimator has not config, use global config
        config = _get_output_config("transform", est)
        assert config["dense"] == "pandas"

        # If estimator has a config, use local config
        est.set_output(transform="default")
        config = _get_output_config("transform", est)
        assert config["dense"] == "default"

    est.set_output(transform="pandas")
    config = _get_output_config("transform", est)
    assert config["dense"] == "pandas"


class EstimatorWithSetOutputNoAutoWrap(_SetOutputMixin, auto_wrap_output_keys=None):
    def transform(self, X, y=None):
        return X


def test_get_output_auto_wrap_false():
    """Check that auto_wrap_output_keys=None does not wrap."""
    est = EstimatorWithSetOutputNoAutoWrap()
    assert not hasattr(est, "set_output")

    X = np.asarray([[1, 0, 3], [0, 0, 1]])
    assert X is est.transform(X)


def test_auto_wrap_output_keys_errors_with_incorrect_input():
    msg = "auto_wrap_output_keys must be None or a tuple of keys."
    with pytest.raises(ValueError, match=msg):

        class BadEstimator(_SetOutputMixin, auto_wrap_output_keys="bad_parameter"):
            pass


class AnotherMixin:
    def __init_subclass__(cls, custom_parameter, **kwargs):
        super().__init_subclass__(**kwargs)
        cls.custom_parameter = custom_parameter


def test_set_output_mixin_custom_mixin():
    """Check that multiple init_subclasses passes parameters up."""

    class BothMixinEstimator(_SetOutputMixin, AnotherMixin, custom_parameter=123):
        def transform(self, X, y=None):
            return X

        def get_feature_names_out(self, input_features=None):
            return input_features

    est = BothMixinEstimator()
    assert est.custom_parameter == 123
    assert hasattr(est, "set_output")


def test__wrap_in_pandas_container_column_errors():
    """If a callable `columns` errors, it has the same semantics as columns=None."""
    pd = pytest.importorskip("pandas")

    def get_columns():
        raise ValueError("No feature names defined")

    X_df = pd.DataFrame({"feat1": [1, 2, 3], "feat2": [3, 4, 5]})

    X_wrapped = _wrap_in_pandas_container(X_df, columns=get_columns)
    assert_array_equal(X_wrapped.columns, X_df.columns)

    X_np = np.asarray([[1, 3], [2, 4], [3, 5]])
    X_wrapped = _wrap_in_pandas_container(X_np, columns=get_columns)
    assert_array_equal(X_wrapped.columns, range(X_np.shape[1]))


def test_set_output_mro():
    """Check that multi-inheritance resolves to the correct class method.

    Non-regression test gh-25293.
    """

    class Base(_SetOutputMixin):
        def transform(self, X):
            return "Base"  # noqa

    class A(Base):
        pass

    class B(Base):
        def transform(self, X):
            return "B"

    class C(A, B):
        pass

    assert C().transform(None) == "B"


class EstimatorWithSetOutputIndex(_SetOutputMixin):
    def fit(self, X, y=None):
        self.n_features_in_ = X.shape[1]
        return self

    def transform(self, X, y=None):
        import pandas as pd

        # transform by giving output a new index.
        return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])])

    def get_feature_names_out(self, input_features=None):
        return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)


def test_set_output_pandas_keep_index():
    """Check that set_output does not override index.

    Non-regression test for gh-25730.
    """
    pd = pytest.importorskip("pandas")

    X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1])
    est = EstimatorWithSetOutputIndex().set_output(transform="pandas")
    est.fit(X)

    X_trans = est.transform(X)
    assert_array_equal(X_trans.index, ["s0", "s1"])


class EstimatorReturnTuple(_SetOutputMixin):
    def __init__(self, OutputTuple):
        self.OutputTuple = OutputTuple

    def transform(self, X, y=None):
        return self.OutputTuple(X, 2 * X)


def test_set_output_named_tuple_out():
    """Check that namedtuples are kept by default."""
    Output = namedtuple("Output", "X, Y")
    X = np.asarray([[1, 2, 3]])
    est = EstimatorReturnTuple(OutputTuple=Output)
    X_trans = est.transform(X)

    assert isinstance(X_trans, Output)
    assert_array_equal(X_trans.X, X)
    assert_array_equal(X_trans.Y, 2 * X)


class EstimatorWithListInput(_SetOutputMixin):
    def fit(self, X, y=None):
        assert isinstance(X, list)
        self.n_features_in_ = len(X[0])
        return self

    def transform(self, X, y=None):
        return X

    def get_feature_names_out(self, input_features=None):
        return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)


def test_set_output_list_input():
    """Check set_output for list input.

    Non-regression test for #27037.
    """
    pd = pytest.importorskip("pandas")

    X = [[0, 1, 2, 3], [4, 5, 6, 7]]
    est = EstimatorWithListInput()
    est.set_output(transform="pandas")

    X_out = est.fit(X).transform(X)
    assert isinstance(X_out, pd.DataFrame)
    assert_array_equal(X_out.columns, ["X0", "X1", "X2", "X3"])
