diff --git a/app.py b/app.py index f688161..8fc3e6b 100644 --- a/app.py +++ b/app.py @@ -195,6 +195,7 @@ def load_lora(lora_name, progress=gr.Progress(track_tqdm=True)): ) self.generate_btn = gr.Button('Generate', variant='primary') + self.cancel_btn = gr.Button('Cancel', variant='primary') with gr.Row(): with gr.Column(): @@ -249,28 +250,37 @@ def load_lora(lora_name, progress=gr.Progress(track_tqdm=True)): def generate( - prompt, - do_sample, - max_new_tokens, - num_beams, - repeat_penalty, - temperature, + prompt, + do_sample, + max_new_tokens, + num_beams, + repeat_penalty, + temperature, top_p, top_k, progress=gr.Progress(track_tqdm=True) ): - return self.trainer.generate( - prompt, - do_sample=do_sample, - max_new_tokens=max_new_tokens, - num_beams=num_beams, - repetition_penalty=repeat_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k - ) + #Iteratively generate tokens until we either emit max_new_tokens or stop getting new output + for i in range(max_new_tokens): + output_this_iteration = self.trainer.generate( + prompt, + do_sample=do_sample, + max_new_tokens=1, + num_beams=num_beams, + repetition_penalty=repeat_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k + ) + #If we have the same output as last iteration, generation is done + if len(prompt) == len(output_this_iteration): + break + + prompt = output_this_iteration + yield output_this_iteration + - self.generate_btn.click( + generate_event = self.generate_btn.click( fn=generate, inputs=[ self.prompt, @@ -285,6 +295,8 @@ def generate( outputs=[self.output] ) + self.cancel_btn.click(fn=None, inputs=None, outputs=None, cancels=[generate_event]) + def layout(self): with gr.Blocks() as demo: with gr.Row(): diff --git a/trainer.py b/trainer.py index d4b20f3..278612c 100644 --- a/trainer.py +++ b/trainer.py @@ -47,6 +47,8 @@ def load_model(self, model_name, force=False, **kwargs): load_in_8bit=True, torch_dtype=torch.float16, ) + #Clear the collection that tracks which adapters are loaded, as they are associated with self.model + self.loras = {} if model_name.startswith('decapoda-research/llama'): self.tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name) @@ -74,7 +76,6 @@ def load_lora(self, lora_name, replace_model=True): if peft_config.base_model_name_or_path != self.model_name: self.load_model(peft_config.base_model_name_or_path) - self.loras = {} assert self.model_name is not None assert self.model is not None