Skip to content

Commit

Permalink
Fix 1st token latency time (#1091)
Browse files Browse the repository at this point in the history
  • Loading branch information
libinta committed Jun 28, 2024
1 parent 9e1319f commit d73f3c9
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,14 @@ def _greedy_search(
input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
this_peer_finished = unfinished_sequences.max() == 0
hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
time_to_first_token_done = True
import habana_frameworks.torch.hpu as torch_hpu

torch_hpu.synchronize()
hb_gen_time.step()

if (
not model_kwargs.get("pad_done", False)
Expand All @@ -1873,14 +1881,6 @@ def _greedy_search(
# before starting the decode phase.
self._pad_past_key_values(model_kwargs)
model_kwargs["pad_done"] = True
hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
time_to_first_token_done = True
import habana_frameworks.torch.hpu as torch_hpu

torch_hpu.synchronize()
hb_gen_time.step()

if (
model_kwargs.get("use_hpu_graphs", False)
Expand Down Expand Up @@ -2282,6 +2282,14 @@ def _sample(
input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
this_peer_finished = unfinished_sequences.max() == 0
hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
time_to_first_token_done = True
import habana_frameworks.torch.hpu as torch_hpu

torch_hpu.synchronize()
hb_gen_time.step()

if (
not model_kwargs.get("pad_done", False)
Expand All @@ -2293,15 +2301,6 @@ def _sample(
self._pad_past_key_values(model_kwargs)
model_kwargs["pad_done"] = True

hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
time_to_first_token_done = True
import habana_frameworks.torch.hpu as torch_hpu

torch_hpu.synchronize()
hb_gen_time.step()

if (
model_kwargs.get("use_hpu_graphs", False)
and model_kwargs.get("limit_hpu_graphs", False)
Expand Down

0 comments on commit d73f3c9

Please sign in to comment.