-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for many new model output types #93
base: main
Are you sure you want to change the base?
Conversation
…i-label classifiers Signed-off-by: abigailt <[email protected]>
Signed-off-by: abigailt <[email protected]>
…upports multiple output types. Existing tests pass. Still need more tests for new types. Signed-off-by: abigailt <[email protected]>
Signed-off-by: abigailt <[email protected]>
…ytorch model passing. Signed-off-by: abigailt <[email protected]>
Signed-off-by: abigailt <[email protected]>
Signed-off-by: abigailt <[email protected]>
…ions Signed-off-by: abigailt <[email protected]>
Signed-off-by: abigailt <[email protected]>
Signed-off-by: abigailt <[email protected]>
Signed-off-by: abigailt <[email protected]>
…cated with the new output types supported. Signed-off-by: abigailt <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some readability suggestions
apt/utils/models/keras_model.py
Outdated
@@ -65,7 +63,7 @@ def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE: | |||
:return: Predictions from the model as numpy array (class probabilities, if supported). | |||
""" | |||
predictions = self._art_model.predict(x.get_samples(), **kwargs) | |||
check_correct_model_output(predictions, self.output_type) | |||
# check_correct_model_output(predictions, self.output_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out?
@@ -12,9 +13,16 @@ | |||
|
|||
|
|||
class ModelOutputType(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not inherit from Flag instead?
Then you could use much less members for the variations, e.g. - CLASSIFIER, MULTI, BINARY, LOGITS, and the checks in the rest of the code will be much easier
|
||
|
||
def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int: | ||
def is_multi_label(output_type: ModelOutputType) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole section would be redundant if you ModelOutputType inherits from Flag.
apt/utils/models/model.py
Outdated
:return: the score as float (for classifiers, between 0 and 1) | ||
""" | ||
raise NotImplementedError | ||
predictions = kwargs['predictions'] if 'predictions' in kwargs else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In these and the following you can use kwargs.get('predictions')
variations.
It returns the value for key if key is in the dictionary, else default. If default is not given, it defaults to None,
if scoring_method == ScoringMethod.ACCURACY: | ||
if not is_multi_label(self.output_type) and not is_binary(self.output_type) and nb_classes is not None: | ||
y = check_and_transform_label_format(y, nb_classes=nb_classes) | ||
if (self.output_type == ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_CLASS_PROBABILITIES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, a Flag ModelOutputType would make this check easier
apt/utils/models/model.py
Outdated
|
||
if y_train_pred is not None and len(y_train_pred.shape) == 1: | ||
self._nb_classes = get_nb_classes(y_train_pred) | ||
# self._nb_classes = get_nb_classes(y_train_pred, self.output_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These commented-out lines can be removed, right?
apt/utils/models/model.py
Outdated
def score(self, test_data: Dataset, **kwargs): | ||
""" | ||
Score the model using test data. | ||
|
||
:param test_data: Test data. | ||
:type train_data: `Dataset` | ||
:type test_data: `Dataset` | ||
:param predictions: Model predictions to score. If provided, these will be used instead of calling the model's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why aren't all these defined in the API, but only in the method documentation?
tests/test_pytorch.py
Outdated
|
||
def test_pytorch_predictions_multi_label_cat(): | ||
# This kind of model requires special training and will not be supported using the 'fit' method. | ||
class multi_label_cat_model(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Class name should be MultiLabelCatModel
tests/test_pytorch.py
Outdated
|
||
|
||
def test_pytorch_predictions_multi_label_binary(): | ||
class multi_label_binary_model(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, class names should normally use the CapWords convention, unless they are primarily used as a callable.
Signed-off-by: abigailt <[email protected]>
Including multi-label (output) models.
Note that not all model output types are supported in all modules and methods.