Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More powerful session API #495

Open
Mathnerd314 opened this issue Sep 2, 2023 · 1 comment
Open

More powerful session API #495

Mathnerd314 opened this issue Sep 2, 2023 · 1 comment

Comments

@Mathnerd314
Copy link

Based on a discussion on Discord between me and @borzunov in the webui discussion, I didn't want it to get lost.

So consider a simple program using sessions:

with model.inference_session(max_length=500) as sess:
   output1 = model.generate('[Input 1]', max_new_tokens=2, session=sess, do_sample=True, temperature=0.9, top_p=0.6)
   output2 = model.generate(`[Input 2]`, max_new_tokens=2, session=sess, do_sample=True, temperature=0.9, top_p=0.6)

Currently the way the session API works is that it keeps history, so the first command generates [Output 1] and then the second command generates starting from [Input 1][Output 1][Input 2]. But compared to the usual transformers API this is quite restrictive, it is really only useful for chat-like applications where you can never go back and edit anything.

It would be a more powerful API if instead the .generate() calls acted as though they were unrelated/independent, and then the session managed the reuse logic internally. So for example if you wanted the old behavior you would call model.generate("[Input 1][Output 1][Input 2]") in the second call, but if you didn't you could still do model.generate("[Input 2]"). It's fairly cheap to process a buffer of tokens in Python and analyze it for potential reuse patterns.

As far as the reuse logic, I have developed the outline for a little algorithm that I think will work in most cases.

def do_generation(old,new):
  start =0
  while(old[start] != new[0]):
    start += 1
  len = 0
  while(old[start + len] == new[len]):
    len += 1
  reuse_inference(old[start:start+len])
  reprocess_fresh(new[len:])

This supports three main use cases:

  • prefix - similar to the old one, doing a then ab will reuse the blocks
  • prefix with different suffix - doing ab then ac will reuse a
  • rolling - doing abc then bcd will drop a then reuse bc, which can happen as you get long prompts that exceed the context length

Per Alexander B., this would actually be fairly easy to implement in petals, but currently, it is not yet implemented.

@borzunov
Copy link
Collaborator

Hi @Mathnerd314,

Your suggestions sound reasonable. We'll start with an option to slice inference session (reuse_inference(old[start:end])) - I hope to add it in the nearest releases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants