Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference API] Add Azure AI Studio Embeddings and Chat Completion Support #108472

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
dfc9de1
redo after messy merge commit
markjhoy May 9, 2024
f0832eb
cleanups; refactoring; and added a few tests
markjhoy May 9, 2024
0e35525
filter xContent ratelimit; reduce boilerplate code
markjhoy May 9, 2024
3c79a9a
fix checkstyle issue
markjhoy May 9, 2024
b8ca2e2
... and spotlessApply
markjhoy May 9, 2024
26fd405
set lower rate limit 240; rename back files
markjhoy May 11, 2024
05f0ac5
Merge branch 'main' into markjhoy/azure_ai_studio_integration_inferen…
markjhoy May 11, 2024
35dbbad
clean lint
markjhoy May 11, 2024
e984bb3
Merge branch 'main' into markjhoy/azure_ai_studio_integration_inferen…
markjhoy May 13, 2024
ab9831b
fix code and tests after merge
markjhoy May 13, 2024
5534370
change completion temp and top_p to double
markjhoy May 13, 2024
d7d8e36
Merge branch 'main' into markjhoy/azure_ai_studio_integration_inferen…
markjhoy May 14, 2024
b0c218f
clean lint
markjhoy May 14, 2024
2bc99b7
Merge remote-tracking branch 'upstream/main' into markjhoy/azure_ai_s…
markjhoy May 14, 2024
b08b4d7
Merge remote-tracking branch 'upstream/main' into markjhoy/azure_ai_s…
markjhoy May 14, 2024
768d1dc
add default max_new_tokens of 64
markjhoy May 14, 2024
f76c582
Merge remote-tracking branch 'upstream/main' into markjhoy/azure_ai_s…
markjhoy May 14, 2024
f21b075
constrain top_p temperature to 0.0-2.0 range
markjhoy May 14, 2024
46a99ac
Merge branch 'main' into markjhoy/azure_ai_studio_integration_inferen…
markjhoy May 15, 2024
4a0c9ea
remove Snowflake provider; cleanups
markjhoy May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ROLLUP_USAGE = def(8_653_00_0);
public static final TransportVersion SECURITY_ROLE_DESCRIPTION = def(8_654_00_0);
public static final TransportVersion ML_INFERENCE_AZURE_OPENAI_COMPLETIONS = def(8_655_00_0);
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO = def(8_656_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings;
Expand Down Expand Up @@ -67,106 +71,122 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
new NamedWriteableRegistry.Entry(InferenceResults.class, LegacyTextEmbeddingResults.NAME, LegacyTextEmbeddingResults::new)
);

// Inference results
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, SparseEmbeddingResults.NAME, SparseEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingByteResults.NAME, TextEmbeddingByteResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, ChatCompletionResults.NAME, ChatCompletionResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, RankedDocsResults.NAME, RankedDocsResults::new)
);
addInferenceResultsNamedWriteables(namedWriteables);
markjhoy marked this conversation as resolved.
Show resolved Hide resolved
addChunkedInferenceResultsNamedWriteables(namedWriteables);

// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));

// Default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, DefaultSecretSettings.NAME, DefaultSecretSettings::new));

addInternalElserNamedWriteables(namedWriteables);

// Chunked inference results
// Internal TextEmbedding service config
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ErrorChunkedInferenceResults.NAME,
ErrorChunkedInferenceResults::new
ServiceSettings.class,
ElasticsearchInternalServiceSettings.NAME,
ElasticsearchInternalServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedSparseEmbeddingResults.NAME,
ChunkedSparseEmbeddingResults::new
ServiceSettings.class,
MultilingualE5SmallInternalServiceSettings.NAME,
MultilingualE5SmallInternalServiceSettings::new
)
);

addHuggingFaceNamedWriteables(namedWriteables);
addOpenAiNamedWriteables(namedWriteables);
addCohereNamedWriteables(namedWriteables);
addAzureOpenAiNamedWriteables(namedWriteables);
addAzureAiStudioNamedWriteables(namedWriteables);

return namedWriteables;
}

private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingResults.NAME,
ChunkedTextEmbeddingResults::new
ServiceSettings.class,
AzureAiStudioEmbeddingsServiceSettings.NAME,
AzureAiStudioEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingFloatResults.NAME,
ChunkedTextEmbeddingFloatResults::new
TaskSettings.class,
AzureAiStudioEmbeddingsTaskSettings.NAME,
AzureAiStudioEmbeddingsTaskSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingByteResults.NAME,
ChunkedTextEmbeddingByteResults::new
ServiceSettings.class,
AzureAiStudioChatCompletionServiceSettings.NAME,
AzureAiStudioChatCompletionServiceSettings::new
)
);

// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));

// Default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, DefaultSecretSettings.NAME, DefaultSecretSettings::new));

// Internal ELSER config
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, ElserInternalServiceSettings.NAME, ElserInternalServiceSettings::new)
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AzureAiStudioChatCompletionTaskSettings.NAME,
AzureAiStudioChatCompletionTaskSettings::new
)
);
}

private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, ElserMlNodeTaskSettings.NAME, ElserMlNodeTaskSettings::new)
new NamedWriteableRegistry.Entry(
AzureOpenAiSecretSettings.class,
AzureOpenAiSecretSettings.NAME,
AzureOpenAiSecretSettings::new
)
);

// Internal TextEmbedding service config
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticsearchInternalServiceSettings.NAME,
ElasticsearchInternalServiceSettings::new
AzureOpenAiEmbeddingsServiceSettings.NAME,
AzureOpenAiEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
MultilingualE5SmallInternalServiceSettings.NAME,
MultilingualE5SmallInternalServiceSettings::new
TaskSettings.class,
AzureOpenAiEmbeddingsTaskSettings.NAME,
AzureOpenAiEmbeddingsTaskSettings::new
)
);
}

// Hugging Face config
private static void addCohereNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceElserServiceSettings.NAME,
HuggingFaceElserServiceSettings::new
CohereEmbeddingsServiceSettings.NAME,
CohereEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereEmbeddingsTaskSettings.NAME, CohereEmbeddingsTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(SecretSettings.class, HuggingFaceElserSecretSettings.NAME, HuggingFaceElserSecretSettings::new)
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereRerankServiceSettings.NAME, CohereRerankServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereRerankTaskSettings.NAME, CohereRerankTaskSettings::new)
);
}

// OpenAI
private static void addOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
Expand All @@ -191,52 +211,86 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
OpenAiChatCompletionTaskSettings::new
)
);
}

// Cohere
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
);
private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
CohereEmbeddingsServiceSettings.NAME,
CohereEmbeddingsServiceSettings::new
HuggingFaceElserServiceSettings.NAME,
HuggingFaceElserServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereEmbeddingsTaskSettings.NAME, CohereEmbeddingsTaskSettings::new)
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereRerankServiceSettings.NAME, CohereRerankServiceSettings::new)
new NamedWriteableRegistry.Entry(SecretSettings.class, HuggingFaceElserSecretSettings.NAME, HuggingFaceElserSecretSettings::new)
);
}

private static void addInternalElserNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereRerankTaskSettings.NAME, CohereRerankTaskSettings::new)
new NamedWriteableRegistry.Entry(ServiceSettings.class, ElserInternalServiceSettings.NAME, ElserInternalServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, ElserMlNodeTaskSettings.NAME, ElserMlNodeTaskSettings::new)
);
}

// Azure OpenAI
private static void addChunkedInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
AzureOpenAiSecretSettings.class,
AzureOpenAiSecretSettings.NAME,
AzureOpenAiSecretSettings::new
InferenceServiceResults.class,
ErrorChunkedInferenceResults.NAME,
ErrorChunkedInferenceResults::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AzureOpenAiEmbeddingsServiceSettings.NAME,
AzureOpenAiEmbeddingsServiceSettings::new
InferenceServiceResults.class,
ChunkedSparseEmbeddingResults.NAME,
ChunkedSparseEmbeddingResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AzureOpenAiEmbeddingsTaskSettings.NAME,
AzureOpenAiEmbeddingsTaskSettings::new
InferenceServiceResults.class,
ChunkedTextEmbeddingResults.NAME,
ChunkedTextEmbeddingResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingFloatResults.NAME,
ChunkedTextEmbeddingFloatResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingByteResults.NAME,
ChunkedTextEmbeddingByteResults::new
)
);
}

return namedWriteables;
private static void addInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, SparseEmbeddingResults.NAME, SparseEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingByteResults.NAME, TextEmbeddingByteResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, ChatCompletionResults.NAME, ChatCompletionResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, RankedDocsResults.NAME, RankedDocsResults::new)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
Expand Down Expand Up @@ -182,6 +183,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.azureaistudio;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.AzureAiStudioRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

public class AzureAiStudioAction implements ExecutableAction {
protected final Sender sender;
protected final AzureAiStudioRequestManager requestCreator;
protected final String errorMessage;

protected AzureAiStudioAction(Sender sender, AzureAiStudioRequestManager requestCreator, String errorMessage) {
this.sender = sender;
this.requestCreator = requestCreator;
this.errorMessage = errorMessage;
}

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
try {
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);

sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
listener.onFailure(createInternalServerError(e, errorMessage));
}
}
}