Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for many new model output types #93

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

Conversation

abigailgold
Copy link
Member

Including multi-label (output) models.
Note that not all model output types are supported in all modules and methods.

Signed-off-by: abigailt <[email protected]>
…cated with the new output types supported.

Signed-off-by: abigailt <[email protected]>
Copy link
Member

@andersonm-ibm andersonm-ibm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some readability suggestions

@@ -65,7 +63,7 @@ def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
:return: Predictions from the model as numpy array (class probabilities, if supported).
"""
predictions = self._art_model.predict(x.get_samples(), **kwargs)
check_correct_model_output(predictions, self.output_type)
# check_correct_model_output(predictions, self.output_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out?

@@ -12,9 +13,16 @@


class ModelOutputType(Enum):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not inherit from Flag instead?
Then you could use much less members for the variations, e.g. - CLASSIFIER, MULTI, BINARY, LOGITS, and the checks in the rest of the code will be much easier



def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
def is_multi_label(output_type: ModelOutputType) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole section would be redundant if you ModelOutputType inherits from Flag.

:return: the score as float (for classifiers, between 0 and 1)
"""
raise NotImplementedError
predictions = kwargs['predictions'] if 'predictions' in kwargs else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In these and the following you can use kwargs.get('predictions') variations.
It returns the value for key if key is in the dictionary, else default. If default is not given, it defaults to None,

if scoring_method == ScoringMethod.ACCURACY:
if not is_multi_label(self.output_type) and not is_binary(self.output_type) and nb_classes is not None:
y = check_and_transform_label_format(y, nb_classes=nb_classes)
if (self.output_type == ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_CLASS_PROBABILITIES
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, a Flag ModelOutputType would make this check easier


if y_train_pred is not None and len(y_train_pred.shape) == 1:
self._nb_classes = get_nb_classes(y_train_pred)
# self._nb_classes = get_nb_classes(y_train_pred, self.output_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These commented-out lines can be removed, right?

def score(self, test_data: Dataset, **kwargs):
"""
Score the model using test data.

:param test_data: Test data.
:type train_data: `Dataset`
:type test_data: `Dataset`
:param predictions: Model predictions to score. If provided, these will be used instead of calling the model's
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't all these defined in the API, but only in the method documentation?


def test_pytorch_predictions_multi_label_cat():
# This kind of model requires special training and will not be supported using the 'fit' method.
class multi_label_cat_model(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class name should be MultiLabelCatModel



def test_pytorch_predictions_multi_label_binary():
class multi_label_binary_model(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, class names should normally use the CapWords convention, unless they are primarily used as a callable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants