Skip to content

Commit

Permalink
dbeaver/pro#2748 Completion instructions fix (#29998)
Browse files Browse the repository at this point in the history
* dbeaver/pro#2824 Custom scope performance fix

* dbeaver/pro#2748 Fix completion instructions

* dbeaver/pro#2748 Fix message chunk concatenation

* dbeaver/pro#2748 Process completion response

* dbeaver/pro#2748 Update completion instructions

* dbeaver/pro#2748 Update completion instructions

* dbeaver/pro#2748 Update completion instructions

---------

Co-authored-by: MashaKorax <[email protected]>
  • Loading branch information
ShadelessFox and MashaKorax committed May 15, 2024
1 parent d69b476 commit e97d1de
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,45 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jkiss.dbeaver.ui.editors.sql.ai.model;
package org.jkiss.dbeaver.model.ai;

import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.model.DBPDataSource;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionMessage;
import org.jkiss.dbeaver.model.sql.SQLUtils;
import org.jkiss.dbeaver.runtime.DBWorkbench;

import java.util.ArrayList;
import java.util.List;

sealed public interface MessageChunk {
@NotNull
String toRawString();
// All these ideally should be a part of a given AI engine
public class AITextUtils {
private AITextUtils() {
// prevents instantiation
}

record Text(@NotNull String text) implements MessageChunk {
@NotNull
@Override
public String toRawString() {
return text;
@NotNull
public static String convertToSQL(
@NotNull DAICompletionMessage prompt,
@NotNull MessageChunk[] response,
@Nullable DBPDataSource dataSource
) {
final StringBuilder builder = new StringBuilder();

if (DBWorkbench.getPlatform().getPreferenceStore().getBoolean(AICompletionConstants.AI_INCLUDE_SOURCE_TEXT_IN_QUERY_COMMENT)) {
builder.append(SQLUtils.generateCommentLine(dataSource, prompt.getContent()));
}
}

record Code(@NotNull String text, @NotNull String language) implements MessageChunk {
@NotNull
@Override
public String toRawString() {
return "```" + language + "\n" + text + "\n```";
for (MessageChunk chunk : response) {
if (chunk instanceof MessageChunk.Code code) {
builder.append(code.text()).append(System.lineSeparator());
} else if (chunk instanceof MessageChunk.Text text) {
builder.append(SQLUtils.generateCommentLine(dataSource, text.text()));
}
}

return builder.toString().trim();
}

@NotNull
Expand Down Expand Up @@ -92,5 +106,4 @@ public static MessageChunk[] splitIntoChunks(@NotNull String text) {

return chunks.toArray(MessageChunk[]::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* DBeaver - Universal Database Manager
* Copyright (C) 2010-2024 DBeaver Corp and others
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jkiss.dbeaver.model.ai;

import org.jkiss.code.NotNull;

sealed public interface MessageChunk {
@NotNull
String toRawString();

record Text(@NotNull String text) implements MessageChunk {
@NotNull
@Override
public String toRawString() {
return text;
}
}

record Code(@NotNull String text, @NotNull String language) implements MessageChunk {
@NotNull
@Override
public String toRawString() {
return "```" + language + "\n" + text + "\n```";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ protected List<DAICompletionMessage> filterMessages(List<DAICompletionMessage> m

protected abstract int getMaxTokens();

/**
* Provides instructions for the AI so it can generate more accurate completions
*
* @param chatCompletion if the completion is for the chat mode, or for a single completion request.
*/
@NotNull
protected String getInstructions(boolean chatCompletion) {
return """
You are SQL assistant. You must produce SQL code for given prompt.
You must produce valid SQL statement enclosed with Markdown code block and terminated with semicolon.
All comments MUST be placed before query outside markdown code block.
Be polite.
""";
}

@Nullable
abstract protected String requestCompletion(
@NotNull DBRProgressMonitor monitor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,37 +111,15 @@ public DAICompletionMessage createMetadataMessage(
@NotNull DAICompletionContext context,
@Nullable DBSObjectContainer mainObject,
@NotNull IAIFormatter formatter,
boolean isChatAPI,
int maxRequestTokens,
boolean chatCompletion
@NotNull String instructions,
int maxRequestTokens
) throws DBException {
if (mainObject == null || mainObject.getDataSource() == null) {
throw new DBException("Invalid completion request");
}

final DBCExecutionContext executionContext = context.getExecutionContext();

final StringBuilder sb = new StringBuilder();

if (chatCompletion && isChatAPI) {
sb.append(
"""
You MUST perform SQL completion.
Your query must start with "SELECT" and MUST be enclosed with Markdown code block.
Talk naturally, as if you were talking to a human.
""");
} else if (isChatAPI) {
sb.append(
"""
Perform SQL completion.
Your query must start with "SELECT" and MUST be enclosed with Markdown code block.
Any comments MUST be placed in SQL multiline comment block at start of the query.
AVOID single line comments.
""");
} else {
sb.append("Perform SQL completion. Your query must start with \"SELECT\" and MUST be enclosed with Markdown code block.\n");
}

final StringBuilder sb = new StringBuilder(instructions);
final String extraInstructions = formatter.getExtraInstructions(monitor, mainObject, executionContext);
if (CommonUtils.isNotEmpty(extraInstructions)) {
sb.append(", ").append(extraInstructions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,13 @@ protected String requestCompletion(
final DBCExecutionContext executionContext = context.getExecutionContext();
DBSObjectContainer mainObject = getScopeObject(context, executionContext);

final GPTModel model = getModel();
final DAICompletionMessage metadataMessage = MetadataProcessor.INSTANCE.createMetadataMessage(
monitor,
context,
mainObject,
formatter,
model.isChatAPI(),
getMaxTokens() - AIConstants.MAX_RESPONSE_TOKENS,
chatCompletion
getInstructions(chatCompletion),
getMaxTokens() - AIConstants.MAX_RESPONSE_TOKENS
);

final List<DAICompletionMessage> mergedMessages = new ArrayList<>();
Expand All @@ -186,7 +184,7 @@ protected String requestCompletion(
mainObject,
completionText,
formatter,
model.isChatAPI()
getModel().isChatAPI()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@
import org.jkiss.dbeaver.model.runtime.AbstractJob;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.sql.SQLScriptElement;
import org.jkiss.dbeaver.model.sql.SQLUtils;
import org.jkiss.dbeaver.runtime.DBWorkbench;
import org.jkiss.dbeaver.ui.UIUtils;
import org.jkiss.dbeaver.ui.editors.sql.SQLEditor;
import org.jkiss.dbeaver.ui.editors.sql.ai.AIUIUtils;
import org.jkiss.dbeaver.ui.editors.sql.ai.model.MessageChunk;
import org.jkiss.dbeaver.ui.editors.sql.ai.popup.AISuggestionPopup;
import org.jkiss.dbeaver.ui.editors.sql.ai.preferences.AIPreferencePage;
import org.jkiss.dbeaver.utils.GeneralUtils;
Expand Down Expand Up @@ -187,24 +185,13 @@ private void doAutoCompletion(
}

DAICompletionResponse response = completionResult.get(0);
MessageChunk[] messageChunks = MessageChunk.splitIntoChunks(CommonUtils.notEmpty(response.getResultCompletion()));
MessageChunk[] messageChunks = AITextUtils.splitIntoChunks(CommonUtils.notEmpty(response.getResultCompletion()));

if (messageChunks.length == 0) {
return;
}
StringBuilder completion = new StringBuilder();
if (DBWorkbench.getPlatform().getPreferenceStore().getBoolean(AICompletionConstants.AI_INCLUDE_SOURCE_TEXT_IN_QUERY_COMMENT)) {
completion.append(SQLUtils.generateCommentLine(executionContext.getDataSource(), message.getContent()));
}
for (MessageChunk messageChunk : messageChunks) {
if (messageChunk instanceof MessageChunk.Code code) {
completion.append(code.text());
} else if (messageChunk instanceof MessageChunk.Text text) {
completion.append(SQLUtils.generateCommentLine(executionContext.getDataSource(), text.text()));
}
}

final String finalCompletion = completion.toString();
final String completion = AITextUtils.convertToSQL(message, messageChunks, executionContext.getDataSource());

// Save to history
new AbstractJob("Save smart completion history") {
Expand All @@ -216,7 +203,7 @@ protected IStatus run(DBRProgressMonitor monitor) {
lDataSource,
executionContext,
message.getContent(),
finalCompletion);
completion);
} catch (DBException e) {
return GeneralUtils.makeExceptionStatus(e);
}
Expand All @@ -231,21 +218,19 @@ protected IStatus run(DBRProgressMonitor monitor) {
int offset = ((TextSelection) selection).getOffset();
int length = ((TextSelection) selection).getLength();
SQLScriptElement query = editor.extractQueryAtPos(offset);
String text = completion;
if (query != null) {
offset = query.getOffset();
length = query.getLength();
// Trim trailing semicolon if needed
if (length > 0 && !query.getText().endsWith(";") && !completion.isEmpty()) {
if (completion.charAt(completion.length() - 1) == ';') {
completion.setLength(completion.length() - 1);
if (length > 0 && !query.getText().endsWith(";") && !text.isEmpty()) {
if (text.charAt(text.length() - 1) == ';') {
text = text.substring(0, text.length() - 1);
}
}
}
document.replace(
offset,
length,
completion.toString());
editor.getSelectionProvider().setSelection(new TextSelection(offset + completion.length(), 0));
document.replace(offset, length, text);
editor.getSelectionProvider().setSelection(new TextSelection(offset + text.length(), 0));
} catch (BadLocationException e) {
DBWorkbench.getPlatformUI().showError("Insert SQL", "Error inserting SQL completion in text editor", e);
}
Expand Down

0 comments on commit e97d1de

Please sign in to comment.