Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add missing base.is_clusterer() function #28936

Merged
7 changes: 7 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123455 is the *pull request* number, not the issue number.

:mod:`sklearn.base`
...................

- |Enhancement| Added a function :func:`base.is_clusterer` which determines
whether a given estimator is of category clusterer.
:pr:`28936` by :user:`Christian Veenhuis <ChVeen>`.

Thanks to everyone who has contributed to the maintenance and improvement of
the project since version 1.5, including:

Expand Down
39 changes: 39 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,13 +1374,17 @@ def is_classifier(estimator):
Examples
--------
>>> from sklearn.base import is_classifier
>>> from sklearn.cluster import KMeans
>>> from sklearn.svm import SVC, SVR
>>> classifier = SVC()
>>> regressor = SVR()
>>> kmeans = KMeans()
>>> is_classifier(classifier)
True
>>> is_classifier(regressor)
False
>>> is_classifier(kmeans)
False
"""
return getattr(estimator, "_estimator_type", None) == "classifier"

Expand All @@ -1401,17 +1405,52 @@ def is_regressor(estimator):
Examples
--------
>>> from sklearn.base import is_regressor
>>> from sklearn.cluster import KMeans
>>> from sklearn.svm import SVC, SVR
>>> classifier = SVC()
>>> regressor = SVR()
>>> kmeans = KMeans()
>>> is_regressor(classifier)
False
>>> is_regressor(regressor)
True
>>> is_regressor(kmeans)
False
"""
return getattr(estimator, "_estimator_type", None) == "regressor"


def is_clusterer(estimator):
"""Return True if the given estimator is (probably) a clusterer.
ChVeen marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
estimator : object
Estimator object to test.

Returns
-------
out : bool
True if estimator is a clusterer and False otherwise.

Examples
--------
>>> from sklearn.base import is_clusterer
>>> from sklearn.cluster import KMeans
>>> from sklearn.svm import SVC, SVR
>>> classifier = SVC()
>>> regressor = SVR()
>>> kmeans = KMeans()
>>> is_clusterer(classifier)
False
>>> is_clusterer(regressor)
False
>>> is_clusterer(kmeans)
True
"""
return getattr(estimator, "_estimator_type", None) == "clusterer"


def is_outlier_detector(estimator):
"""Return True if the given estimator is (probably) an outlier detector.

Expand Down
47 changes: 46 additions & 1 deletion sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
TransformerMixin,
clone,
is_classifier,
is_clusterer,
is_regressor,
)
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.svm import SVC, SVR
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._set_output import _get_output_config
Expand Down Expand Up @@ -260,12 +263,54 @@ def test_get_params():


def test_is_classifier():
ChVeen marked this conversation as resolved.
Show resolved Hide resolved
# classifier cases
svc = SVC()
assert is_classifier(svc)
assert is_classifier(GridSearchCV(svc, {"C": [0.1, 1]}))
assert is_classifier(Pipeline([("svc", svc)]))
assert is_classifier(Pipeline([("svc_cv", GridSearchCV(svc, {"C": [0.1, 1]}))]))

# non-classifier cases
svr = SVR()
assert not is_classifier(svr)
assert not is_classifier(GridSearchCV(svr, {"C": [0.1, 1]}))
assert not is_classifier(Pipeline([("svr", svr)]))
assert not is_classifier(Pipeline([("svr_cv", GridSearchCV(svr, {"C": [0.1, 1]}))]))


def test_is_regressor():
# regressor cases
svr = SVR()
assert is_regressor(svr)
assert is_regressor(GridSearchCV(svr, {"C": [0.1, 1]}))
assert is_regressor(Pipeline([("svr", svr)]))
assert is_regressor(Pipeline([("svr_cv", GridSearchCV(svr, {"C": [0.1, 1]}))]))

# non-regressor cases
svc = SVC()
assert not is_regressor(svc)
assert not is_regressor(GridSearchCV(svc, {"C": [0.1, 1]}))
assert not is_regressor(Pipeline([("svc", svc)]))
assert not is_regressor(Pipeline([("svc_cv", GridSearchCV(svc, {"C": [0.1, 1]}))]))


def test_is_clusterer():
# clusterer cases
kmeans = KMeans()
assert is_clusterer(kmeans)
assert is_clusterer(GridSearchCV(kmeans, {"n_clusters": [3, 8]}))
assert is_clusterer(Pipeline([("kmeans", kmeans)]))
assert is_clusterer(
Pipeline([("kmeans_cv", GridSearchCV(kmeans, {"n_clusters": [3, 8]}))])
)

# non-clusterer cases
svc = SVC()
assert not is_clusterer(svc)
assert not is_clusterer(GridSearchCV(svc, {"C": [0.1, 1]}))
assert not is_clusterer(Pipeline([("svc", svc)]))
assert not is_clusterer(Pipeline([("svc_cv", GridSearchCV(svc, {"C": [0.1, 1]}))]))


def test_set_params():
# test nested estimator parameter setting
Expand Down