Skip to content

Commit

Permalink
support alibaba_tongyi stream output (#5271)
Browse files Browse the repository at this point in the history
* support alibaba_tongyi stream output

* fix:type error and add a test
  • Loading branch information
sinajia committed May 14, 2024
1 parent 7f9859e commit e35aef9
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 1 deletion.
110 changes: 109 additions & 1 deletion libs/langchain-community/src/chat_models/alibaba_tongyi.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { type CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
BaseChatModel,
type BaseChatModelParams,
Expand All @@ -6,10 +7,12 @@ import {
AIMessage,
type BaseMessage,
ChatMessage,
AIMessageChunk,
} from "@langchain/core/messages";
import { type ChatResult } from "@langchain/core/outputs";
import { type CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { ChatGenerationChunk } from "@langchain/core/outputs";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { IterableReadableStream } from "@langchain/core/utils/stream";

/**
* Type representing the role of a message in the Tongyi chat model.
Expand Down Expand Up @@ -493,6 +496,111 @@ export class ChatAlibabaTongyi
return this.caller.call(makeCompletionRequest);
}

async *_streamResponseChunks(
messages: BaseMessage[],
options?: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const parameters = {
...this.invocationParams(),
stream: true,
incremental_output: true,
};

const messagesMapped: TongyiMessage[] = messages.map((message) => ({
role: messageToTongyiRole(message),
content: message.content as string,
}));

const stream = await this.caller.call(async () =>
this.createTongyiStream(
{
model: this.model,
parameters,
input: {
messages: messagesMapped,
},
},
options?.signal
)
);

for await (const chunk of stream) {
const { text, finish_reason } = chunk.output;
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
generationInfo:
finish_reason === "stop"
? {
finish_reason,
request_id: chunk.request_id,
usage: chunk.usage,
}
: undefined,
});
await runManager?.handleLLMNewToken(text);
}
}

private async *createTongyiStream(
request: ChatCompletionRequest,
signal?: AbortSignal
) {
const response = await fetch(this.apiUrl, {
method: "POST",
headers: {
Authorization: `Bearer ${this.alibabaApiKey}`,
Accept: "text/event-stream",
"Content-Type": "application/json",
},
body: JSON.stringify(request),
signal,
});

if (!response.ok) {
let error;
const responseText = await response.text();
try {
const json = JSON.parse(responseText);
error = new Error(
`Tongyi call failed with status code ${response.status}: ${json.error}`
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
error = new Error(
`Tongyi call failed with status code ${response.status}: ${responseText}`
);
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).response = response;
throw error;
}
if (!response.body) {
throw new Error(
"Could not begin Tongyi stream. Please check the given URL and try again."
);
}
const stream = IterableReadableStream.fromReadableStream(response.body);
const decoder = new TextDecoder();
let extra = "";
for await (const chunk of stream) {
const decoded = extra + decoder.decode(chunk);
const lines = decoded.split("\n");
extra = lines.pop() || "";
for (const line of lines) {
if (!line.startsWith("data:")) {
continue;
}
try {
yield JSON.parse(line.slice("data:".length).trim());
} catch (e) {
console.warn(`Received a non-JSON parseable chunk: ${line}`);
}
}
}
}

_llmType(): string {
return "alibaba_tongyi";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ const runTest = async ({
const res = await chat.invoke(messages);
console.log({ res });

// test streaming call
const stream = await chat.stream(
`Translate "I love programming" into Chinese.`
);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(0);

if (passedConfig.streaming) {
expect(nrNewTokens > 0).toBe(true);
expect(res.text).toBe(streamedCompletion);
Expand Down

0 comments on commit e35aef9

Please sign in to comment.