Skip to content

Commit

Permalink
Allow specifying base model as model param in OpenAI API (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Mar 14, 2024
1 parent 9f32e51 commit 1e2a94f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 32 deletions.
14 changes: 12 additions & 2 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ async fn completions_v1(
req_headers: HeaderMap,
req: Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let req = req.0;
let mut req = req.0;
if req.model == MODEL_ID.get().unwrap().as_str() {
// Allow user to specify the base model, but treat it as an empty adapter_id
tracing::info!("Replacing base model {0} with empty adapter_id", req.model);
req.model = "".to_string();
}
let mut gen_req = CompatGenerateRequest::from(req);

// default return_full_text given the pipeline_tag
Expand Down Expand Up @@ -176,7 +181,12 @@ async fn chat_completions_v1(
req_headers: HeaderMap,
req: Json<ChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let req = req.0;
let mut req = req.0;
if req.model == MODEL_ID.get().unwrap().as_str() {
// Allow user to specify the base model, but treat it as an empty adapter_id
tracing::info!("Replacing base model {0} with empty adapter_id", req.model);
req.model = "".to_string();
}
let mut gen_req = CompatGenerateRequest::from(req);

// default return_full_text given the pipeline_tag
Expand Down
47 changes: 17 additions & 30 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,37 +144,24 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token)
adapter_source = S3
try:
if adapter_source == HUB:
# Quick auth check on the repo against the token
HfApi(token=api_token).model_info(adapter_id, revision=None)
# fail fast if ID is not an adapter (i.e. it is a full model)
# TODO(geoffrey): do this for S3– can't do it this way because the
# files are not yet downloaded locally at this point.
config_path = get_config_path(adapter_id, adapter_source)
PeftConfig.from_pretrained(config_path, token=api_token)

_download_weights(
adapter_id, source=adapter_source, api_token=api_token
)

if adapter_source == HUB:
# Quick auth check on the repo against the token
HfApi(token=api_token).model_info(adapter_id, revision=None)

# fail fast if ID is not an adapter (i.e. it is a full model)
# TODO(geoffrey): do this for S3– can't do it this way because the
# files are not yet downloaded locally at this point.
config_path = get_config_path(adapter_id, adapter_source)
PeftConfig.from_pretrained(config_path, token=api_token)

_download_weights(
adapter_id, source=adapter_source, api_token=api_token
)

# Calculate size of adapter to be loaded
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
adapter_bytes += source.get_weight_bytes()
except Exception:
logger.exception("Error when downloading adapter")

if adapter_source != LOCAL:
# delete safetensors files if there is an issue downloading or converting
# the weights to prevent cache hits by subsequent calls
try:
local_path = get_local_dir(adapter_id, adapter_source)
shutil.rmtree(local_path)
except Exception as e:
logger.warning(f"Error cleaning up safetensors files after "
f"download error: {e}\nIgnoring.")
raise
# Calculate size of adapter to be loaded
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
adapter_bytes += source.get_weight_bytes()

adapter_memory_size = self.model.adapter_memory_size()
if adapter_memory_size > 0:
Expand Down

0 comments on commit 1e2a94f

Please sign in to comment.