Skip to content

Commit

Permalink
For Issue #1600
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed May 7, 2024
1 parent 3a7cc42 commit 1d1bace
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def run_cli( # for local function:
allow_chat_system_prompt=None,
src_lang=None, tgt_lang=None, concurrency_count=None, save_dir=None, sanitize_bot_response=None,
model_state0=None,
use_auth_token=None,
trust_remote_code=None,
score_model_state0=None,
max_max_new_tokens=None,
is_public=None,
Expand Down
2 changes: 2 additions & 0 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def run_eval( # for local function:
allow_chat_system_prompt=None,
src_lang=None, tgt_lang=None, concurrency_count=None, save_dir=None, sanitize_bot_response=None,
model_state0=None,
use_auth_token=None,
trust_remote_code=None,
score_model_state0=None,
max_max_new_tokens=None,
is_public=None,
Expand Down
6 changes: 6 additions & 0 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,8 @@ def main(

# allow set token directly
use_auth_token = os.environ.get("HUGGING_FACE_HUB_TOKEN", use_auth_token)
if isinstance(use_auth_token, str) and use_auth_token and 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
os.environ['HUGGING_FACE_HUB_TOKEN'] = use_auth_token
allow_upload_to_user_data = bool(
int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data)))))
allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data)))))
Expand Down Expand Up @@ -3854,6 +3856,8 @@ def evaluate(
save_dir=None,
sanitize_bot_response=False,
model_state0=None,
use_auth_token=None,
trust_remote_code=None,
memory_restriction_level=None,
max_max_new_tokens=None,
is_public=None,
Expand Down Expand Up @@ -5170,6 +5174,8 @@ def evaluate(
remove_invalid_values=True,
use_cache=use_cache,
max_new_tokens=max_new_tokens, # unsure if required here
token=use_auth_token,
trust_remote_code=trust_remote_code,
)
if do_sample:
gen_config_kwargs.update(dict(temperature=float(temperature),
Expand Down
10 changes: 8 additions & 2 deletions src/stopping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time

import torch
Expand All @@ -7,7 +8,8 @@
from src.prompter_utils import get_use_chat_template


def update_terminate_responses(terminate_response, tokenizer=None):
def update_terminate_responses(terminate_response, tokenizer=None, trust_remote_code=True):
# FIXME: make trust_remote_code passed in from above, but generation config should be relatively safe
if terminate_response is None:
terminate_response = []
if tokenizer is not None:
Expand All @@ -22,7 +24,11 @@ def update_terminate_responses(terminate_response, tokenizer=None):

if hasattr(tokenizer, 'name_or_path') and hasattr(tokenizer, 'vocab'):
reverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
generate_eos_token_id = GenerationConfig.from_pretrained(tokenizer.name_or_path).eos_token_id
generate_eos_token_id = GenerationConfig.from_pretrained(tokenizer.name_or_path,
token=os.getenv('HUGGING_FACE_HUB_TOKEN'),
trust_remote_code=trust_remote_code,

).eos_token_id
if isinstance(generate_eos_token_id, list):
for eos_token_id in generate_eos_token_id:
terminate_response.extend([reverse_vocab[eos_token_id]])
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "c25144e9cb94ec52c1f98e4e78c59dae14401c40"
__version__ = "3a7cc42211091acc4846311d869e08e725eb865c"

0 comments on commit 1d1bace

Please sign in to comment.