Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tft.apply_saved_model raises ValueError when running beam pipeline #231

Open
thisisandreeeee opened this issue Apr 5, 2021 · 5 comments

Comments

@thisisandreeeee
Copy link

thisisandreeeee commented Apr 5, 2021

I would like to ensemble several pre-trained models within a single TF graph. I'd like to understand if this is feasible using TensorFlow Transform, and I am planning to use the tft.apply_saved_model function to calculate some predictions, before exporting the transform_fn to be used in the serving signature of some wrapper model. However, I am encountering a ValueError when attempting to perform inference on a simple toy model, and the stack trace isn't very informative.

Versions

  • TensorFlow: 2.3.0
  • TensorFlow Transform: 0.22.0
  • Beam: 2.21.0

Steps to reproduce

Create toy model

First, I create a toy classification model that takes two float features as input and returns the probability that the predicted label is positive/negative.

import csv
import tempfile
import uuid

import apache_beam as beam
import numpy as np
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow.keras import layers
from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils

FEATURES = ["feature1", "feature2"]
TARGET = "target"

def create_model(features, target):
    inputs = [layers.Input(shape=(1,), name=f, dtype=tf.float32) for f in features]
    x = layers.Concatenate()(inputs)
    x = layers.Dense(64, activation="relu")(x)
    outputs = layers.Dense(1, activation="sigmoid", name=target)(x)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])

    n = 50
    model.fit(
        {f: tf.constant([np.random.uniform() for _ in range(n)]) for f in features},
        {TARGET: tf.constant([np.random.randint(2) for _ in range(n)])},
    )

    save_dir = f"/tmp/{str(uuid.uuid4())}"
    model.save(save_dir)
    return save_dir


model_dir = create_model(FEATURES, TARGET)

Run beam pipeline

Then, I construct a beam pipeline to run tft.apply_saved_model on a sample dataset. To create this dataset:

def inference_data(features, n):
    csv_file = f"/tmp/{str(uuid.uuid4())}.csv"
    with open(csv_file, "w") as f:
        w = csv.writer(f)
        w.writerow(features)
        for _ in range(n):
            w.writerow([np.random.uniform() for _ in features])
    return csv_file

csv_dataset = inference_data(FEATURES, 50)

We can then proceed to write the inference function and execute it:

def inference_fn(inputs):
    outputs = inputs.copy()
    outputs["prediction"] = tft.apply_saved_model(
        model_dir,
        inputs,
        tags=[tf.saved_model.SERVING],
        signature_name="serving_default",
    )
    return outputs

spec = {name: tf.io.FixedLenFeature([], tf.float32) for name in FEATURES}
metadata = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec(spec)
)

with beam.Pipeline() as pipeline:
    with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
        csv_data_coder = tft.coders.CsvCoder(FEATURES, metadata.schema)
        data = (
            pipeline
            | "ReadData"
            >> beam.io.textio.ReadFromText(csv_dataset, skip_header_lines=1)
            | "DecodeData" >> beam.Map(csv_data_coder.decode)
        )
        _, transform_fn = (
            data,
            metadata,
        ) | "AnalyzeAndTransform" >> tft_beam.AnalyzeAndTransformDataset(inference_fn)

Stack trace

When the above pipeline is run, I encounter the following error:

WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:tensorflow:Tensorflow version (2.3.0) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:tensorflow:Tensorflow version (2.3.0) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:tensorflow:Tensorflow version (2.3.0) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:tensorflow:Tensorflow version (2.3.0) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. 
WARNING:tensorflow:From /Users/me/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow_transform/pretrained_models.py:139: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
WARNING:tensorflow:From /Users/me/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow_transform/pretrained_models.py:139: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Restoring parameters from /tmp/4f3c543d-df49-4d2c-9be6-db15fddb4c2b/variables/variables
INFO:tensorflow:Restoring parameters from /tmp/4f3c543d-df49-4d2c-9be6-db15fddb4c2b/variables/variables
WARNING:tensorflow:From /Users/me/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow_transform/pretrained_models.py:198: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
WARNING:tensorflow:From /Users/me/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow_transform/pretrained_models.py:198: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
WARNING:tensorflow:From /Users/me/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow/python/framework/convert_to_constants.py:854: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /Users/me/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow/python/framework/convert_to_constants.py:854: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow/python/framework/importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, producer_op_list)
    496         results = c_api.TF_GraphImportGraphDefWithResults(
--> 497             graph._c_graph, serialized, options)  # pylint: disable=protected-access
    498         results = c_api_util.ScopedTFImportGraphDefResults(results)

InvalidArgumentError: Input 2 of node import/StatefulPartitionedCall was passed float from import/dense_3/kernel:0 incompatible with expected resource.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-16-f7c09b6dbfc7> in <module>
     17             data,
     18             metadata,
---> 19         ) | "AnalyzeAndTransform" >> tft_beam.AnalyzeAndTransformDataset(inference_fn)

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/transforms/ptransform.py in __ror__(self, pvalueish, _unused)
    996 
    997   def __ror__(self, pvalueish, _unused=None):
--> 998     return self.transform.__ror__(pvalueish, self.label)
    999 
   1000   def expand(self, pvalue):

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/transforms/ptransform.py in __ror__(self, left, label)
    560     pvalueish = _SetInputPValues().visit(pvalueish, replacements)
    561     self.pipeline = p
--> 562     result = p.apply(self, pvalueish, label)
    563     if deferred:
    564       return result

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/pipeline.py in apply(self, transform, pvalueish, label)
    585       try:
    586         old_label, transform.label = transform.label, label
--> 587         return self.apply(transform, pvalueish)
    588       finally:
    589         transform.label = old_label

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/pipeline.py in apply(self, transform, pvalueish, label)
    628         transform.type_check_inputs(pvalueish)
    629 
--> 630       pvalueish_result = self.runner.apply(transform, pvalueish, self._options)
    631 
    632       if type_options is not None and type_options.pipeline_type_check:

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/runners/runner.py in apply(self, transform, input, options)
    196       m = getattr(self, 'apply_%s' % cls.__name__, None)
    197       if m:
--> 198         return m(transform, input, options)
    199     raise NotImplementedError(
    200         'Execution of [%s] not implemented in runner %s.' % (transform, self))

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/runners/runner.py in apply_PTransform(self, transform, input, options)
    226   def apply_PTransform(self, transform, input, options):
    227     # The base case of apply is to call the transform's expand.
--> 228     return transform.expand(input)
    229 
    230   def run_transform(self,

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow_transform/beam/impl.py in expand(self, dataset)
   1027     # e.g. caching the values of expensive computations done in AnalyzeDataset.
   1028     transform_fn = (
-> 1029         dataset | 'AnalyzeDataset' >> AnalyzeDataset(self._preprocessing_fn))
   1030 
   1031     if Context.get_use_deep_copy_optimization():

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/transforms/ptransform.py in __ror__(self, pvalueish, _unused)
    996 
    997   def __ror__(self, pvalueish, _unused=None):
--> 998     return self.transform.__ror__(pvalueish, self.label)
    999 
   1000   def expand(self, pvalue):

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/transforms/ptransform.py in __ror__(self, left, label)
    560     pvalueish = _SetInputPValues().visit(pvalueish, replacements)
    561     self.pipeline = p
--> 562     result = p.apply(self, pvalueish, label)
    563     if deferred:
    564       return result

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/pipeline.py in apply(self, transform, pvalueish, label)
    585       try:
    586         old_label, transform.label = transform.label, label
--> 587         return self.apply(transform, pvalueish)
    588       finally:
    589         transform.label = old_label

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/pipeline.py in apply(self, transform, pvalueish, label)
    628         transform.type_check_inputs(pvalueish)
    629 
--> 630       pvalueish_result = self.runner.apply(transform, pvalueish, self._options)
    631 
    632       if type_options is not None and type_options.pipeline_type_check:

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/runners/runner.py in apply(self, transform, input, options)
    196       m = getattr(self, 'apply_%s' % cls.__name__, None)
    197       if m:
--> 198         return m(transform, input, options)
    199     raise NotImplementedError(
    200         'Execution of [%s] not implemented in runner %s.' % (transform, self))

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/apache_beam/runners/runner.py in apply_PTransform(self, transform, input, options)
    226   def apply_PTransform(self, transform, input, options):
    227     # The base case of apply is to call the transform's expand.
--> 228     return transform.expand(input)
    229 
    230   def run_transform(self,

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow_transform/beam/impl.py in expand(self, dataset)
    974     input_values, input_metadata = dataset
    975     result, cache = super(AnalyzeDataset, self).expand((input_values, None,
--> 976                                                         None, input_metadata))
    977     assert not cache
    978     return result

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow_transform/beam/impl.py in expand(self, dataset)
    826         copied_inputs = impl_helper.copy_tensors(input_signature)
    827 
--> 828       output_signature = self._preprocessing_fn(copied_inputs)
    829 
    830     # At this point we check that the preprocessing_fn has at least one

<ipython-input-10-c7be554d3125> in inference_fn(inputs)
      5         inputs,
      6         tags=[tf.saved_model.SERVING],
----> 7         signature_name="serving_default",
      8     )
      9     return outputs

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow_transform/pretrained_models.py in apply_saved_model(model_dir, inputs, tags, signature_name, output_keys_in_signature)
    202       constant_graph_def,
    203       input_map=input_name_to_tensor_map,
--> 204       return_elements=output_tensor_names + loaded_initializer_op_names)
    205   returned_output_tensors = returned_elements[:len(output_tensor_names)]
    206   returned_initializer_ops = returned_elements[len(output_tensor_names):]

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow/python/framework/importer.py in import_graph_def(***failed resolving arguments***)
    403       return_elements=return_elements,
    404       name=name,
--> 405       producer_op_list=producer_op_list)
    406 
    407 

~/.pyenv/versions/miniconda3-4.3.30/envs/my-project/lib/python3.7/site-packages/tensorflow/python/framework/importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, producer_op_list)
    499       except errors.InvalidArgumentError as e:
    500         # Convert to ValueError for backwards compatibility.
--> 501         raise ValueError(str(e))
    502 
    503     # Create _DefinedFunctions for any imported functions.

ValueError: Input 2 of node import/StatefulPartitionedCall was passed float from import/dense_3/kernel:0 incompatible with expected resource.

I'm not too sure what the issue is. I'm guessing it's something silly like an incorrect input signature, but when I try to perform inference manually it works fine.

tf.keras.models.load_model(model_dir).predict(
    {f: tf.constant([np.random.uniform() for _ in range(5)]) for f in FEATURES}
)
@arghyaganguly
Copy link

arghyaganguly commented Apr 5, 2021

@thisisandreeeee , as per link InvalidArgumentError occurs when an operation receives an input tensor that has an invalid value or shape.
In your example, the input shape being passed to "inference_fn" might have caused this error.

@thisisandreeeee
Copy link
Author

thisisandreeeee commented Apr 5, 2021

@arghyaganguly I think you're right, but I can't quite figure out what the input shape should be. When I try to call .predict(..) on the model directly, it seems that the shape should be a dictionary where the keys are feature names, and the values are tensors.

# this works
tf.keras.models.load_model(model_dir).predict(
    {f: tf.constant([np.random.uniform() for _ in range(5)]) for f in FEATURES}
)

This also seems to be the same shape of the inputs that are passed to the inference_fn which, when printed, are:

{'feature1': <tf.Tensor 'inputs/inputs/feature1_copy:0' shape=(None,) dtype=float32>, 'feature2': <tf.Tensor 'inputs/inputs/feature2_copy:0' shape=(None,) dtype=float32>}

The input shapes look the same to me, so I'm unclear as to why there's an InvalidArgumentError.

@zoyahav
Copy link
Member

zoyahav commented Apr 7, 2021

I believe apply_saved_model is incompatible with Keras,
Is this a simplified version of your preprocessing_fn, or are you planning to just run inference for all data? (using TFT may be an overkill in this case)
Could you expand a bit about your use case?

The following works, but is not practical as it would only work with beam's DirectRunner:

from tfx_bsl.public import tfxio

model = tf.keras.models.load_model(model_dir)
def inference_fn(inputs):
    outputs = inputs.copy()
    outputs["prediction"] = model(inputs)
    return outputs

spec = {name: tf.io.FixedLenFeature([], tf.float32) for name in FEATURES}
metadata = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec(spec)
)
csv_tfxio = tfxio.CsvTFXIO(csv_dataset, column_names=FEATURES,
                           schema=metadata.schema, telemetry_descriptors=['test'], skip_header_lines=1)
tensor_adapter_config = csv_tfxio.TensorAdapterConfig()

with beam.Pipeline() as pipeline:
    with tft_beam.Context(temp_dir=tempfile.mkdtemp(), force_tf_compat_v1=False):
        data = (
            pipeline
            | 'TFXIORead' >> csv_tfxio.BeamSource()
        )
        transformed, transform_fn = (
            data,
            tensor_adapter_config,
        ) | "AnalyzeAndTransform" >> tft_beam.AnalyzeAndTransformDataset(inference_fn)
        transformed[0] | 'print' >> beam.Map(print)

Note that I'm using TFXIO because CsvCoder.decode is deprecated in the most recent version of TFT, and I'm also setting force_tf_compat_v1=False since you're running Keras related code in your preprocessing_fn.

@varshaan
Copy link
Contributor

varshaan commented Apr 8, 2021

Followup to Zohar's suggestion above, using tft.make_and_track_object [1] to create the keras model and invoke will allow doing that inside inference_fn and hence allow using other beam runners as well.

Note this only works when TF2 behavior is not disabled and force_tf_compat_v1=False.

def inference_fn(inputs):
    outputs = inputs.copy()
    model = tft.make_and_track_object(lambda: tf.keras.models.load_model(model_dir))
    outputs["prediction"] = model(inputs)
    return outputs

[1] https://www.tensorflow.org/tfx/transform/api_docs/python/tft/make_and_track_object

@arghyaganguly
Copy link

Closing this as there has been no update to the comment thread (awaiting response from the user)lately.Please feel free to reopen based on above comment trace.Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants