Skip to content

Commit

Permalink
Use BOS token as a fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Apr 30, 2024
1 parent 6bc93c7 commit a01cc68
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ def prepare_train_features(examples):
input_ids = tokenized_examples["input_ids"][i]
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0

Expand Down
4 changes: 4 additions & 0 deletions examples/pytorch/question-answering/run_qa_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ def prepare_train_features(examples):
input_ids = tokenized_examples["input_ids"][i]
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
Expand Down Expand Up @@ -539,6 +541,8 @@ def prepare_validation_features(examples):
# Find the CLS token in the input ids.
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ def prepare_train_features(examples):
input_ids = tokenized_examples["input_ids"][i]
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
Expand Down Expand Up @@ -568,6 +570,8 @@ def prepare_validation_features(examples):
# Find the CLS token in the input ids.
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
Expand Down
2 changes: 2 additions & 0 deletions examples/pytorch/question-answering/run_qa_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ def prepare_train_features(examples):
input_ids = tokenized_examples["input_ids"][i]
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0

Expand Down

0 comments on commit a01cc68

Please sign in to comment.