Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tft.compute_and_apply_vocabulary is not robust to RaggedTensor type #265

Open
chrispankow opened this issue Mar 29, 2022 · 2 comments
Open

Comments

@chrispankow
Copy link

Tensorflow version 2.8.0, TFT version 1.7.0.

I am currently working on constructing a module which has some multivalent inputs, as well as a multi-hot label endpoint. Both of these need similar transform and feature engineering: convert a string into tokens then map the tokens into a sequence of integers which are fed to an embedding table. However, the number of tokens in a given example string is not constant, and tft.compute_and_apply_vocabulary seems to be unable to parse the output of tf.string.split. In the context of the full model:

def _preprocess_multivalent_feature(feature, ncats):                             
    raw_values = tf.strings.split(feature, ',')                                  
    coded_values = tft.compute_and_apply_vocabulary(raw_values, num_oov_buckets=1)
    return coded_values

which lands me at (snipped for brevity)

TypeError                                 Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
    548     try:
--> 549       str_values = [compat.as_bytes(x) for x in proto_values]
    550     except TypeError:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py in <listcomp>(.0)
    548     try:
--> 549       str_values = [compat.as_bytes(x) for x in proto_values]
    550     except TypeError:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/compat.py in as_bytes(bytes_or_text, encoding)
     86     raise TypeError('Expected binary or unicode string, got %r' %
---> 87                     (bytes_or_text,))
     88 

TypeError: Expected binary or unicode string, got tf.RaggedTensor(values=tf.RaggedTensor(values=Tensor("StringSplit/StringSplit/StringSplit/StringSplitV2:1", shape=(None,), dtype=string), row_splits=Tensor("StringSplit/StringSplit/StringSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)), row_splits=Tensor("StringSplit/RaggedFromTensor/RaggedFromUniformRowLength/RowPartitionFromUniformRowLength/mul:0", shape=(None,), dtype=int64))

During handling of the above exception, another exception occurred:

[snip]

/app/pipeline/components/transform.py in _preprocess_multivalent_feature(feature, ncats)
     69 def _preprocess_multivalent_feature(feature):
     70     raw_values = tf.strings.split(feature, ',')
---> 71     coded_values = tft.compute_and_apply_vocabulary(raw_values, num_oov_buckets=1)
     72     return coded_values

[snip]

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
    551       raise TypeError("Failed to convert object of type %s to Tensor. "
    552                       "Contents: %s. Consider casting elements to a "
--> 553                       "supported type." % (type(values), values))
    554     tensor_proto.string_val.extend(str_values)
    555     return tensor_proto

TypeError: Failed to convert object of type <class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'> to Tensor. Contents: tf.RaggedTensor(values=tf.RaggedTensor(values=Tensor("StringSplit/StringSplit/StringSplit/StringSplitV2:1", shape=(None,), dtype=string), row_splits=Tensor("StringSplit/StringSplit/StringSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)), row_splits=Tensor("StringSplit/RaggedFromTensor/RaggedFromUniformRowLength/RowPartitionFromUniformRowLength/mul:0", shape=(None,), dtype=int64)). Consider casting elements to a supported type.

Commenting out the tf.string.split (and thus leaving them as whole strings) allows the pipeline execution to continue. Despite efforts, I cannot reproduce this exactly with a smaller example (this is work to scale up the pipeline through TFX, and providing that is out of the scope of the issue here). I am able to produce a RaggedTensor with the output of a similar function in a working example with faked inputs. However, I have a hard time believing that the tensor produced by that example would be useable by the Embedding layer which it is putatively going to be connected to:

Raw data:
[{'x': ['a,b', 'b,c,d', '']}, {'x': ['a,b,c', 'd', 'e,f,g,h,i']}]

Transformed data:
[{'_x$ragged_values': array([3, 0, 0, 2, 1, 9]),
  '_x$row_lengths_1': array([2, 3, 1])},
 {'_x$ragged_values': array([3, 0, 2, 1, 8, 7, 6, 5, 4]),
  '_x$row_lengths_1': array([3, 1, 5])}]

It is highly undesirable, though perhaps acceptable, if there is a way to generate a normal tensor from the ragged one. The change:

return coded_values -> return coded_values.to_tensor()

However, that meets the same problem, as it appears earlier in compute_and_apply_vocabulary.

Any advice is appreciated.

@zoyahav
Copy link
Member

zoyahav commented Apr 1, 2022

Could you please attach the full traceback? (that shows the tensorflow_transform stack)
tft.compute_and_apply_vocabulary should work with RaggedTensors, so we'd like to understand what the issue is.

The transformed data that you've pasted above is the flat representation that's useful when encoding it again as a tf.Example. TFT / TFX_BSL is then able to decode this into a tf.RaggedTensor which can then be converted to sparse/dense/etc.
Does this answer your question? if not, please clarify with a simple example.

@chrispankow
Copy link
Author

Sorry for the delay. Full traceback is appended. The example above was sort of for my own purposes to try and understand the problem. If TFT / TFX is able to handle that as any normal tensor, then that's fine, I was just concerned that if the transformed output had keys which the input training graph didn't know / understand that it could lead to trouble.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
    548     try:
--> 549       str_values = [compat.as_bytes(x) for x in proto_values]
    550     except TypeError:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py in <listcomp>(.0)
    548     try:
--> 549       str_values = [compat.as_bytes(x) for x in proto_values]
    550     except TypeError:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/compat.py in as_bytes(bytes_or_text, encoding)
     86     raise TypeError('Expected binary or unicode string, got %r' %
---> 87                     (bytes_or_text,))
     88 

TypeError: Expected binary or unicode string, got tf.RaggedTensor(values=tf.RaggedTensor(values=Tensor("StringSplit/StringSplit/StringSplit/StringSplitV2:1", shape=(None,), dtype=string), row_splits=Tensor("StringSplit/StringSplit/StringSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)), row_splits=Tensor("StringSplit/RaggedFromTensor/RaggedFromUniformRowLength/RowPartitionFromUniformRowLength/mul:0", shape=(None,), dtype=int64))

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-18-019b4f3f5a7b> in <module>
      7             transform=['train', 'valid', 'test'])
      8 )
----> 9 context.run(transform, enable_cache=False)

/opt/conda/lib/python3.7/site-packages/tfx/orchestration/experimental/interactive/interactive_context.py in run_if_ipython(*args, **kwargs)
     66       # __IPYTHON__ variable is set by IPython, see
     67       # https://ipython.org/ipython-doc/rel-0.10.2/html/interactive/reference.html#embedding-ipython.
---> 68       return fn(*args, **kwargs)
     69     else:
     70       absl.logging.warning(

/opt/conda/lib/python3.7/site-packages/tfx/orchestration/experimental/interactive/interactive_context.py in run(self, component, enable_cache, beam_pipeline_args)
    186         telemetry_utils.LABEL_TFX_RUNNER: runner_label,
    187     }):
--> 188       execution_id = launcher.launch().execution_id
    189 
    190     return execution_result.ExecutionResult(

/opt/conda/lib/python3.7/site-packages/tfx/orchestration/launcher/base_component_launcher.py in launch(self)
    207                          copy.deepcopy(execution_decision.input_dict),
    208                          execution_decision.output_dict,
--> 209                          copy.deepcopy(execution_decision.exec_properties))
    210 
    211     absl.logging.info('Running publisher for %s',

/opt/conda/lib/python3.7/site-packages/tfx/orchestration/launcher/in_process_component_launcher.py in _run_executor(self, execution_id, input_dict, output_dict, exec_properties)
     70     # output_dict can still be changed, specifically properties.
     71     executor.Do(
---> 72         copy.deepcopy(input_dict), output_dict, copy.deepcopy(exec_properties))

/opt/conda/lib/python3.7/site-packages/tfx/components/transform/executor.py in Do(self, input_dict, output_dict, exec_properties)
    484       label_outputs[labels.CACHE_OUTPUT_PATH_LABEL] = cache_output
    485     status_file = 'status_file'  # Unused
--> 486     self.Transform(label_inputs, label_outputs, status_file)
    487     absl.logging.debug('Cleaning up temp path %s on executor success',
    488                        temp_path)

/opt/conda/lib/python3.7/site-packages/tfx/components/transform/executor.py in Transform(***failed resolving arguments***)
    971     # order to fail faster if it fails.
    972     analyze_input_columns = tft.get_analyze_input_columns(
--> 973         preprocessing_fn, typespecs, force_tf_compat_v1=force_tf_compat_v1)
    974 
    975     if (not compute_statistics and not materialize_output_paths and

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/inspect_preprocessing_fn.py in get_analyze_input_columns(preprocessing_fn, specs, force_tf_compat_v1)
     63   graph, structured_inputs, _ = (
     64       impl_helper.trace_preprocessing_function(
---> 65           preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1))
     66 
     67   tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/impl_helper.py in trace_preprocessing_function(preprocessing_fn, input_specs, use_tf_compat_v1, base_temp_dir)
    641   else:
    642     return _trace_preprocessing_fn_v2(preprocessing_fn, input_specs,
--> 643                                       base_temp_dir)
    644 
    645 

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/impl_helper.py in _trace_preprocessing_fn_v2(preprocessing_fn, specs, base_temp_dir)
    607   """Trace TF2 graph for `preprocessing_fn`."""
    608   concrete_fn = get_traced_transform_fn(preprocessing_fn, specs,
--> 609                                         base_temp_dir).get_concrete_function()
    610   return (concrete_fn.graph,
    611           tf2_utils.get_structured_inputs_from_func_graph(concrete_fn.graph),

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1297       ValueError: if this object has not yet been called on concrete values.
   1298     """
-> 1299     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1300     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1301     return concrete

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1203       if self._stateful_fn is None:
   1204         initializers = []
-> 1205         self._initialize(args, kwargs, add_initializers_to=initializers)
   1206         self._initialize_uninitialized_variables(initializers)
   1207 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    724     self._concrete_stateful_fn = (
    725         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 726             *args, **kwds))
    727 
    728     def invalid_creator_scope(*unused_args, **unused_kwds):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2967       args, kwargs = None, None
   2968     with self._lock:
-> 2969       graph_function, _ = self._maybe_define_function(args, kwargs)
   2970     return graph_function
   2971 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3359 
   3360           self._function_cache.missed.add(call_context_key)
-> 3361           graph_function = self._create_graph_function(args, kwargs)
   3362           self._function_cache.primary[cache_key] = graph_function
   3363 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3204             arg_names=arg_names,
   3205             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3206             capture_by_value=self._capture_by_value),
   3207         self._function_attributes,
   3208         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    988         _, original_func = tf_decorator.unwrap(python_func)
    989 
--> 990       func_outputs = python_func(*func_args, **func_kwargs)
    991 
    992       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    632             xla_context.Exit()
    633         else:
--> 634           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    635         return out
    636 

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/impl_helper.py in transform_fn(inputs)
    568     with graph_context.TFGraphContext(
    569         temp_dir=base_temp_dir, evaluated_replacements=tensor_replacement_map):
--> 570       transformed_features = preprocessing_fn(inputs_copy)
    571     # An empty `TENSOR_REPLACEMENTS` collection symbolizes that there is no
    572     # analyzer left for Transform to evaluate. Either if this collection is

/app/pipeline/components/transform.py in preprocessing_fn(inputs)
     17         feature_outputs = [
     18             _preprocess_float_features(inputs, config),
---> 19             _preprocess_categorical_features(inputs, config),
     20             _preprocess_labels(inputs, config, deals),
     21             _add_sample_weight(inputs, config),

/app/pipeline/components/transform.py in _preprocess_categorical_features(inputs, config)
     83         transformed_key = _transformed_name(key)
     84         print(key, tf.sparse.to_dense(inputs[key]))
---> 85         outputs[transformed_key] = _preprocess_multivalent_feature(tf.sparse.to_dense(inputs[key]), 200)
     86 
     87     cat_features = set(config._CATEGORICAL_FEATURES) - set(config._MULTIVALENT_FEATURES)

/app/pipeline/components/transform.py in _preprocess_multivalent_feature(feature, ncats)
     68 def _preprocess_multivalent_feature(feature, ncats):
     69     raw_values = tf.strings.split(feature, ',')
---> 70     coded_values = tft.compute_and_apply_vocabulary(raw_values, num_oov_buckets=1)
     71     return coded_values.to_tensor()
     72     multi_hot = tf.keras.layers.CategoryEncoding(

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/common.py in wrapped_fn(*args, **kwargs)
     78             collection.append(collections.Counter())
     79           collection[0][fn.__name__] += 1
---> 80           return fn(*args, **kwargs)
     81       else:
     82         return fn(*args, **kwargs)

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/mappers.py in compute_and_apply_vocabulary(x, default_value, top_k, frequency_threshold, num_oov_buckets, vocab_filename, weights, labels, use_adjusted_mutual_info, min_diff_from_avg, coverage_top_k, coverage_frequency_threshold, key_fn, fingerprint_shuffle, file_format, name)
   1036         key_fn=key_fn,
   1037         fingerprint_shuffle=fingerprint_shuffle,
-> 1038         file_format=file_format)
   1039     return apply_vocabulary(
   1040         x,

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/common.py in wrapped_fn(*args, **kwargs)
     80           return fn(*args, **kwargs)
     81       else:
---> 82         return fn(*args, **kwargs)
     83 
     84     # We use tf_decorator here so that TF can correctly introspect into

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/analyzers.py in vocabulary(x, top_k, frequency_threshold, vocab_filename, store_frequency, weights, labels, use_adjusted_mutual_info, min_diff_from_avg, coverage_top_k, coverage_frequency_threshold, key_fn, fingerprint_shuffle, file_format, name)
   1852         x=x,
   1853         labels=labels,
-> 1854         weights=weights)
   1855     return _vocabulary_analyzer_nodes(
   1856         analyzer_inputs=analyzer_inputs,

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/analyzers.py in _get_vocabulary_analyzer_inputs(vocab_ordering_type, x, labels, weights)
   1909     return [reduced_batch.unique_x, reduced_batch.summed_weights_per_x]
   1910   else:
-> 1911     reduced_batch = tf_utils.reduce_batch_weighted_counts(x)
   1912     assert reduced_batch.summed_weights_per_x is None
   1913     assert reduced_batch.summed_positive_per_x_and_y is None

/opt/conda/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py in reduce_batch_weighted_counts(x, weights)
    106     # TODO(b/112916494): Always do batch wise reduction once possible.
    107 
--> 108     return ReducedBatchWeightedCounts(tf.reshape(x, [-1]), None, None, None)
    109   # TODO(b/134075780): Revisit expected weights shape when input is sparse.
    110   x, weights = assert_same_shape(x, weights)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in reshape(tensor, shape, name)
    193     A `Tensor`. Has the same type as `tensor`.
    194   """
--> 195   result = gen_array_ops.reshape(tensor, shape, name)
    196   tensor_util.maybe_set_static_shape(result, shape)
    197   return result

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in reshape(tensor, shape, name)
   8376   # Add nodes to the TensorFlow graph.
   8377   _, _, _op, _outputs = _op_def_library._apply_op_helper(
-> 8378         "Reshape", tensor=tensor, shape=shape, name=name)
   8379   _result = _outputs[:]
   8380   if _execute.must_record_gradient():

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
    523         except TypeError as err:
    524           if dtype is None:
--> 525             raise err
    526           else:
    527             raise TypeError(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
    520                 dtype=dtype,
    521                 as_ref=input_arg.is_ref,
--> 522                 preferred_dtype=default_dtype)
    523         except TypeError as err:
    524           if dtype is None:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/profiler/trace.py in wrapped(*args, **kwargs)
    161         with Trace(trace_name, **trace_kwargs):
    162           return func(*args, **kwargs)
--> 163       return func(*args, **kwargs)
    164 
    165     return wrapped

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
   1538 
   1539     if ret is None:
-> 1540       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1541 
   1542     if ret is NotImplemented:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    337                                          as_ref=False):
    338   _ = as_ref
--> 339   return constant(v, dtype=dtype, name=name)
    340 
    341 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name)
    263   """
    264   return _constant_impl(value, dtype, shape, name, verify_shape=False,
--> 265                         allow_broadcast=True)
    266 
    267 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
    281       tensor_util.make_tensor_proto(
    282           value, dtype=dtype, shape=shape, verify_shape=verify_shape,
--> 283           allow_broadcast=allow_broadcast))
    284   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    285   attrs = {"value": tensor_value, "dtype": dtype_value}

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
    551       raise TypeError("Failed to convert object of type %s to Tensor. "
    552                       "Contents: %s. Consider casting elements to a "
--> 553                       "supported type." % (type(values), values))
    554     tensor_proto.string_val.extend(str_values)
    555     return tensor_proto

TypeError: Failed to convert object of type <class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'> to Tensor. Contents: tf.RaggedTensor(values=tf.RaggedTensor(values=Tensor("StringSplit/StringSplit/StringSplit/StringSplitV2:1", shape=(None,), dtype=string), row_splits=Tensor("StringSplit/StringSplit/StringSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)), row_splits=Tensor("StringSplit/RaggedFromTensor/RaggedFromUniformRowLength/RowPartitionFromUniformRowLength/mul:0", shape=(None,), dtype=int64)). Consider casting elements to a supported type.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants