From 1130979008de90e56a60fec24a3661542b7ce473 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 15 May 2024 17:27:37 +0700 Subject: [PATCH] fix: cohere stream param does not work (#2907) --- .../browser/extensions/engines/helpers/sse.ts | 19 +++++++++++------- .../inference-cohere-extension/src/index.ts | 20 +++++++++++++------ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/core/src/browser/extensions/engines/helpers/sse.ts b/core/src/browser/extensions/engines/helpers/sse.ts index 7ae68142f2..024ced4703 100644 --- a/core/src/browser/extensions/engines/helpers/sse.ts +++ b/core/src/browser/extensions/engines/helpers/sse.ts @@ -68,14 +68,19 @@ export function requestInference( let cachedLines = '' for (const line of lines) { try { - const toParse = cachedLines + line - if (!line.includes('data: [DONE]')) { - const data = JSON.parse(toParse.replace('data: ', '')) - content += data.choices[0]?.delta?.content ?? '' - if (content.startsWith('assistant: ')) { - content = content.replace('assistant: ', '') + if (transformResponse) { + content += transformResponse(line) + subscriber.next(content ?? '') + } else { + const toParse = cachedLines + line + if (!line.includes('data: [DONE]')) { + const data = JSON.parse(toParse.replace('data: ', '')) + content += data.choices[0]?.delta?.content ?? '' + if (content.startsWith('assistant: ')) { + content = content.replace('assistant: ', '') + } + if (content !== '') subscriber.next(content) } - if (content !== '') subscriber.next(content) } } catch { cachedLines = line diff --git a/extensions/inference-cohere-extension/src/index.ts b/extensions/inference-cohere-extension/src/index.ts index 24cc5935bc..dd7f033174 100644 --- a/extensions/inference-cohere-extension/src/index.ts +++ b/extensions/inference-cohere-extension/src/index.ts @@ -26,8 +26,8 @@ enum RoleType { type CoherePayloadType = { chat_history?: Array<{ role: RoleType; message: string }> - message?: string, - preamble?: string, + message?: string + preamble?: string } /** @@ -82,18 +82,24 @@ export default class JanInferenceCohereExtension extends RemoteOAIEngine { if (payload.messages.length === 0) { return {} } + + const { messages, ...params } = payload const convertedData: CoherePayloadType = { + ...params, chat_history: [], message: '', } - payload.messages.forEach((item, index) => { + messages.forEach((item, index) => { // Assign the message of the last item to the `message` property - if (index === payload.messages.length - 1) { + if (index === messages.length - 1) { convertedData.message = item.content as string return } if (item.role === ChatCompletionRole.User) { - convertedData.chat_history.push({ role: RoleType.user, message: item.content as string }) + convertedData.chat_history.push({ + role: RoleType.user, + message: item.content as string, + }) } else if (item.role === ChatCompletionRole.Assistant) { convertedData.chat_history.push({ role: RoleType.chatbot, @@ -106,5 +112,7 @@ export default class JanInferenceCohereExtension extends RemoteOAIEngine { return convertedData } - transformResponse = (data: any) => data.text + transformResponse = (data: any) => { + return typeof data === 'object' ? data.text : JSON.parse(data).text ?? '' + } }