Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Gemini text chat - prevent sending broken messageContent and history #822

Merged
merged 4 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/sharp-knives-ring.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

Improve Gemini message and context preparation
1 change: 1 addition & 0 deletions packages/core/src/Prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ export function messagesToHistoryStr(messages: ChatMessage[]) {
}

export const defaultContextSystemPrompt = ({ context = "" }) => {
if (!context) return "";
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
return `Context information is below.
---------------------
${context}
Expand Down
74 changes: 30 additions & 44 deletions packages/core/src/llm/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class GeminiHelper {
> = {
user: "user",
system: "user",
assistant: "user",
assistant: "model",
memory: "user",
};

Expand All @@ -152,38 +152,26 @@ class GeminiHelper {
};

public static mergeNeighboringSameRoleMessages(
messages: ChatMessage[],
): ChatMessage[] {
// Gemini does not support multiple messages of the same role in a row, so we merge them
const mergedMessages: ChatMessage[] = [];
let i: number = 0;

while (i < messages.length) {
const currentMessage: ChatMessage = messages[i];
// Initialize merged content with current message content
const mergedContent: MessageContent[] = [currentMessage.content];

// Check if the next message exists and has the same role
while (
i + 1 < messages.length &&
this.ROLES_TO_GEMINI[messages[i + 1].role] ===
this.ROLES_TO_GEMINI[currentMessage.role]
) {
i++;
const nextMessage: ChatMessage = messages[i];
mergedContent.push(nextMessage.content);
}

// Create a new ChatMessage object with merged content
const mergedMessage: ChatMessage = {
role: currentMessage.role,
content: mergedContent.join("\n"),
};
mergedMessages.push(mergedMessage);
i++;
}

return mergedMessages;
messages: GeminiMessageContent[],
): GeminiMessageContent[] {
return messages.reduce(
(
result: GeminiMessageContent[],
current: GeminiMessageContent,
index: number,
) => {
if (index > 0 && messages[index - 1].role === current.role) {
result[result.length - 1].parts = [
...result[result.length - 1].parts,
...current.parts,
];
} else {
result.push(current);
}
return result;
},
[],
);
}

public static messageContentToGeminiParts(content: MessageContent): Part[] {
Expand Down Expand Up @@ -214,8 +202,8 @@ class GeminiHelper {
message: ChatMessage,
): GeminiMessageContent {
return {
role: this.ROLES_TO_GEMINI[message.role],
parts: this.messageContentToGeminiParts(message.content),
role: GeminiHelper.ROLES_TO_GEMINI[message.role],
parts: GeminiHelper.messageContentToGeminiParts(message.content),
};
}
}
Expand Down Expand Up @@ -260,22 +248,20 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
chat: ChatSession;
messageContent: Part[];
} {
const { messages } = params;
const mergedMessages =
GeminiHelper.mergeNeighboringSameRoleMessages(messages);
const history = mergedMessages.slice(0, -1);
const nextMessage = mergedMessages[mergedMessages.length - 1];
const messageContent = GeminiHelper.chatMessageToGemini(nextMessage).parts;
const messages = GeminiHelper.mergeNeighboringSameRoleMessages(
params.messages.map(GeminiHelper.chatMessageToGemini),
);

const history = messages.slice(0, -1);

const client = this.session.gemini.getGenerativeModel(this.metadata);

const chat = client.startChat({
history: history.map(GeminiHelper.chatMessageToGemini),
history,
});

return {
chat,
messageContent,
messageContent: messages[messages.length - 1].parts,
};
}

Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export {
Anthropic,
} from "./anthropic.js";
export { FireworksLLM } from "./fireworks.js";
export { GEMINI_MODEL, Gemini } from "./gemini.js";
export { GEMINI_MODEL, Gemini, GeminiSession } from "./gemini.js";
export { Groq } from "./groq.js";
export { HuggingFaceInferenceAPI } from "./huggingface.js";
export {
Expand Down