Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use tfdf.builder.CARTBuilder to build/train a decision tree by hand #184

Open
Realvincentyuan opened this issue Jun 24, 2023 · 19 comments
Labels
documentation Improvements or additions to documentation

Comments

@Realvincentyuan
Copy link

Realvincentyuan commented Jun 24, 2023

Expectation

Use the tfdf.builder.CARTBuilder to build a decision tree structure and train it with the literal dataset, and optimize the tree structure per the performance.

The process is like manually replicate the training process using tfdf.keras.CartModel but the benefit is that I can adjust the tree structure per needs, not only focusing on the model performance, which will be helpful if intuitive rules are needed.

Sample code

I tried to use the tfdf.builder.CARTBuilder to build the structure and fit/predict but the results are not as expected as the fitting process does not lead to change of the prediction of leaves.

Below are some sample code with a sample dataset running in Colab

import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import math
import collections


# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")

model_trial_idx = 10

# Create the model builder

model_trial_idx += 1
model_path = f"/tmp/manual_model/{model_trial_idx}"

builder = tfdf.builder.CARTBuilder(
    path=model_path,
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="species", classes=["Adelie", "Non-Adelie"]))


# Create some alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue

builder.add_tree(
    Tree(
        NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="bill_length_mm", type=ColumnType.NUMERICAL),
                threshold=40.0,
                missing_evaluation=False),
            
            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="island",type=ColumnType.CATEGORICAL),
                    mask=["Dream", "Torgersen"],
                    missing_evaluation=False)
                ,pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.2], num_examples=10))
                ,neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.9], num_examples=20))
                ),
            
            neg_child=LeafNode(value=ProbabilityValue(probability=[0.2, 0.8], num_examples=30))
            )
        )
    
    )

builder.close()

manual_model = tf.keras.models.load_model(model_path)

# Convert the pandas dataframe into a tf dataset.

dataset_df['species_binary'] = dataset_df['species'] == 'Adelie'

dataset_tf_2 = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df[['bill_length_mm','island','species_binary']], label="species_binary")


# model compile and fit
manual_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.keras.metrics.BinaryAccuracy(),
                       tf.keras.metrics.FalseNegatives()])


manual_model.fit(dataset_tf_2)

Questions

  • The above code runs without error, but the tree does not reflect the fitting results, the prediction prob and number of samples stay the same, which is very weird, looks like the manual_model is a completely static model.
image
  • I assume that the tfdf.builder.CARTBuilder is used to build a shell, and the prediction of each node can be reflected after fitting/prediction. I am very confused why it requires me to define the value in the leaf and the value remains the same after fitting/prediction, did I miss anything?

  • What is the best practice to use tfdf.builder.CARTBuilder to build a decision tree by hand on earth?

Reference:

@rstz
Copy link
Collaborator

rstz commented Jun 24, 2023

I believe that there could be a fundamental misundestanding.

CART trees are not generated the same way neural networks are; in fact, the two algorithms are fundamentally different.

For neural networks, the (very simplified) process is:

  1. Decide the layout of the network
  2. Feed the examples one-by-one (or in batches) to determine the weights in the network using (some variant of) gradient descent.
  3. Done

For CART trees, the process is (see Wikipedia for much more details

  1. Feed the entire dataset to the CART algorithm at once.
  2. The algorithm builds the "best" tree for the dataset, i.e., the algorithm decides the layout of the tree. There are no weights in the tree, and the algorithm does not use gradient descent.

This means that your assumption that the CARTBuilder is used to build the shell is incorrect. The CARTBuilder is an advanced tool to manually build the full model, that is no longer changed by further training. This can be useful e.g. for experimenting with expert-designed trees and for research on decision trees. However, it is not useful for fitting a model to a dataset.

If you want to fit a model to a dataset, just ignore the CARTBuilder, create a model and fit it directly as explained in the beginner tutorial

@rstz rstz added not_specific_to_tfdf invalid This doesn't seem right and removed not_specific_to_tfdf labels Jun 24, 2023
@Realvincentyuan
Copy link
Author

Realvincentyuan commented Jun 24, 2023

I believe that there could be a fundamental misundestanding.

CART trees are not generated the same way neural networks are; in fact, the two algorithms are fundamentally different.

For neural networks, the (very simplified) process is:

  1. Decide the layout of the network
  2. Feed the examples one-by-one (or in batches) to determine the weights in the network using (some variant of) gradient descent.
  3. Done

For CART trees, the process is (see Wikipedia for much more details

  1. Feed the entire dataset to the CART algorithm at once.
  2. The algorithm builds the "best" tree for the dataset, i.e., the algorithm decides the layout of the tree. There are no weights in the tree, and the algorithm does not use gradient descent.

This means that your assumption that the CARTBuilder is used to build the shell is incorrect. The CARTBuilder is an advanced tool to manually build the full model, that is no longer changed by further training. This can be useful e.g. for experimenting with expert-designed trees and for research on decision trees. However, it is not useful for fitting a model to a dataset.

If you want to fit a model to a dataset, just ignore the CARTBuilder, create a model and fit it directly as explained in the beginner tutorial

Hi @rstz ,

Thank you very much for the explanation, I should have put it this way:

I am aligned with you that the builder is perfect for expert-designed trees and this is what I am looking for, because I expect to build my tree consisting of the features and thresholds under my control, which helps me build intuitive and compliant rules (without using features violating regulations and laws).

To this end, the CARTBuilder does address part of the needs, but what I do not follow is the builder requires adding the probability and number of samples before running any data on the model. How could we know the probability and number of samples before running the model on some dataset?

Also, I do not expect the tree structure to be updated after fitting/prediction, but the probability and samples (which are the results) of each node do not get updated after running the model on some data, this is what I am confused about. When using the CARTBuilder, is in fact a way to simply design the expert-designed trees structure without specifying the probability and number of samples?

Am I missing something?

@rstz
Copy link
Collaborator

rstz commented Jun 26, 2023

Thank you for the clarification!

Probability and number of examples are used by the decision tree to output confidence estimates along with the predictions, as well as e.g. tree pruning algorithms. This can be useful for certain applications, but does not impact "raw" model inference. If the tree is fit using the CART algorithm, those are set automatically based on the statistics of the training dataset.

If you're building the tree manually and have no way of computing these values, you may just set them to arbitrary values (e.g. always probability 1 for the class you want, and number_of_samples=10).

@rstz rstz removed the invalid This doesn't seem right label Jun 26, 2023
@Realvincentyuan
Copy link
Author

If you're building the tree manually and have no way of computing these values, you may just set them to arbitrary values (e.g. always probability 1 for the class you want, and number_of_samples=10).

Hi @rstz,

Thanks for the reply, I want to seek more clarity on the APIs.

Given dataset, mostly we do not know the prob and number of samples before running the model on any data, like you said, I can set them with arbitrary values, but could those values be updated after running the model on some data by any chance? If yes, that would make a lot of sense and very helpful!

@rstz
Copy link
Collaborator

rstz commented Jun 27, 2023

Unfortunately, TF-DF does not offer a specific API for this :(

You're probably able to bootstrap this with a bit of code: Say you have n leaves in your hand-made tree

  1. Build the tree with the CARTBuilder. For each leaf, assign a different n-dimensional unit vector as its probability.
  2. Load the tree you built in Step 1 as a model
  3. Run predict() on your dataset. This gives your d * n - dimensional matrix that tells you exactly which example is mapped to which leaf.
  4. Use the matrix to manually compute the correct leaf probabilities and num_example values
  5. If needed, re-build the tree with the correct leaf probabilities

If you're able to get it to work, we'd be very happy to include this in our examples - just submit a PR!

@Realvincentyuan
Copy link
Author

Realvincentyuan commented Jun 29, 2023

Unfortunately, TF-DF does not offer a specific API for this :(

You're probably able to bootstrap this with a bit of code: Say you have n leaves in your hand-made tree

  1. Build the tree with the CARTBuilder. For each leaf, assign a different n-dimensional unit vector as its probability.
  2. Load the tree you built in Step 1 as a model
  3. Run predict() on your dataset. This gives your d * n - dimensional matrix that tells you exactly which example is mapped to which leaf.
  4. Use the matrix to manually compute the correct leaf probabilities and num_example values
  5. If needed, re-build the tree with the correct leaf probabilities

If you're able to get it to work, we'd be very happy to include this in our examples - just submit a PR!

@rstz

I think what you proprosed makes sense for real business use cases, I am working on this.

On top of that, these a few days I have been thinking of the best practice of this builder, below workflow would make sense:

  1. ML: Build a tree model using the regular tfdf.keras.model, say tfdf.keras.CartModel, it returns the tree structure and the real number of samples and probability
  2. Insert the ML model into the builder: get the tree of the above step and build a builder, something like below:
sample_tree = inspector.extract_tree(tree_idx=0)

# Create the model builder

model_trial_idx = 1
model_trial_idx += 1
model_path = f"/tmp/manual_model/{model_trial_idx}"

builder = tfdf.builder.CARTBuilder(
    path=model_path,
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="species", classes=["Adelie", "Gentoo", "Chinstrap"])
    )

builder.add_tree(sample_tree)
  1. Human expertise: the fine-tune with expertise comes into play - tweak root, nodes and leaves if necessary, then call builder.close()
  2. Activate ML-expertise-fusion model: run this model on the dataset and refresh the prob and number of samples to reflect the literal performance of each node.

This makes more sense than building the builder from scratch, as it works better in general on top of a result based on ML model.

Nonetheless, I went through the APIs, looks like some classes like Tree does not have setter or similar functions, below code snippets threw errors:

sample_tree.root = NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="bill_length_mm", type=ColumnType.NUMERICAL),
                threshold=40.0,
                missing_evaluation=False),

            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="island",type=ColumnType.CATEGORICAL),
                    mask=["Dream", "Torgersen"],
                    missing_evaluation=False)
                ,pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.2], num_examples=10))
                ,neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.9], num_examples=20))
                ),

            neg_child=LeafNode(value=ProbabilityValue(probability=[0.2, 0.8], num_examples=30))
            )

error is:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-35-242d44038491>](https://localhost:8080/#) in <cell line: 1>()
----> 1 sample_tree.root = NonLeafNode(
      2             condition=NumericalHigherThanCondition(
      3                 feature=SimpleColumnSpec(name="bill_length_mm", type=ColumnType.NUMERICAL),
      4                 threshold=40.0,
      5                 missing_evaluation=False),

AttributeError: can't set attribute 'root'

Also builder class does not have APIs to tweak each node, seemingly, something like below in the phase of using builder will also be helpful.

builder.get_node[node_idx] = NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="island",type=ColumnType.CATEGORICAL),
                    mask=["Dream", "Torgersen"],
                    missing_evaluation=False)
                ,pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.2], num_examples=10))
                ,neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.9], num_examples=20))
                )

Perhaps I have missed some APIs, as seems the Python API doc is not exactly aligned with the code. But please advise on my proposed workflow and the questions regarding those setting methods. Appreciate it!

@rstz
Copy link
Collaborator

rstz commented Jun 29, 2023

Hi,

thank you for providing additional details. I looked into the tree implementation and it seems like it's just the tree root that does not have a setter method - other properties (e.g. left_child, right_child, value etc.) can be modified through Python:

sample_tree.root.pos_child = LeafNode(value=ProbabilityValue(probability=[0.2, 0.8], num_examples=30))

We might make the root modifiable for the next TF-DF version, but I don't believe that this will be blocking - since each attribute of the root can be modified.

TF-DF does not offer a tree traversal API such as get_node[node_idx], you would have to implement that on your own.

@Realvincentyuan
Copy link
Author

Realvincentyuan commented Jul 1, 2023

Hi @rstz ,

Thanks for the comments, adding a setter method for the root is necessary in some cases. Also, I tried to use your suggested way to adjust nodes, it worked, too.

Nonetheless, I ran into an error that other people also met when trying to call a SavedModel, I added the comments in the issue, #136, too. Just copied the comments here:

In my case, I am building a model using the builder way - add an existing tree to the builder to build a new model, and I can tweak nodes if necessary, the sample code of my workflow is as below:

model = tfdf.keras.CartModel()
model.fit(x=dataset_tf)


inspector = model.make_inspector()

sample_tree = inspector.extract_tree(tree_idx=0)

# Create some alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue


sample_tree = inspector.extract_tree(tree_idx=0)
print(sample_tree)


# Build a model
model_trial_idx = 1

# Create the model builder

model_trial_idx += 1
model_path = f"/tmp/manual_model/{model_trial_idx}"

builder = tfdf.builder.CARTBuilder(
    path=model_path,
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="species", classes=["Adelie", "Gentoo", "Chinstrap"])
    )


builder.add_tree(sample_tree)
builder.close()

manual_model = tf.keras.models.load_model(model_path)
manual_model.predict(dataset_tf)

The error is as below:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-47-d9c72aa7e6da>](https://localhost:8080/#) in <cell line: 1>()
----> 1 manual_model.predict(dataset_tf)

1 frames
[/usr/local/lib/python3.10/dist-packages/keras/engine/training.py](https://localhost:8080/#) in tf__predict_function(iterator)
     13                 try:
     14                     do_return = True
---> 15                     retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
     16                 except:
     17                     do_return = False

ValueError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2169, in predict_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2155, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2143, in run_step  **
        outputs = model.predict_step(data)
    File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2111, in predict_step
        return self(x, training=False)
    File "/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None

    ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
      Positional arguments (2 total):
        * {'bill_depth_mm': <tf.Tensor 'inputs_2:0' shape=(None,) dtype=float32>,
     'bill_length_mm': <tf.Tensor 'inputs_1:0' shape=(None,) dtype=float32>,
     'body_mass_g': <tf.Tensor 'inputs_4:0' shape=(None,) dtype=float32>,
     'flipper_length_mm': <tf.Tensor 'inputs_3:0' shape=(None,) dtype=float32>,
     'island': <tf.Tensor 'inputs:0' shape=(None,) dtype=string>,
     'sex': <tf.Tensor 'inputs_5:0' shape=(None,) dtype=string>,
     'year': <tf.Tensor 'inputs_6:0' shape=(None,) dtype=int64>}
        * False
      Keyword arguments: {}
    
     Expected these arguments to match one of the following 4 option(s):
    
    Option 1:
      Positional arguments (2 total):
        * {'bill_depth_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='bill_depth_mm'),
     'bill_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='bill_length_mm'),
     'flipper_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='flipper_length_mm'),
     'island': TensorSpec(shape=(None,), dtype=tf.string, name='island')}
        * True
      Keyword arguments: {}
    
    Option 2:
      Positional arguments (2 total):
        * {'bill_depth_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='bill_depth_mm'),
     'bill_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='bill_length_mm'),
     'flipper_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='flipper_length_mm'),
     'island': TensorSpec(shape=(None,), dtype=tf.string, name='island')}
        * False
      Keyword arguments: {}
    
    Option 3:
      Positional arguments (2 total):
        * {'bill_depth_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_bill_depth_mm'),
     'bill_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_bill_length_mm'),
     'flipper_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_flipper_length_mm'),
     'island': TensorSpec(shape=(None,), dtype=tf.string, name='inputs_island')}
        * True
      Keyword arguments: {}
    
    Option 4:
      Positional arguments (2 total):
        * {'bill_depth_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_bill_depth_mm'),
     'bill_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_bill_length_mm'),
     'flipper_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_flipper_length_mm'),
     'island': TensorSpec(shape=(None,), dtype=tf.string, name='inputs_island')}
        * False
      Keyword arguments: {}

Note: the dataset is identical to what is used to train the initial model. So it is very weird that it still ran into error when using the manual model to predict the dataset. This seemes to be issue of TensorFlow and people that used other TF packages also met this issue and the solution is install tf-nightly: tensorflow/tensorflow#35446 (comment)

If this is the case, my version of packages is:

  • 2.12.0 for the TF
  • 1.3.0 for the TFDF

They are compatible but why this issue still remains?

Appendix

@rstz
Copy link
Collaborator

rstz commented Jul 3, 2023

Tl;Dr: You need to manually fix the input signature:

def copy_model_sig(m):
  """
  Copy the model signature to a new model.
  """
  spec = m.save_spec()[0][0]
  return lambda insp: spec
...

builder = tfdf.builder.CARTBuilder(
    ...
    input_signature_example_fn=copy_model_sig(model)
    )

Hi,

For some background, a TF-DF model is actually an Yggdrasil Decision Forests (YDF) model that is wrapped in a Keras model. The builder and the inspector both operate on the YDF model, since Keras itself does not have a concept of trees, forests, etc. Unfortunately, some information gets lost between the two formats. We are actively working on improving this, but it requires a major rewrite of parts of TF-DF / YDF.

Here, the TF-DF model does not know what the correct input signature. Instead it just asks the YDF model for its signature and tries to guess a signature from there. Since YDF generally does not distinguish between different integer representations (everything is a float32) and may choose to ignore features that the model does not use, the signature it guesses often does not match what you fed to it through TF-DF.

Looking at this code, it seems useful to integrate it somewhere in the library or, at least, more explicitly in the documentation - I'll think about that

The full, fixed code is:

model = tfdf.keras.CartModel()
model.fit(x=dataset_tf)

inspector = model.make_inspector()
def copy_model_sig(m):
  """
  Copy the model signature to a new model.
  """
  spec = m.save_spec()[0][0]
  return lambda inspector: spec

sample_tree = inspector.extract_tree(tree_idx=0)

# Create some alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue

sample_tree = inspector.extract_tree(tree_idx=0)
print(sample_tree)

# Build a model
model_trial_idx = 1

# Create the model builder

model_trial_idx += 1
model_path = f"/tmp/manual_model/{model_trial_idx}"

builder = tfdf.builder.CARTBuilder(
    path=model_path,
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="species", classes=["Adelie", "Gentoo", "Chinstrap"]),
    input_signature_example_fn=copy_model_sig(model)
    )

builder.add_tree(sample_tree)
builder.close()

Note that you can also manually modify the signature if necessary (copied from one of the test):

def custom_model_input_signature(
        inspector: inspector_lib.AbstractInspector,
    ) -> Any:
      input_spec = keras.build_default_input_model_signature(inspector)
      # Those features are stored as int64 in the dataset.
      for feature_name in [
          "age",
          "capital_gain",
          "capital_loss",
          "education_num",
          "fnlwgt",
          "hours_per_week",
      ]:
        input_spec[feature_name] = tf.TensorSpec(shape=[None], dtype=tf.int64)
      return input_spec

@Realvincentyuan
Copy link
Author

Realvincentyuan commented Jul 4, 2023

Hi @rstz ,

Thanks for the explanation and workaround, I appreciate your prompt reply these a few days!

I tested and it worked! This thread is a great discussion, there are some work on my side to address my needs:

  • Make some functions to map the results back to the model made by builder to ensure that the number of samples and prob reflect the literal info after running the model on some dataset
  • Make a parsing function to format the output of pretty function, basically, flatten the tree-structure and present rules one by one as below
(flipper_length_mm >= 206.5; miss=False, score=0.5436033606529236)
    ├─(pos)─ (bill_depth_mm >= 17.649999618530273; miss=False, score=0.2061920464038849)
    │        ├─(pos)─ ProbabilityValue([0.3333333333333333, 0.6666666666666666, 0.0],n=6.0) (idx=4)
    │        └─(neg)─ ProbabilityValue([0.0, 0.0, 1.0],n=108.0) (idx=3)
    └─(neg)─ (island in ['Biscoe', 'Torgersen']; miss=True, score=0.23399487137794495)
             ├─(pos)─ ProbabilityValue([0.9767441860465116, 0.0, 0.023255813953488372],n=86.0) (idx=2)
             └─(neg)─ (bill_length_mm >= 42.349998474121094; miss=True, score=0.5646106004714966)
                      ├─(pos)─ ProbabilityValue([0.03278688524590164, 0.9672131147540983, 0.0],n=61.0) (idx=1)
                      └─(neg)─ ProbabilityValue([0.9795918367346939, 0.02040816326530612, 0.0],n=49.0) (idx=0)

to the format below along with some stats of the nodes

flipper_length_mm >= 206.5 and bill_depth_mm >= 17.64999961853027;
 flipper_length_mm >= 206.5 and bill_depth_mm <  17.64999961853027
....

Possible follow-up on the TF-DF side like you mentioned:

  • Add a setter method for Root node in the next a few releases, (in real business use cases, normally we do need to directly replace the root and keep everything else the same for test, btw)
  • Add docs/patches of the model format mismatches

Again, thanks for your help, great discussion! TFDF long live! 🫡🫡🫡

@Realvincentyuan
Copy link
Author

Update on your prior point, @rstz

Run predict() on your dataset. This gives your d * n - dimensional matrix that tells you exactly which example is mapped to which leaf.

After applying input_signature_example_fn=copy_model_sig(model), the predict method runs well, but this only has prob output, and in fact I have little way to see which node each instance in the dataset goes to with this.

I looked through the API, seems predict_get_leaves or call_get_leaves are the ones to call. predict_get_leaves works well with native CartModel.

Nonetheless, the manual model seems to not be able to call predict_get_leaves or call_get_leaves methods,

  • the model created by the builder does not have the predict_get_leaves() method, error is as below. It does not even have warnings when creating the builder, add the sample tree, and closing the builder though
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-38-f1aae9bea4bf>](https://localhost:8080/#) in <cell line: 1>()
----> 1 manual_model.predict_get_leaves(dataset_tf)

AttributeError: 'InferenceCoreModel' object has no attribute 'predict_get_leaves'
  • for call_get_leaves, the error is as below:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-39-de7dfe55a161>](https://localhost:8080/#) in <cell line: 1>()
----> 1 manual_model.call_get_leaves(dataset_tf)

1 frames
[/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/function_deserialization.py](https://localhost:8080/#) in restored_function_body(*args, **kwargs)
    259     """Calls a restored function or raises an error if no matching function."""
    260     if not saved_function.concrete_functions:
--> 261       raise ValueError("Found zero restored functions for caller function.")
    262     # This is the format of function.graph.structured_input_signature. At this
    263     # point, the args and kwargs have already been canonicalized.

ValueError: Found zero restored functions for caller function.

Both errors seem to be due to the building saving and model loading process?

@rstz
Copy link
Collaborator

rstz commented Jul 11, 2023

Hi,

Unfortunately, TF-DF cannot serialize all its methods within the Keras model format. This means that some functions, including predict_get_leaves are not available in saved models, and this includes all the models created with the builder. So, there's not way to use these (we really tried to make it work, but couldn't find a satisfactory way for now).

However, my initial proposal was to first create d artificial output classes (one per leaf) with d-dimensional unit vectors as the probabilities in the builder. This allows a direct mapping from probability to leaf index. Does this work?

@Realvincentyuan
Copy link
Author

Realvincentyuan commented Jul 13, 2023

Hi @rstz ,

I managed to traverse the tree and get node assignment, now I am trying to automate some processes to make some methods to update nodes and refresh the true prob and number of samples programmatically. Though this method likely leads to issues if the probability of prediction in some nodes are identical(say the nodes are pure), which might need additional prob reset.I will bring up pull requests once the test is complete.

One thing I notice that Non-leaf nodes do not have id while leaf nodes do, I do not quite follow why the inconsistency exists? I need to add additional code to identify non-leaf node or reset the index, this is not a huge effort but just curious.

@rstz
Copy link
Collaborator

rstz commented Jul 16, 2023

Hi @Realvincentyuan,

looking forward to your PRs!
I don't remember a specific reason for not not including and id for non-leaf nodes - IIRC, there just wasn't a use case for them...

@Realvincentyuan
Copy link
Author

Hi @Realvincentyuan,

looking forward to your PRs! I don't remember a specific reason for not not including and id for non-leaf nodes - IIRC, there just wasn't a use case for them...

Hi @rstz ,

I have created functions for tree traversal, probability & number of samples reset and probability & number of samples refresh after getting the true results. I did not modify anything regarding the native classes that you built, these are mostly independent helper functions.

I am not sure which part to add the functions for PRs, is it contrib or tools? I will document my functions and examples later and create a PR.

├── configure: Project configuration.
├── documentation: User and developer documentation. Contains the colabs.
├── examples: Collection of usage examples.
├── tensorflow_decision_forests: The library
│   ├── component: Utilities.
│   │   ├── builder: Create models "by hand".
│   │   ├── inspector: Inspection of structure and meta-data of models.
│   │   ├── model_plotter: Plotting of model tree structure.
│   │   ├── inspector: Inspection of structure and meta-data of models.
│   │   ├── py_tree: Representation of a decision tree as a python object.
│   │   └── tuner: TF-DF's own hyper-parameter tuner.
│   ├── contrib: Additional functionality outside the project's main scope.
│   ├── keras: Keras logic. Depends on tensorflow logic.
│   │   └── wrapper: Python code generator for Keras models.
│   │── tensorflow: TensorFlow logic.
│   │   └── ops: Custom C++ ops.
│   │       ├── inference: ... for inference.
│   │       └── training: ... for training.
│   └── test_data: Datasets for unit tests and benchmarks.
├── third_party: Bazel configuration for dependencies.
└── tools: Tools for the management of the project and code.

Please note that I do not get a chance to refresh the values of non-leaf node, as the prediction of a model only looks at leaf node and I can only get the leaf assignment for each instance of input dataset. So by far I could not update the non-leaf node of models built by hand.

I feel like in order to update each non-leaf node, it is required to simulate each split of the tree from root to each leaf, and in my case it is not necessary as I only need the literal performance of each leaf, in the PR, I will not include this feature. Any thoughts on this item?

@rstz
Copy link
Collaborator

rstz commented Jul 19, 2023

Hi, a subdirectory of contrib is probably best for this. Let's discuss the other question on the PR directly.

@Realvincentyuan
Copy link
Author

Realvincentyuan commented Jul 28, 2023

Hi, a subdirectory of contrib is probably best for this. Let's discuss the other question on the PR directly.

Hi @rstz , I firstly added a post regarding those helper functions in my blog: https://vincentyuan.us/build-a-decision-tree-by-hand-with-tensorflow/

Please review this first when you get a chance. As you can see,

  • if every node has an ID, the traversal can be more efficient for non-leaf node
  • Also root needs a setter method

Looks like it is encouraged to add a test script, while I am not familiar with tf.test.TestCase, is that mandatory? Give me some time to go through some tutorials of the tf.test.TestCase if necessary. Otherwise, I might consider contributing to enhance the example post: https://www.tensorflow.org/decision_forests/tutorials/advanced_colab

@rstz
Copy link
Collaborator

rstz commented Sep 11, 2023

Hi, sorry for not getting back to you for a while. The article looks good! I wonder if the best strategy would be to just create an "External examples" file that links directly to your (and other) articles?

@rstz rstz added the documentation Improvements or additions to documentation label Sep 11, 2023
@Realvincentyuan
Copy link
Author

Hi, sorry for not getting back to you for a while. The article looks good! I wonder if the best strategy would be to just create an "External examples" file that links directly to your (and other) articles?

Hi,

No worries, I have not got a chance to make an exhaustive unit test so far. And I am thinking if possible, I will contribute to the tutorials notebook that you guys made, instead of adding an external link as the URL might get outdated. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants