Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove broken final state loop #874

Merged
merged 3 commits into from May 11, 2024

Conversation

br3no
Copy link
Contributor

@br3no br3no commented May 7, 2024

Fixes #856

The code this PR removes introduces an artificial and erroneous loop transition in every final state that is always traversed, regardless of the generation.

The comment doesn't make sense in my opinion, as the if above just handles exactly this case.

Removing this piece of code fixes the bug that surfaced in the upgrade of outlines in the vLLM integration.

@br3no
Copy link
Contributor Author

br3no commented May 7, 2024

A related matter:

is_finished = is_generation_finished(fsms, fsm_states)

This code is fundamentally broken, in my opinion, because it always stops generation when a final state is reached, regardless of outgoing transitions it may have. Instead, the condition for stopping should be that a stop-token has been generated. Right?

@rlouf
Copy link
Member

rlouf commented May 7, 2024

Is there a minimal reproducing example we could add as a test?

@br3no
Copy link
Contributor Author

br3no commented May 7, 2024

The issue does not show in the transformers integration because of the line I posted in the last comment. So the most minimal example at the moment would be the code provided in #856.

@br3no
Copy link
Contributor Author

br3no commented May 7, 2024

There are some test errors. I believe there is a condition still to be checked if the state does not exist in the transitions table. I’ll invest some time later today.

@rlouf I didn’t run the tests, because I didn’t know how.

@br3no
Copy link
Contributor Author

br3no commented May 7, 2024

I believe this should do it now.

@br3no
Copy link
Contributor Author

br3no commented May 8, 2024

Okay, obviously not.

I will invest some time into this today and hopefully come to a solution.

@br3no
Copy link
Contributor Author

br3no commented May 8, 2024

I've looked into the breaking tests:

____________________________ test_regex_final_state ____________________________

    def test_regex_final_state():
        """Make sure that the FSM stays in the final state as we keep generating"""
    
        class MockTokenizer:
            vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104}
            special_tokens = {"eos"}
            eos_token_id = 104
    
            def convert_token_to_string(self, token):
                return token
    
        regex_str = r"`\n(\.\n)?`\n"
        tokenizer = MockTokenizer()
    
        with pytest.warns(UserWarning):
            fsm = RegexFSM(regex_str, tokenizer)
    
        state = fsm.next_state(state=4, token_id=103)
        assert state == 5
        assert fsm.is_final_state(state)
    
        state = fsm.next_state(state=5, token_id=103)
>       assert state == 5
E       assert -1 == 5

tests/fsm/test_fsm.py:85: AssertionError
____________________________ test_regex_final_state ____________________________

    def test_regex_final_state():
        """Make sure that the FSM stays in the final state as we keep generating"""
    
        class MockTokenizer:
            vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104}
            special_tokens = {"eos"}
            eos_token_id = 104
    
            def convert_token_to_string(self, token):
                return token
    
        regex_str = r"`\n(\.\n)?`\n"
        tokenizer = MockTokenizer()
        fsm = RegexGuide(regex_str, tokenizer)
    
        state = fsm.get_next_state(state=4, token_id=103)
        assert state == 5
        assert fsm.is_final_state(state)
    
        state = fsm.get_next_state(state=5, token_id=103)
>       assert state == 5
E       assert -1 == 5

tests/fsm/test_guide.py:183: AssertionError

I believe these tests are testing an assumption that is fundamentally wrong. Final states can have outbound transitions, including into non-terminal states.
I have attached an svg file with a rendering of the state machine described by the state to token map for the following very simple regular expression:

"(12){1,3}"

state_machine


I'm wondering if the right thing to do would be to remove these tests, or what we would want to test instead.

@br3no
Copy link
Contributor Author

br3no commented May 10, 2024

I have changed the tests to verify that if we are in a final state with no outbound transitions, a new generation will lead to us staying in a final state.

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically the same as #884

Should we leave test_fsm.py as is? Otherwise looks good.

@br3no
Copy link
Contributor Author

br3no commented May 10, 2024

Basically the same as #884

Should we leave test_fsm.py as is? Otherwise looks good.

Do you mean reverting the changes in test_fsm.py? If so, this will break the build. In essence both test_fsm.py and test_guide.py are testing the same thing, since they share the underlying implementation.

@lapp0
Copy link
Contributor

lapp0 commented May 10, 2024

Thanks, looks good to me!

@brandonwillard brandonwillard added bug structured generation Linked to structured generation labels May 10, 2024
@rlouf rlouf merged commit 78852b0 into outlines-dev:main May 11, 2024
5 checks passed
@ekagra-ranjan
Copy link

ekagra-ranjan commented May 13, 2024

Damn! I was facing this issue on Fri and spent a couple of days to finally figure out the solution only to find that this PR existed :)

@ekagra-ranjan
Copy link

@br3no can you pls share how did you generate the FSM plot here? #874 (comment)

@br3no
Copy link
Contributor Author

br3no commented May 14, 2024

@ekagra-ranjan sure. I used graphviz for that. Here's an example:

import outlines
from transformers import AutoTokenizer
from graphviz import Digraph

def draw_state_machine(graph: dict, final_states: set, tokenizer):
    dot = Digraph()

    # Add nodes
    for state in graph:
        if state in final_states:
            dot.node(str(state), str(state), color='salmon', style='filled', fillcolor='salmon')
        else:
            dot.node(str(state), str(state), color='lightblue', style='filled', fillcolor='lightblue')

    # Prepare edge labels by aggregating transitions between the same nodes
    edge_labels = {}
    for state, transitions in graph.items():
        for transition, end_state in transitions.items():
            if end_state not in graph:
                # Add end states not in the state map
                dot.node(str(end_state), str(end_state), color='salmon', style='filled', fillcolor='salmon')
            label = tokenizer.decode(int(transition))
            edge_key = (str(state), str(end_state))
            if edge_key not in edge_labels:
                edge_labels[edge_key] = label
            else:
                # Append new label to existing label, separated by a comma
                edge_labels[edge_key] += ", " + label

    # Add edges with aggregated labels
    for (start_state, end_state), label in edge_labels.items():
        dot.edge(start_state, end_state, label=label)

    # Render and view the graph
    dot.render('state_machine', view=True, format='svg')

tokenizer_zephyr = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-zephyr-1_6b")

regex = r"(12){1,2}"

model = outlines.models.transformers("stabilityai/stablelm-2-zephyr-1_6b")

generator_zephyr = outlines.generate.regex(
    model,
    regex,
)

draw_state_machine(generator_zephyr.fsm.states_to_token_maps, generator_zephyr.fsm.final_states, tokenizer_zephyr)

@@ -193,12 +193,8 @@ def get_next_state(self, state: int, token_id: int) -> int:
The new state of the guide.

"""
if token_id == self.eos_token_id:
if token_id == self.eos_token_id or state not in self.states_to_token_maps:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@br3no I was wondering if we really need the 2nd condition state not in self.states_to_token_maps ? The condition basically checks for states which do not have outgoing edges. But such states would be a part of final states in the FSM and this block of code adds EOS as an edge to such states which makes them have atleast one outgoing edges. Therefore, no states in FSM will be absent in the states_to_token_maps. Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan yes, we do need this second condition.

The block of code you linked to does not add an EOS outbound transition to these states. It only adds transitions to final states which are present in states_to_token_maps. But these states are not present there. states_to_token_subsets.get(state) will return None for these states.

I'm not really knowledgeable about the way Outlines (and interegular) build the state machines out of regexes. The matter of fact is that the states_to_token_maps does not contain all states that are reachable. I have noticed this while debugging the code for some example regexes.

This is not a problem in principle, as these states are considered to be final and states_to_token_subsets.get(state) is None is used all over the code to handle this special case (as in the block you linked to).

I actually believe this could be improved and Outlines would profit from removing this special case that needs to be thought of all over the place and could lead to bugs. But this is, as I said, not a problem in principle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug structured generation Linked to structured generation
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Endless generation bug popped up during migration to Guide in vLLM integration
5 participants