Skip to content

Commit

Permalink
Fix llama_v2_7b_16h for torch.jit.trace
Browse files Browse the repository at this point in the history
Original error: Attention using SDPA can not be traced with torch.jit.trace
when no attention_mask is provided. To solve this issue, please either load
your model with the argument attn_implementation="eager" or
pass an attention_mask input when tracing the model.
  • Loading branch information
Thiago Crepaldi committed Jan 17, 2024
1 parent 6a8b941 commit b6f2439
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchbenchmark/util/framework/huggingface/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]):
if class_models[name][2] == "ReformerConfig()" and not config.num_buckets:
# silence "config.num_buckets is not set. Setting config.num_buckets to 128"
config.num_buckets = 128
if name == "llama_v2_7b_16h":
# Workaround for the following error:
# Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided.
# To solve this issue, please either load your model with the argument attn_implementation="eager" or
# pass an attention_mask input when tracing the model.
config._attn_implementation = "eager"
class_ctor = getattr(transformers, class_models[name][3])
kwargs = {}
hugging_face_models_requiring_trust_remote_code = ["hf_Falcon_7b", "hf_MPT_7b_instruct", "phi_1_5", "hf_Yi"]
Expand Down

0 comments on commit b6f2439

Please sign in to comment.