diff --git a/CHANGELOG.md b/CHANGELOG.md index 44288c7..ab6a411 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.2.2] - 2024-05-19 + +* Update the OCC training to use negative and unlabeled samples for training. + ## [0.2.1] - 2024-05-18 * Updates to data loaders. Label column filter can now be a list of integers. @@ -35,7 +39,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Initial release -[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...HEAD +[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.2...HEAD +[0.2.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...v0.2.2 [0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1 [0.2.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.1.0...v0.2.0 [0.1.0]: https://github.com/google-research/spade_anomaly_detection/releases/tag/v0.1.0 diff --git a/spade_anomaly_detection/__init__.py b/spade_anomaly_detection/__init__.py index d2d9893..0cf4ec9 100644 --- a/spade_anomaly_detection/__init__.py +++ b/spade_anomaly_detection/__init__.py @@ -31,4 +31,4 @@ # A new PyPI release will be pushed every time `__version__` is increased. # When changing this, also update the CHANGELOG.md. -__version__ = '0.2.1' +__version__ = '0.2.2' diff --git a/spade_anomaly_detection/performance_test.py b/spade_anomaly_detection/performance_test.py index 4c2943b..1934dac 100644 --- a/spade_anomaly_detection/performance_test.py +++ b/spade_anomaly_detection/performance_test.py @@ -78,6 +78,11 @@ def setUp(self): filter_label_value=self.runner_parameters.unlabeled_data_value, ) self.unlabeled_record_count = len(self.unlabeled_labels) + _, negative_labels = data_loader.load_dataframe( + dataset_name=csv_path, + filter_label_value=self.runner_parameters.negative_data_value, + ) + self.negative_record_count = len(negative_labels) self.occ_fit_batch_size = ( self.unlabeled_record_count // self.runner_parameters.ensemble_count @@ -128,6 +133,7 @@ def test_spade_auc_performance_pnu_single_batch(self): self.mock_get_total_records.side_effect = [ self.total_record_count, self.unlabeled_record_count, + self.negative_record_count, ] runner_object = runner.Runner(self.runner_parameters) @@ -151,6 +157,7 @@ def test_spade_auc_performance_pu_single_batch(self): self.mock_get_total_records.side_effect = [ self.total_record_count, self.unlabeled_record_count, + self.negative_record_count, ] runner_object = runner.Runner(self.runner_parameters) diff --git a/spade_anomaly_detection/runner.py b/spade_anomaly_detection/runner.py index 7508e69..59e20b1 100644 --- a/spade_anomaly_detection/runner.py +++ b/spade_anomaly_detection/runner.py @@ -104,6 +104,31 @@ def __init__(self, runner_parameters: parameters.RunnerParameters): else self.runner_parameters.negative_threshold ) + def _get_record_count_based_on_labels(self, label_value: int) -> int: + """Gets the number of records in the table. + + Args: + label_value: The value of the label to use as the filter for records. + + Returns: + The count of records. + """ + label_record_count_filter = ( + f'{self.runner_parameters.label_col_name} = {label_value}' + ) + if self.runner_parameters.where_statements: + label_record_count_where_statements = [ + self.runner_parameters.where_statements + ] + [label_record_count_filter] + else: + label_record_count_where_statements = [label_record_count_filter] + + label_record_count = self.data_loader.get_query_record_result_length( + input_path=self.runner_parameters.input_bigquery_table_path, + where_statements=label_record_count_where_statements, + ) + return label_record_count + def check_data_tables( self, total_record_count: int, @@ -166,12 +191,13 @@ def check_data_tables( ) def instantiate_and_fit_ensemble( - self, unlabeled_record_count: int + self, unlabeled_record_count: int, negative_record_count: int ) -> occ_ensemble.GmmEnsemble: """Creates and fits an OCC ensemble on the specified input data. Args: unlabeled_record_count: Number of unlabeled records in the table. + negative_record_count: Number of negative records in the table. Returns: A trained one class classifier ensemble. @@ -183,7 +209,8 @@ def instantiate_and_fit_ensemble( negative_threshold=self.runner_parameters.negative_threshold, ) - records_per_occ = unlabeled_record_count // ensemble_object.ensemble_count + training_record_count = unlabeled_record_count + negative_record_count + records_per_occ = training_record_count // ensemble_object.ensemble_count batch_size = records_per_occ // self.runner_parameters.batches_per_model batch_size = np.min([batch_size, self.runner_parameters.max_occ_batch_size]) @@ -195,7 +222,11 @@ def instantiate_and_fit_ensemble( where_statements=self.runner_parameters.where_statements, ignore_columns=self.runner_parameters.ignore_columns, batch_size=batch_size, - label_column_filter_value=self.runner_parameters.unlabeled_data_value, + # Train using negative labeled data and unlabeled data. + label_column_filter_value=[ + self.runner_parameters.unlabeled_data_value, + self.runner_parameters.negative_data_value, + ], ) logging.info('Fitting ensemble.') @@ -527,20 +558,11 @@ def run(self) -> None: where_statements=self.runner_parameters.where_statements, ) - unlabeled_record_count_filter = ( - f'{self.runner_parameters.label_col_name} = ' - f'{self.runner_parameters.unlabeled_data_value}' + unlabeled_record_count = self._get_record_count_based_on_labels( + self.runner_parameters.unlabeled_data_value ) - if self.runner_parameters.where_statements: - unlabeled_record_count_where_statements = [ - self.runner_parameters.where_statements - ] + [unlabeled_record_count_filter] - else: - unlabeled_record_count_where_statements = [unlabeled_record_count_filter] - - unlabeled_record_count = self.data_loader.get_query_record_result_length( - input_path=self.runner_parameters.input_bigquery_table_path, - where_statements=unlabeled_record_count_where_statements, + negative_record_count = self._get_record_count_based_on_labels( + self.runner_parameters.negative_data_value ) self.check_data_tables( @@ -549,7 +571,8 @@ def run(self) -> None: ) ensemble_object = self.instantiate_and_fit_ensemble( - unlabeled_record_count=unlabeled_record_count + unlabeled_record_count=unlabeled_record_count, + negative_record_count=negative_record_count, ) batch_size = ( @@ -615,10 +638,11 @@ def run(self) -> None: ) if not self.runner_parameters.upload_only: + if self.supervised_model_object is None: + raise ValueError('Supervised model was not created and trained.') self.evaluate_model() - if self.supervised_model_object is not None: - self.supervised_model_object.save( - save_location=self.runner_parameters.output_gcs_uri - ) + self.supervised_model_object.save( + save_location=self.runner_parameters.output_gcs_uri + ) logging.info('SPADE training completed.') diff --git a/spade_anomaly_detection/runner_test.py b/spade_anomaly_detection/runner_test.py index c6b8069..c84ca89 100644 --- a/spade_anomaly_detection/runner_test.py +++ b/spade_anomaly_detection/runner_test.py @@ -129,6 +129,7 @@ def _create_mock_datasets(self) -> None: self.per_class_labeled_example_count * 2 ) + self.unlabeled_examples self.total_test_records = self.per_class_labeled_example_count * 2 + self.negative_examples = self.per_class_labeled_example_count * 1 unlabeled_features = np.random.rand( self.unlabeled_examples, num_features @@ -203,6 +204,7 @@ def _create_mock_datasets(self) -> None: self.mock_get_query_record_result_length.side_effect = [ self.all_examples, self.unlabeled_examples, + self.negative_examples, self.total_test_records, ] else: @@ -213,6 +215,7 @@ def _create_mock_datasets(self) -> None: self.mock_get_query_record_result_length.side_effect = [ self.all_examples, self.unlabeled_examples, + self.negative_examples, ] def test_runner_data_loader_no_error(self): @@ -228,8 +231,14 @@ def test_runner_data_loader_no_error(self): label_col_name=self.runner_parameters.label_col_name, where_statements=self.runner_parameters.where_statements, ignore_columns=self.runner_parameters.ignore_columns, - label_column_filter_value=self.runner_parameters.unlabeled_data_value, - batch_size=self.unlabeled_examples + # Verify that both negative and unlabeled samples are used. + label_column_filter_value=[ + self.runner_parameters.unlabeled_data_value, + self.runner_parameters.negative_data_value, + ], + # Verify that batch size is computed with both negative and unlabeled + # sample counts. + batch_size=(self.unlabeled_examples + self.negative_examples) // self.runner_parameters.ensemble_count, ) # Assert that the data loader is also called to fetch all records. @@ -311,7 +320,7 @@ def test_runner_get_record_count_without_where_statement_no_error(self): def test_runner_record_count_raise_error(self): self.runner_parameters.ensemble_count = 10 - self.mock_get_query_record_result_length.side_effect = [5, 0] + self.mock_get_query_record_result_length.side_effect = [5, 0, 1] runner_object = runner.Runner(self.runner_parameters) with self.assertRaisesRegex( @@ -320,7 +329,7 @@ def test_runner_record_count_raise_error(self): runner_object.run() def test_runner_no_records_raise_error(self): - self.mock_get_query_record_result_length.side_effect = [0, 0] + self.mock_get_query_record_result_length.side_effect = [0, 0, 0] runner_object = runner.Runner(self.runner_parameters) with self.assertRaisesRegex( @@ -340,7 +349,7 @@ def _assert_regex_in( def test_record_count_warning_raise(self): # Will raise a warning when there are < 1k samples in the entire dataset. - self.mock_get_query_record_result_length.side_effect = [500, 100] + self.mock_get_query_record_result_length.side_effect = [500, 100, 10] runner_object = runner.Runner(self.runner_parameters) with self.assertLogs() as training_logs: @@ -452,7 +461,7 @@ def test_batch_sizing_no_error(self, mock_split, mock_pseudo_label): def test_batch_size_too_large_throw_error(self): self.runner_parameters.labeling_and_model_training_batch_size = 1000 - self.mock_get_query_record_result_length.side_effect = [100, 5] + self.mock_get_query_record_result_length.side_effect = [100, 5, 10] runner_object = runner.Runner(self.runner_parameters) with self.assertRaisesRegex( @@ -695,6 +704,7 @@ def test_dataset_label_values_positive_and_negative_throws_error(self): self.mock_get_query_record_result_length.side_effect = [ self.all_examples, self.unlabeled_examples, + self.negative_examples, total_test_records, ]