Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sklearn models with multiple outputs #2433

Open
wants to merge 2 commits into
base: dev_1.18.0
Choose a base branch
from

Conversation

abigailgold
Copy link
Collaborator

(i.e., nb_classes is an array instead of an integer).

Description

The Classifier nb_classes setter now accepts non-integer values to support cases where the classifier has multiple outputs.
check_and_transform_label_format still only supports integer nb_classes values, but throws a TypeError if a different type is received.

Type of change

Please check all relevant options.

  • Improvement (non-breaking)
  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Testing

Added a test with a multi-output DecisionTreeClassifier

Test Configuration:

  • OS: MacOS 14.4
  • Python version: 3.9
  • ART version or commit number
  • TensorFlow / Keras / PyTorch / MXNet version

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • My changes have been tested using both CPU and GPU devices

@beat-buesser beat-buesser self-requested a review April 17, 2024 13:13
@beat-buesser beat-buesser self-assigned this Apr 17, 2024
@beat-buesser beat-buesser added the enhancement New feature or request label Apr 17, 2024
@beat-buesser beat-buesser added this to the ART 1.18.0 milestone Apr 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
ART 1.18.0
Awaiting triage
Development

Successfully merging this pull request may close these issues.

None yet

2 participants