Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH check_scoring() has raise_exc for multimetric scoring #28992

Merged
merged 12 commits into from
May 24, 2024
2 changes: 1 addition & 1 deletion sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class RFECV(RFE):
``cv`` default value of None changed from 3-fold to 5-fold.

scoring : str, callable or None, default=None
A string (see model evaluation documentation) or
A string (see :ref:`scoring_parameter`) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.

Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def _log_reg_scoring_path(
values are chosen in a logarithmic scale between 1e-4 and 1e4.

scoring : callable
A string (see model evaluation documentation) or
A string (see :ref:`scoring_parameter`) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``. For a list of scoring functions
that can be used, look at :mod:`sklearn.metrics`.
Expand Down Expand Up @@ -1521,7 +1521,7 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima
solver.

scoring : str or callable, default=None
A string (see model evaluation documentation) or
A string (see :ref:`scoring_parameter`) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``. For a list of scoring functions
that can be used, look at :mod:`sklearn.metrics`. The
Expand Down
23 changes: 18 additions & 5 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,10 +955,11 @@ def get_scorer_names():
None,
],
"allow_none": ["boolean"],
"raise_exc": ["boolean"],
},
prefer_skip_nested_validation=True,
)
def check_scoring(estimator=None, scoring=None, *, allow_none=False):
def check_scoring(estimator=None, scoring=None, *, allow_none=False, raise_exc=True):
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
"""Determine scorer from user options.

A TypeError will be thrown if the estimator cannot be scored.
Expand All @@ -969,7 +970,7 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False):
The object to use to fit the data. If `None`, then this function may error
depending on `allow_none`.

scoring : str, callable, list, tuple, or dict, default=None
scoring : str, callable, list, tuple, set, or dict, default=None
Scorer to use. If `scoring` represents a single score, one can use:

- a single string (see :ref:`scoring_parameter`);
Expand All @@ -985,8 +986,20 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False):
If None, the provided estimator object's `score` method is used.

allow_none : bool, default=False
If no scoring is specified and the estimator has no score function, we
can either return None or raise an exception.
Whether to return None or raise an error if no `scoring` is specified and the
estimator has no `score` method.

raise_exc : bool, default=True
Whether to raise an exception if a subset of the scorers in multimetric scoring
fails or return an error code.

- If set to `True` raises the failing scorer's exception.
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved

StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
- If set to `False` a formatted string of the exception details is passed as
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved
result of the failing scorer(s).

This applies if `scoring` is list, tuple, set, or dict. Ignored if `scoring` is
a str or a callable.
StefanieSenger marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand Down Expand Up @@ -1026,7 +1039,7 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False):
return get_scorer(scoring)
if isinstance(scoring, (list, tuple, set, dict)):
scorers = _check_multimetric_scoring(estimator, scoring=scoring)
return _MultimetricScorer(scorers=scorers)
return _MultimetricScorer(scorers=scorers, raise_exc=raise_exc)
if scoring is None:
if hasattr(estimator, "score"):
return _PassthroughScorer(estimator)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_classification_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ class TunedThresholdClassifierCV(BaseThresholdClassifier):
The objective metric to be optimized. Can be one of:

* a string associated to a scoring function for binary classification
(see model evaluation documentation);
(see :ref:`scoring_parameter`);
* a scorer callable object created with :func:`~sklearn.metrics.make_scorer`;

response_method : {"auto", "decision_function", "predict_proba"}, default="auto"
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
first_test_score = all_out[0]["test_scores"]
self.multimetric_ = isinstance(first_test_score, dict)

# check refit_metric now for a callabe scorer that is multimetric
# check refit_metric now for a callable scorer that is multimetric
if callable(self.scoring) and self.multimetric_:
self._check_refit_for_multimetric(first_test_score)
refit_metric = self.refit
Expand Down
14 changes: 4 additions & 10 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..base import clone, is_classifier
from ..exceptions import FitFailedWarning, UnsetMetadataPassedError
from ..metrics import check_scoring, get_scorer_names
from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer
from ..metrics._scorer import _MultimetricScorer
from ..preprocessing import LabelEncoder
from ..utils import Bunch, _safe_indexing, check_random_state, indexable
from ..utils._param_validation import (
Expand Down Expand Up @@ -352,15 +352,9 @@ def cross_validate(

cv = check_cv(cv, y, classifier=is_classifier(estimator))

if callable(scoring):
scorers = scoring
elif scoring is None or isinstance(scoring, str):
scorers = check_scoring(estimator, scoring)
else:
scorers = _check_multimetric_scoring(estimator, scoring)
scorers = _MultimetricScorer(
scorers=scorers, raise_exc=(error_score == "raise")
)
scorers = check_scoring(
estimator, scoring=scoring, raise_exc=(error_score == "raise")
)

if _routing_enabled():
# For estimators, a MetadataRouter is created in get_metadata_routing
Expand Down