From 275e2ba3a43632b96d2bb448594cb256c9688cbd Mon Sep 17 00:00:00 2001 From: Raj Sinha Date: Sat, 18 May 2024 16:20:53 +0000 Subject: [PATCH] Update data loaders. Now label column filter can be an integer or a list of integers. PiperOrigin-RevId: 635061726 --- CHANGELOG.md | 9 ++++-- pyproject.toml | 3 +- spade_anomaly_detection/__init__.py | 2 +- spade_anomaly_detection/data_loader.py | 31 +++++++++++++++---- spade_anomaly_detection/data_loader_test.py | 8 +++-- .../data_utils/bq_dataset.py | 4 +++ .../data_utils/bq_dataset_test.py | 9 ++++-- 7 files changed, 49 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c33355a..44288c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] -## [0.2.0] - 2022-05-05 +## [0.2.1] - 2024-05-18 + +* Updates to data loaders. Label column filter can now be a list of integers. + +## [0.2.0] - 2024-05-05 * Add PyPi support. Minor reorganization of repository. @@ -31,6 +35,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Initial release -[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...HEAD +[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...HEAD +[0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1 [0.2.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.1.0...v0.2.0 [0.1.0]: https://github.com/google-research/spade_anomaly_detection/releases/tag/v0.1.0 diff --git a/pyproject.toml b/pyproject.toml index d9db643..605b686 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,6 @@ # Project metadata. Available keys are documented at: # https://packaging.python.org/en/latest/specifications/declaring-project-metadata name = "spade_anomaly_detection" -version = "0.2.0" description = "Semi-supervised Pseudo Labeler Anomaly Detection with Ensembling (SPADE) is a semi-supervised anomaly detection method that uses an ensemble of one class classifiers as the pseudo-labelers and supervised classifiers to achieve state of the art results especially on datasets with distribution mismatch between labeled and unlabeled samples." readme = "README.md" requires-python = ">=3.8" @@ -35,7 +34,7 @@ dependencies = [ ] # `version` is automatically set by flit to use `spade_anomaly_detection.__version__` -# dynamic = ["version"] +dynamic = ["version"] [project.urls] homepage = "https://github.com/google-research/spade_anomaly_detection" diff --git a/spade_anomaly_detection/__init__.py b/spade_anomaly_detection/__init__.py index 3b054fc..d2d9893 100644 --- a/spade_anomaly_detection/__init__.py +++ b/spade_anomaly_detection/__init__.py @@ -31,4 +31,4 @@ # A new PyPI release will be pushed every time `__version__` is increased. # When changing this, also update the CHANGELOG.md. -__version__ = '0.2.0' +__version__ = '0.2.1' diff --git a/spade_anomaly_detection/data_loader.py b/spade_anomaly_detection/data_loader.py index a09e8c6..46fa1be 100644 --- a/spade_anomaly_detection/data_loader.py +++ b/spade_anomaly_detection/data_loader.py @@ -328,8 +328,9 @@ def load_tf_dataset_from_bigquery( where_statements: Optional[List[str]] = None, ignore_columns: Optional[Sequence[str]] = None, batch_size: Optional[int] = None, - label_column_filter_value: Optional[int] = None, + label_column_filter_value: Optional[int | list[int]] = None, convert_features_to_float64: bool = False, + page_size: Optional[int] = None, ) -> tf.data.Dataset: """Loads a TensorFlow dataset from a BigQuery Table. @@ -346,10 +347,13 @@ def load_tf_dataset_from_bigquery( dataset is not batched. In this case, when iterating through the dataset, it will yield one record per call instead of a batch of records. - label_column_filter_value: An integer used when filtering the label column - values. No value will result in all data returned from the table. + label_column_filter_value: An integer or list of integers used when + filtering the label column values. No value will result in all data + returned from the table. convert_features_to_float64: Set to True to cast the contents of the features columns to float64. + page_size: the pagination size to use when retrieving data from BigQuery. + A large value can result in fewer BQ calls, hence time savings. Returns: A TensorFlow dataset. @@ -362,7 +366,15 @@ def load_tf_dataset_from_bigquery( where_statements = ( list() if where_statements is None else where_statements ) - where_statements.append(f'{label_col_name} = {label_column_filter_value}') + if isinstance(label_column_filter_value, int): + where_statements.append( + f'{label_col_name} = {label_column_filter_value}' + ) + else: + where_statements.append( + f'CAST({label_col_name} AS INT64) IN ' + f'UNNEST({label_column_filter_value})' + ) if ignore_columns is not None: metadata_builder = feature_metadata.BigQueryMetadataBuilder( @@ -373,8 +385,10 @@ def load_tf_dataset_from_bigquery( ) if where_statements: - metadata_retrieval_options = feature_metadata.MetadataRetrievalOptions( - where_clauses=where_statements + metadata_retrieval_options = ( + feature_metadata.MetadataRetrievalOptions.get_none( + where_clauses=where_statements + ) ) tf_dataset, metadata = bq_dataset.get_dataset_and_metadata_for_table( @@ -384,7 +398,12 @@ def load_tf_dataset_from_bigquery( drop_remainder=True, metadata_options=metadata_retrieval_options, metadata_builder=metadata_builder, + page_size=page_size, ) + options = tf.data.Options() + # Avoid a large warning output by TF Dataset. + options.deterministic = False + tf_dataset = tf_dataset.with_options(options) self.input_feature_metadata = metadata diff --git a/spade_anomaly_detection/data_loader_test.py b/spade_anomaly_detection/data_loader_test.py index 60ce73d..d858c11 100644 --- a/spade_anomaly_detection/data_loader_test.py +++ b/spade_anomaly_detection/data_loader_test.py @@ -341,7 +341,9 @@ def test_load_bigquery_dataset_unlabeled_value(self, metadata_mock): label_column_filter_value=label_column_filter_value, ) - mock_metadata_call_actual = metadata_mock.call_args.kwargs['where_clauses'] + mock_metadata_call_actual = metadata_mock.get_none.call_args.kwargs[ + 'where_clauses' + ] # Ensure that a where statement was not created when we don't pass in label # values. @@ -399,7 +401,7 @@ def test_where_statement_construction_no_error(self, mock_metadata): self.assertListEqual( expected_where_statements, - mock_metadata.call_args.kwargs['where_clauses'], + mock_metadata.get_none.call_args.kwargs['where_clauses'], ) @mock.patch.object( @@ -419,7 +421,7 @@ def test_where_statements_with_label_filter_no_error(self, mock_metadata): self.assertListEqual( where_statements, - mock_metadata.call_args.kwargs['where_clauses'], + mock_metadata.get_none.call_args.kwargs['where_clauses'], ) @mock.patch.object(feature_metadata, 'BigQueryMetadataBuilder', autospec=True) diff --git a/spade_anomaly_detection/data_utils/bq_dataset.py b/spade_anomaly_detection/data_utils/bq_dataset.py index 01b66e2..335bfdd 100644 --- a/spade_anomaly_detection/data_utils/bq_dataset.py +++ b/spade_anomaly_detection/data_utils/bq_dataset.py @@ -721,6 +721,7 @@ def get_dataset_and_metadata_for_table( batch_size: int = 64, with_mask: bool = _WITH_MASK_DEFAULT, drop_remainder: bool = False, + page_size: Optional[int] = None, ) -> Tuple[tf.data.Dataset, feature_metadata.BigQueryTableMetadata]: """Gets the metadata and dataset for a BigQuery table. @@ -735,6 +736,8 @@ def get_dataset_and_metadata_for_table( with_mask: Whether the dataset should be returned with a mask format. For more information see get_bigquery_dataset. drop_remainder: If true no partial batches will be yielded. + page_size: the pagination size to use when retrieving data from BigQuery. A + large value can result in fewer BQ calls, hence time savings. Returns: A tuple of the output dataset and metadata for the specified table. @@ -783,6 +786,7 @@ def get_dataset_and_metadata_for_table( cache_location=None, where_clauses=metadata_options.where_clauses, drop_remainder=drop_remainder, + page_size=page_size, ) return dataset, all_metadata diff --git a/spade_anomaly_detection/data_utils/bq_dataset_test.py b/spade_anomaly_detection/data_utils/bq_dataset_test.py index 360a2bb..05225fb 100644 --- a/spade_anomaly_detection/data_utils/bq_dataset_test.py +++ b/spade_anomaly_detection/data_utils/bq_dataset_test.py @@ -683,6 +683,7 @@ def _create_rand_df(): self.assertIsInstance(output_dataset, tf.data.Dataset) # Loop through the batches and make sure they all match. + epoch = None for epoch in range(1, 3): batch_index = 0 for cur_batch in output_dataset: @@ -760,7 +761,7 @@ def test_get_dataset_and_metadata_for_table_path( mock_bq_storage_client = mock.create_autospec( bigquery_storage.BigQueryReadClient, spec_set=True, instance=True) metadata_options = feature_metadata.MetadataRetrievalOptions.get_all() - batch_sie = 128 + batch_size = 128 with_mask = False output_dataset, output_metadata = ( @@ -769,7 +770,7 @@ def test_get_dataset_and_metadata_for_table_path( bigquery_client=mock_bq_client, bigquery_storage_client=mock_bq_storage_client, metadata_options=metadata_options, - batch_size=batch_sie, + batch_size=batch_size, with_mask=with_mask, ) ) @@ -780,11 +781,12 @@ def test_get_dataset_and_metadata_for_table_path( get_metadata_mock.return_value, mock_bq_client, bqstorage_client=mock_bq_storage_client, - batch_size=batch_sie, + batch_size=batch_size, with_mask=with_mask, cache_location=None, where_clauses=(), drop_remainder=False, + page_size=None, ) self.assertEqual(output_dataset, get_bigquery_dataset_mock.return_value) @@ -832,6 +834,7 @@ def test_get_dataset_and_metadata_for_table_parts_defaults( cache_location=None, where_clauses=(), drop_remainder=False, + page_size=None, ) self.assertEqual(output_dataset, get_bigquery_dataset_mock.return_value)