Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unsupervised fitting #21

Merged
merged 1 commit into from
May 17, 2024
Merged

Fix unsupervised fitting #21

merged 1 commit into from
May 17, 2024

Conversation

aazuspan
Copy link
Contributor

Closes #20

This fix allows fitting unsupervised estimators with the assumption that they will always predict to shape (n_samples,).

Output dtype is now determined based on the _estimator_type attribute. This is likely a temporary solution as _estimator_type is planned for deprecation in favor of tags and explicit estimator type checking functions, but neither of those solutions are fully implemented yet.

This fix allows fitting unsupervised estimators with the assumption that
they will always predict to shape (n_samples,).

Output dtype is now determined based on the `_estimator_type` attribute.
This is likely a temporary solution as `_estimator_type` is planned for
deprecation in favor of tags and explicit estimator type checking
functions, but neither of those solutions are fully implemented yet.

See scikit-learn/scikit-learn#28960
@aazuspan aazuspan added the bug Something isn't working label May 17, 2024
@aazuspan aazuspan requested a review from grovduck May 17, 2024 17:43
@aazuspan aazuspan self-assigned this May 17, 2024
Copy link
Member

@grovduck grovduck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aazuspan, all looks great to me. Nice that you've left yourself some breadcrumbs to shift to sklearn tags when that comes around. And now we support unsupervised clustering as well!

y_pred = da.apply_gufunc(
estimator._wrapped.predict,
signature,
self.preprocessor.flat,
axis=self.preprocessor.flat_band_dim,
output_dtypes=[float],
output_dtypes=[output_dtype],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this fixes type for classifiers as well.

@aazuspan
Copy link
Contributor Author

Thanks for the quick review!

@aazuspan aazuspan merged commit 23487fb into main May 17, 2024
5 checks passed
@aazuspan aazuspan deleted the unsupervised branch May 17, 2024 21:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fitting unsupervised estimators fails
2 participants