Skip to content

Commit

Permalink
Merge pull request #16 from google-research/test_635124847
Browse files Browse the repository at this point in the history
Update the OCC training to use negative and unlabeled samples for training.
  • Loading branch information
raj-sinha committed May 20, 2024
2 parents 903affb + edbf26c commit 71f0727
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 29 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.2.2] - 2024-05-19

* Update the OCC training to use negative and unlabeled samples for training.

## [0.2.1] - 2024-05-18

* Updates to data loaders. Label column filter can now be a list of integers.
Expand All @@ -35,7 +39,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Initial release

[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...HEAD
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.2...HEAD
[0.2.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...v0.2.2
[0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1
[0.2.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.1.0...v0.2.0
[0.1.0]: https://github.com/google-research/spade_anomaly_detection/releases/tag/v0.1.0
2 changes: 1 addition & 1 deletion spade_anomaly_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@

# A new PyPI release will be pushed every time `__version__` is increased.
# When changing this, also update the CHANGELOG.md.
__version__ = '0.2.1'
__version__ = '0.2.2'
7 changes: 7 additions & 0 deletions spade_anomaly_detection/performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def setUp(self):
filter_label_value=self.runner_parameters.unlabeled_data_value,
)
self.unlabeled_record_count = len(self.unlabeled_labels)
_, negative_labels = data_loader.load_dataframe(
dataset_name=csv_path,
filter_label_value=self.runner_parameters.negative_data_value,
)
self.negative_record_count = len(negative_labels)

self.occ_fit_batch_size = (
self.unlabeled_record_count // self.runner_parameters.ensemble_count
Expand Down Expand Up @@ -128,6 +133,7 @@ def test_spade_auc_performance_pnu_single_batch(self):
self.mock_get_total_records.side_effect = [
self.total_record_count,
self.unlabeled_record_count,
self.negative_record_count,
]

runner_object = runner.Runner(self.runner_parameters)
Expand All @@ -151,6 +157,7 @@ def test_spade_auc_performance_pu_single_batch(self):
self.mock_get_total_records.side_effect = [
self.total_record_count,
self.unlabeled_record_count,
self.negative_record_count,
]

runner_object = runner.Runner(self.runner_parameters)
Expand Down
66 changes: 45 additions & 21 deletions spade_anomaly_detection/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
else self.runner_parameters.negative_threshold
)

def _get_record_count_based_on_labels(self, label_value: int) -> int:
"""Gets the number of records in the table.
Args:
label_value: The value of the label to use as the filter for records.
Returns:
The count of records.
"""
label_record_count_filter = (
f'{self.runner_parameters.label_col_name} = {label_value}'
)
if self.runner_parameters.where_statements:
label_record_count_where_statements = [
self.runner_parameters.where_statements
] + [label_record_count_filter]
else:
label_record_count_where_statements = [label_record_count_filter]

label_record_count = self.data_loader.get_query_record_result_length(
input_path=self.runner_parameters.input_bigquery_table_path,
where_statements=label_record_count_where_statements,
)
return label_record_count

def check_data_tables(
self,
total_record_count: int,
Expand Down Expand Up @@ -166,12 +191,13 @@ def check_data_tables(
)

def instantiate_and_fit_ensemble(
self, unlabeled_record_count: int
self, unlabeled_record_count: int, negative_record_count: int
) -> occ_ensemble.GmmEnsemble:
"""Creates and fits an OCC ensemble on the specified input data.
Args:
unlabeled_record_count: Number of unlabeled records in the table.
negative_record_count: Number of negative records in the table.
Returns:
A trained one class classifier ensemble.
Expand All @@ -183,7 +209,8 @@ def instantiate_and_fit_ensemble(
negative_threshold=self.runner_parameters.negative_threshold,
)

records_per_occ = unlabeled_record_count // ensemble_object.ensemble_count
training_record_count = unlabeled_record_count + negative_record_count
records_per_occ = training_record_count // ensemble_object.ensemble_count
batch_size = records_per_occ // self.runner_parameters.batches_per_model
batch_size = np.min([batch_size, self.runner_parameters.max_occ_batch_size])

Expand All @@ -195,7 +222,11 @@ def instantiate_and_fit_ensemble(
where_statements=self.runner_parameters.where_statements,
ignore_columns=self.runner_parameters.ignore_columns,
batch_size=batch_size,
label_column_filter_value=self.runner_parameters.unlabeled_data_value,
# Train using negative labeled data and unlabeled data.
label_column_filter_value=[
self.runner_parameters.unlabeled_data_value,
self.runner_parameters.negative_data_value,
],
)

logging.info('Fitting ensemble.')
Expand Down Expand Up @@ -527,20 +558,11 @@ def run(self) -> None:
where_statements=self.runner_parameters.where_statements,
)

unlabeled_record_count_filter = (
f'{self.runner_parameters.label_col_name} = '
f'{self.runner_parameters.unlabeled_data_value}'
unlabeled_record_count = self._get_record_count_based_on_labels(
self.runner_parameters.unlabeled_data_value
)
if self.runner_parameters.where_statements:
unlabeled_record_count_where_statements = [
self.runner_parameters.where_statements
] + [unlabeled_record_count_filter]
else:
unlabeled_record_count_where_statements = [unlabeled_record_count_filter]

unlabeled_record_count = self.data_loader.get_query_record_result_length(
input_path=self.runner_parameters.input_bigquery_table_path,
where_statements=unlabeled_record_count_where_statements,
negative_record_count = self._get_record_count_based_on_labels(
self.runner_parameters.negative_data_value
)

self.check_data_tables(
Expand All @@ -549,7 +571,8 @@ def run(self) -> None:
)

ensemble_object = self.instantiate_and_fit_ensemble(
unlabeled_record_count=unlabeled_record_count
unlabeled_record_count=unlabeled_record_count,
negative_record_count=negative_record_count,
)

batch_size = (
Expand Down Expand Up @@ -615,10 +638,11 @@ def run(self) -> None:
)

if not self.runner_parameters.upload_only:
if self.supervised_model_object is None:
raise ValueError('Supervised model was not created and trained.')
self.evaluate_model()
if self.supervised_model_object is not None:
self.supervised_model_object.save(
save_location=self.runner_parameters.output_gcs_uri
)
self.supervised_model_object.save(
save_location=self.runner_parameters.output_gcs_uri
)

logging.info('SPADE training completed.')
22 changes: 16 additions & 6 deletions spade_anomaly_detection/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _create_mock_datasets(self) -> None:
self.per_class_labeled_example_count * 2
) + self.unlabeled_examples
self.total_test_records = self.per_class_labeled_example_count * 2
self.negative_examples = self.per_class_labeled_example_count * 1

unlabeled_features = np.random.rand(
self.unlabeled_examples, num_features
Expand Down Expand Up @@ -203,6 +204,7 @@ def _create_mock_datasets(self) -> None:
self.mock_get_query_record_result_length.side_effect = [
self.all_examples,
self.unlabeled_examples,
self.negative_examples,
self.total_test_records,
]
else:
Expand All @@ -213,6 +215,7 @@ def _create_mock_datasets(self) -> None:
self.mock_get_query_record_result_length.side_effect = [
self.all_examples,
self.unlabeled_examples,
self.negative_examples,
]

def test_runner_data_loader_no_error(self):
Expand All @@ -228,8 +231,14 @@ def test_runner_data_loader_no_error(self):
label_col_name=self.runner_parameters.label_col_name,
where_statements=self.runner_parameters.where_statements,
ignore_columns=self.runner_parameters.ignore_columns,
label_column_filter_value=self.runner_parameters.unlabeled_data_value,
batch_size=self.unlabeled_examples
# Verify that both negative and unlabeled samples are used.
label_column_filter_value=[
self.runner_parameters.unlabeled_data_value,
self.runner_parameters.negative_data_value,
],
# Verify that batch size is computed with both negative and unlabeled
# sample counts.
batch_size=(self.unlabeled_examples + self.negative_examples)
// self.runner_parameters.ensemble_count,
)
# Assert that the data loader is also called to fetch all records.
Expand Down Expand Up @@ -311,7 +320,7 @@ def test_runner_get_record_count_without_where_statement_no_error(self):

def test_runner_record_count_raise_error(self):
self.runner_parameters.ensemble_count = 10
self.mock_get_query_record_result_length.side_effect = [5, 0]
self.mock_get_query_record_result_length.side_effect = [5, 0, 1]
runner_object = runner.Runner(self.runner_parameters)

with self.assertRaisesRegex(
Expand All @@ -320,7 +329,7 @@ def test_runner_record_count_raise_error(self):
runner_object.run()

def test_runner_no_records_raise_error(self):
self.mock_get_query_record_result_length.side_effect = [0, 0]
self.mock_get_query_record_result_length.side_effect = [0, 0, 0]
runner_object = runner.Runner(self.runner_parameters)

with self.assertRaisesRegex(
Expand All @@ -340,7 +349,7 @@ def _assert_regex_in(

def test_record_count_warning_raise(self):
# Will raise a warning when there are < 1k samples in the entire dataset.
self.mock_get_query_record_result_length.side_effect = [500, 100]
self.mock_get_query_record_result_length.side_effect = [500, 100, 10]
runner_object = runner.Runner(self.runner_parameters)

with self.assertLogs() as training_logs:
Expand Down Expand Up @@ -452,7 +461,7 @@ def test_batch_sizing_no_error(self, mock_split, mock_pseudo_label):

def test_batch_size_too_large_throw_error(self):
self.runner_parameters.labeling_and_model_training_batch_size = 1000
self.mock_get_query_record_result_length.side_effect = [100, 5]
self.mock_get_query_record_result_length.side_effect = [100, 5, 10]
runner_object = runner.Runner(self.runner_parameters)

with self.assertRaisesRegex(
Expand Down Expand Up @@ -695,6 +704,7 @@ def test_dataset_label_values_positive_and_negative_throws_error(self):
self.mock_get_query_record_result_length.side_effect = [
self.all_examples,
self.unlabeled_examples,
self.negative_examples,
total_test_records,
]

Expand Down

0 comments on commit 71f0727

Please sign in to comment.