diff --git a/core/src/core.ts b/core/src/browser/core.ts similarity index 99% rename from core/src/core.ts rename to core/src/browser/core.ts index 32244e7846..6bbae7c856 100644 --- a/core/src/core.ts +++ b/core/src/browser/core.ts @@ -1,4 +1,4 @@ -import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from './types' +import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from '../types' /** * Execute a extension module function in main process diff --git a/core/src/events.ts b/core/src/browser/events.ts similarity index 100% rename from core/src/events.ts rename to core/src/browser/events.ts diff --git a/core/src/extension.ts b/core/src/browser/extension.ts similarity index 100% rename from core/src/extension.ts rename to core/src/browser/extension.ts diff --git a/core/src/extensions/assistant.ts b/core/src/browser/extensions/assistant.ts similarity index 90% rename from core/src/extensions/assistant.ts rename to core/src/browser/extensions/assistant.ts index 5c3114f41b..d025c67868 100644 --- a/core/src/extensions/assistant.ts +++ b/core/src/browser/extensions/assistant.ts @@ -1,4 +1,4 @@ -import { Assistant, AssistantInterface } from '../index' +import { Assistant, AssistantInterface } from '../../types' import { BaseExtension, ExtensionTypeEnum } from '../extension' /** diff --git a/core/src/extensions/conversational.ts b/core/src/browser/extensions/conversational.ts similarity index 97% rename from core/src/extensions/conversational.ts rename to core/src/browser/extensions/conversational.ts index a49a4e6895..ec53fbbbf9 100644 --- a/core/src/extensions/conversational.ts +++ b/core/src/browser/extensions/conversational.ts @@ -1,4 +1,4 @@ -import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../index' +import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types' import { BaseExtension, ExtensionTypeEnum } from '../extension' /** diff --git a/core/src/extensions/ai-engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts similarity index 56% rename from core/src/extensions/ai-engines/AIEngine.ts rename to core/src/browser/extensions/engines/AIEngine.ts index 8af89f3365..c4f8168297 100644 --- a/core/src/extensions/ai-engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core' import { events } from '../../events' import { BaseExtension } from '../../extension' import { fs } from '../../fs' -import { Model, ModelEvent } from '../../types' +import { MessageRequest, Model, ModelEvent } from '../../../types' +import { EngineManager } from './EngineManager' /** * Base AIEngine @@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types' export abstract class AIEngine extends BaseExtension { // The inference engine abstract provider: string - // The model folder - modelFolder: string = 'models' + /** + * On extension load, subscribe to events. + */ + override onLoad() { + this.registerEngine() + + events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) + + this.prePopulateModels() + } + + /** + * Defines models + */ models(): Promise { return Promise.resolve([]) } /** - * On extension load, subscribe to events. + * Registers AI Engines */ - onLoad() { - this.prePopulateModels() + registerEngine() { + EngineManager.instance().register(this) } + /** + * Loads the model. + */ + async loadModel(model: Model): Promise { + if (model.engine.toString() !== this.provider) return Promise.resolve() + events.emit(ModelEvent.OnModelReady, model) + return Promise.resolve() + } + /** + * Stops the model. + */ + async unloadModel(model?: Model): Promise { + if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve() + events.emit(ModelEvent.OnModelStopped, model ?? {}) + return Promise.resolve() + } + + /* + * Inference request + */ + inference(data: MessageRequest) {} + + /** + * Stop inference + */ + stopInference() {} + /** * Pre-populate models to App Data Folder */ prePopulateModels(): Promise { + const modelFolder = 'models' return this.models().then((models) => { const prePoluateOperations = models.map((model) => getJanDataFolderPath() .then((janDataFolder) => // Attempt to create the model folder - joinPath([janDataFolder, this.modelFolder, model.id]).then((path) => + joinPath([janDataFolder, modelFolder, model.id]).then((path) => fs .mkdir(path) .catch() diff --git a/core/src/browser/extensions/engines/EngineManager.ts b/core/src/browser/extensions/engines/EngineManager.ts new file mode 100644 index 0000000000..2980c5c65e --- /dev/null +++ b/core/src/browser/extensions/engines/EngineManager.ts @@ -0,0 +1,32 @@ +import { AIEngine } from './AIEngine' + +/** + * Manages the registration and retrieval of inference engines. + */ +export class EngineManager { + public engines = new Map() + + /** + * Registers an engine. + * @param engine - The engine to register. + */ + register(engine: T) { + this.engines.set(engine.provider, engine) + } + + /** + * Retrieves a engine by provider. + * @param provider - The name of the engine to retrieve. + * @returns The engine, if found. + */ + get(provider: string): T | undefined { + return this.engines.get(provider) as T | undefined + } + + /** + * The instance of the engine manager. + */ + static instance(): EngineManager { + return window.core?.engineManager as EngineManager ?? new EngineManager() + } +} diff --git a/core/src/extensions/ai-engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts similarity index 71% rename from core/src/extensions/ai-engines/LocalOAIEngine.ts rename to core/src/browser/extensions/engines/LocalOAIEngine.ts index f6557cd8f3..ab5a2622c2 100644 --- a/core/src/extensions/ai-engines/LocalOAIEngine.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -1,6 +1,6 @@ import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core' import { events } from '../../events' -import { Model, ModelEvent } from '../../types' +import { Model, ModelEvent } from '../../../types' import { OAIEngine } from './OAIEngine' /** @@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine { /** * On extension load, subscribe to events. */ - onLoad() { + override onLoad() { super.onLoad() // These events are applicable to local inference providers events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) @@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine { /** * Load the model. */ - async loadModel(model: Model) { + override async loadModel(model: Model): Promise { if (model.engine.toString() !== this.provider) return - - const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id]) + const modelFolderName = 'models' + const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id]) const systemInfo = await systemInformation() const res = await executeOnMain( this.nodeModule, @@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine { ) if (res?.error) { - events.emit(ModelEvent.OnModelFail, { - ...model, - error: res.error, - }) - return + events.emit(ModelEvent.OnModelFail, { error: res.error }) + return Promise.reject(res.error) } else { this.loadedModel = model events.emit(ModelEvent.OnModelReady, model) + return Promise.resolve() } } /** * Stops the model. */ - unloadModel(model: Model) { - if (model.engine && model.engine?.toString() !== this.provider) return - this.loadedModel = undefined + override async unloadModel(model?: Model): Promise { + if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve() - executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { + this.loadedModel = undefined + return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { events.emit(ModelEvent.OnModelStopped, {}) }) } diff --git a/core/src/extensions/ai-engines/OAIEngine.ts b/core/src/browser/extensions/engines/OAIEngine.ts similarity index 94% rename from core/src/extensions/ai-engines/OAIEngine.ts rename to core/src/browser/extensions/engines/OAIEngine.ts index 5936005bbf..41b08f4598 100644 --- a/core/src/extensions/ai-engines/OAIEngine.ts +++ b/core/src/browser/extensions/engines/OAIEngine.ts @@ -13,7 +13,7 @@ import { ModelInfo, ThreadContent, ThreadMessage, -} from '../../types' +} from '../../../types' import { events } from '../../events' /** @@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine { /** * On extension load, subscribe to events. */ - onLoad() { + override onLoad() { super.onLoad() events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data)) events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference()) @@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine { /** * On extension unload */ - onUnload(): void {} + override onUnload(): void {} /* * Inference request */ - inference(data: MessageRequest) { + override inference(data: MessageRequest) { if (data.model?.engine?.toString() !== this.provider) return const timestamp = Date.now() @@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine { return } message.status = MessageStatus.Error + message.error_code = err.code events.emit(MessageEvent.OnMessageUpdate, message) }, }) @@ -114,7 +115,7 @@ export abstract class OAIEngine extends AIEngine { /** * Stops the inference. */ - stopInference() { + override stopInference() { this.isCancelled = true this.controller?.abort() } diff --git a/core/src/browser/extensions/engines/RemoteOAIEngine.ts b/core/src/browser/extensions/engines/RemoteOAIEngine.ts new file mode 100644 index 0000000000..2d5126c6b9 --- /dev/null +++ b/core/src/browser/extensions/engines/RemoteOAIEngine.ts @@ -0,0 +1,26 @@ +import { OAIEngine } from './OAIEngine' + +/** + * Base OAI Remote Inference Provider + * Added the implementation of loading and unloading model (applicable to local inference providers) + */ +export abstract class RemoteOAIEngine extends OAIEngine { + // The inference engine + abstract apiKey: string + /** + * On extension load, subscribe to events. + */ + override onLoad() { + super.onLoad() + } + + /** + * Headers for the inference request + */ + override headers(): HeadersInit { + return { + 'Authorization': `Bearer ${this.apiKey}`, + 'api-key': `${this.apiKey}`, + } + } +} diff --git a/core/src/extensions/ai-engines/helpers/sse.ts b/core/src/browser/extensions/engines/helpers/sse.ts similarity index 85% rename from core/src/extensions/ai-engines/helpers/sse.ts rename to core/src/browser/extensions/engines/helpers/sse.ts index 723d0dc13f..def017ebc6 100644 --- a/core/src/extensions/ai-engines/helpers/sse.ts +++ b/core/src/browser/extensions/engines/helpers/sse.ts @@ -1,5 +1,5 @@ import { Observable } from 'rxjs' -import { ModelRuntimeParams } from '../../../types' +import { ErrorCode, ModelRuntimeParams } from '../../../../types' /** * Sends a request to the inference server to generate a response based on the recent messages. * @param recentMessages - An array of recent messages to use as context for the inference. @@ -34,6 +34,16 @@ export function requestInference( signal: controller?.signal, }) .then(async (response) => { + if (!response.ok) { + const data = await response.json() + const error = { + message: data.error?.message ?? 'Error occurred.', + code: data.error?.code ?? ErrorCode.Unknown, + } + subscriber.error(error) + subscriber.complete() + return + } if (model.parameters.stream === false) { const data = await response.json() subscriber.next(data.choices[0]?.message?.content ?? '') diff --git a/core/src/extensions/ai-engines/index.ts b/core/src/browser/extensions/engines/index.ts similarity index 79% rename from core/src/extensions/ai-engines/index.ts rename to core/src/browser/extensions/engines/index.ts index fc341380ab..34ef45afd1 100644 --- a/core/src/extensions/ai-engines/index.ts +++ b/core/src/browser/extensions/engines/index.ts @@ -2,3 +2,4 @@ export * from './AIEngine' export * from './OAIEngine' export * from './LocalOAIEngine' export * from './RemoteOAIEngine' +export * from './EngineManager' diff --git a/core/src/extensions/huggingface.ts b/core/src/browser/extensions/huggingface.ts similarity index 92% rename from core/src/extensions/huggingface.ts rename to core/src/browser/extensions/huggingface.ts index 16a1d9b8af..b9c9626a00 100644 --- a/core/src/extensions/huggingface.ts +++ b/core/src/browser/extensions/huggingface.ts @@ -1,6 +1,6 @@ import { BaseExtension, ExtensionTypeEnum } from '../extension' -import { HuggingFaceInterface, HuggingFaceRepoData, Quantization } from '../types/huggingface' -import { Model } from '../types/model' +import { HuggingFaceInterface, HuggingFaceRepoData, Quantization } from '../../types/huggingface' +import { Model } from '../../types/model' /** * Hugging Face extension for converting HF models to GGUF. diff --git a/core/src/extensions/index.ts b/core/src/browser/extensions/index.ts similarity index 96% rename from core/src/extensions/index.ts rename to core/src/browser/extensions/index.ts index c049f3b3ab..768886d497 100644 --- a/core/src/extensions/index.ts +++ b/core/src/browser/extensions/index.ts @@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface' /** * Base AI Engines. */ -export * from './ai-engines' +export * from './engines' diff --git a/core/src/extensions/inference.ts b/core/src/browser/extensions/inference.ts similarity index 96% rename from core/src/extensions/inference.ts rename to core/src/browser/extensions/inference.ts index e8e51f9eb9..44c50f7f82 100644 --- a/core/src/extensions/inference.ts +++ b/core/src/browser/extensions/inference.ts @@ -1,4 +1,4 @@ -import { InferenceInterface, MessageRequest, ThreadMessage } from '../index' +import { InferenceInterface, MessageRequest, ThreadMessage } from '../../types' import { BaseExtension, ExtensionTypeEnum } from '../extension' /** diff --git a/core/src/extensions/model.ts b/core/src/browser/extensions/model.ts similarity index 97% rename from core/src/extensions/model.ts rename to core/src/browser/extensions/model.ts index 33eec0afce..6dd52f192e 100644 --- a/core/src/extensions/model.ts +++ b/core/src/browser/extensions/model.ts @@ -1,5 +1,5 @@ import { BaseExtension, ExtensionTypeEnum } from '../extension' -import { GpuSetting, ImportingModel, Model, ModelInterface, OptionType } from '../index' +import { GpuSetting, ImportingModel, Model, ModelInterface, OptionType } from '../../types' /** * Model extension for managing models. diff --git a/core/src/extensions/monitoring.ts b/core/src/browser/extensions/monitoring.ts similarity index 97% rename from core/src/extensions/monitoring.ts rename to core/src/browser/extensions/monitoring.ts index 2d75e0218b..c30766f6ef 100644 --- a/core/src/extensions/monitoring.ts +++ b/core/src/browser/extensions/monitoring.ts @@ -1,5 +1,5 @@ import { BaseExtension, ExtensionTypeEnum } from '../extension' -import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../index' +import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../../types' /** * Monitoring extension for system monitoring. diff --git a/core/src/fs.ts b/core/src/browser/fs.ts similarity index 98% rename from core/src/fs.ts rename to core/src/browser/fs.ts index 3a9a20afb7..164e3b6479 100644 --- a/core/src/fs.ts +++ b/core/src/browser/fs.ts @@ -1,4 +1,4 @@ -import { FileStat } from './types' +import { FileStat } from '../types' /** * Writes data to a file at the specified path. diff --git a/core/src/browser/index.ts b/core/src/browser/index.ts new file mode 100644 index 0000000000..a7803c7e04 --- /dev/null +++ b/core/src/browser/index.ts @@ -0,0 +1,35 @@ +/** + * Export Core module + * @module + */ +export * from './core' + +/** + * Export Event module. + * @module + */ +export * from './events' + +/** + * Export Filesystem module. + * @module + */ +export * from './fs' + +/** + * Export Extension module. + * @module + */ +export * from './extension' + +/** + * Export all base extensions. + * @module + */ +export * from './extensions' + +/** + * Export all base tools. + * @module + */ +export * from './tools' diff --git a/core/src/browser/tools/index.ts b/core/src/browser/tools/index.ts new file mode 100644 index 0000000000..24cd127804 --- /dev/null +++ b/core/src/browser/tools/index.ts @@ -0,0 +1,2 @@ +export * from './manager' +export * from './tool' diff --git a/core/src/browser/tools/manager.ts b/core/src/browser/tools/manager.ts new file mode 100644 index 0000000000..b323ad7ced --- /dev/null +++ b/core/src/browser/tools/manager.ts @@ -0,0 +1,47 @@ +import { AssistantTool, MessageRequest } from '../../types' +import { InferenceTool } from './tool' + +/** + * Manages the registration and retrieval of inference tools. + */ +export class ToolManager { + public tools = new Map() + + /** + * Registers a tool. + * @param tool - The tool to register. + */ + register(tool: T) { + this.tools.set(tool.name, tool) + } + + /** + * Retrieves a tool by it's name. + * @param name - The name of the tool to retrieve. + * @returns The tool, if found. + */ + get(name: string): T | undefined { + return this.tools.get(name) as T | undefined + } + + /* + ** Process the message request with the tools. + */ + process(request: MessageRequest, tools: AssistantTool[]): Promise { + return tools.reduce((prevPromise, currentTool) => { + return prevPromise.then((prevResult) => { + return currentTool.enabled + ? this.get(currentTool.type)?.process(prevResult, currentTool) ?? + Promise.resolve(prevResult) + : Promise.resolve(prevResult) + }) + }, Promise.resolve(request)) + } + + /** + * The instance of the tool manager. + */ + static instance(): ToolManager { + return (window.core?.toolManager as ToolManager) ?? new ToolManager() + } +} diff --git a/core/src/browser/tools/tool.ts b/core/src/browser/tools/tool.ts new file mode 100644 index 0000000000..0fd3429331 --- /dev/null +++ b/core/src/browser/tools/tool.ts @@ -0,0 +1,12 @@ +import { AssistantTool, MessageRequest } from '../../types' + +/** + * Represents a base inference tool. + */ +export abstract class InferenceTool { + abstract name: string + /* + ** Process a message request and return the processed message request. + */ + abstract process(request: MessageRequest, tool?: AssistantTool): Promise +} diff --git a/core/src/extensions/ai-engines/RemoteOAIEngine.ts b/core/src/extensions/ai-engines/RemoteOAIEngine.ts deleted file mode 100644 index 5e9804b23a..0000000000 --- a/core/src/extensions/ai-engines/RemoteOAIEngine.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { events } from '../../events' -import { Model, ModelEvent } from '../../types' -import { OAIEngine } from './OAIEngine' - -/** - * Base OAI Remote Inference Provider - * Added the implementation of loading and unloading model (applicable to local inference providers) - */ -export abstract class RemoteOAIEngine extends OAIEngine { - // The inference engine - abstract apiKey: string - /** - * On extension load, subscribe to events. - */ - onLoad() { - super.onLoad() - // These events are applicable to local inference providers - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) - events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) - } - - /** - * Load the model. - */ - async loadModel(model: Model) { - if (model.engine.toString() !== this.provider) return - events.emit(ModelEvent.OnModelReady, model) - } - /** - * Stops the model. - */ - unloadModel(model: Model) { - if (model.engine && model.engine.toString() !== this.provider) return - events.emit(ModelEvent.OnModelStopped, {}) - } - - /** - * Headers for the inference request - */ - override headers(): HeadersInit { - return { - 'Authorization': `Bearer ${this.apiKey}`, - 'api-key': `${this.apiKey}`, - } - } -} diff --git a/core/src/index.ts b/core/src/index.ts index 3505797b19..cfd69f93d1 100644 --- a/core/src/index.ts +++ b/core/src/index.ts @@ -2,42 +2,13 @@ * Export all types. * @module */ -export * from './types/index' +export * from './types' /** - * Export all routes - */ -export * from './api' - -/** - * Export Core module - * @module - */ -export * from './core' - -/** - * Export Event module. - * @module - */ -export * from './events' - -/** - * Export Filesystem module. - * @module - */ -export * from './fs' - -/** - * Export Extension module. - * @module - */ -export * from './extension' - -/** - * Export all base extensions. + * Export browser module * @module */ -export * from './extensions/index' +export * from './browser' /** * Declare global object diff --git a/core/src/node/api/common/adapter.ts b/core/src/node/api/common/adapter.ts index 56f4cedb35..2beacf3254 100644 --- a/core/src/node/api/common/adapter.ts +++ b/core/src/node/api/common/adapter.ts @@ -4,7 +4,7 @@ import { ExtensionRoute, FileManagerRoute, FileSystemRoute, -} from '../../../api' +} from '../../../types/api' import { Downloader } from '../processors/download' import { FileSystem } from '../processors/fs' import { Extension } from '../processors/extension' diff --git a/core/src/node/api/common/handler.ts b/core/src/node/api/common/handler.ts index 4a39ae52a6..fb958dbd1b 100644 --- a/core/src/node/api/common/handler.ts +++ b/core/src/node/api/common/handler.ts @@ -1,4 +1,4 @@ -import { CoreRoutes } from '../../../api' +import { CoreRoutes } from '../../../types/api' import { RequestAdapter } from './adapter' export type Handler = (route: string, args: any) => any diff --git a/core/src/node/api/processors/download.ts b/core/src/node/api/processors/download.ts index 8e8e08f2f6..98464dd52d 100644 --- a/core/src/node/api/processors/download.ts +++ b/core/src/node/api/processors/download.ts @@ -1,5 +1,5 @@ import { resolve, sep } from 'path' -import { DownloadEvent } from '../../../api' +import { DownloadEvent } from '../../../types/api' import { normalizeFilePath } from '../../helper/path' import { getJanDataFolderPath } from '../../helper' import { DownloadManager } from '../../helper/download' diff --git a/core/src/node/api/restful/app/download.ts b/core/src/node/api/restful/app/download.ts index b5919659b1..5e0c83d01a 100644 --- a/core/src/node/api/restful/app/download.ts +++ b/core/src/node/api/restful/app/download.ts @@ -1,4 +1,4 @@ -import { DownloadRoute } from '../../../../api' +import { DownloadRoute } from '../../../../types/api' import { DownloadManager } from '../../../helper/download' import { HttpServer } from '../../HttpServer' diff --git a/core/src/node/index.ts b/core/src/node/index.ts index 02d921fd64..eb60270752 100644 --- a/core/src/node/index.ts +++ b/core/src/node/index.ts @@ -5,4 +5,4 @@ export * from './extension/store' export * from './api' export * from './helper' export * from './../types' -export * from './../api' +export * from '../types/api' diff --git a/core/src/api/index.ts b/core/src/types/api/index.ts similarity index 100% rename from core/src/api/index.ts rename to core/src/types/api/index.ts diff --git a/core/src/types/index.ts b/core/src/types/index.ts index 295d054e7e..291c735246 100644 --- a/core/src/types/index.ts +++ b/core/src/types/index.ts @@ -8,3 +8,4 @@ export * from './file' export * from './config' export * from './huggingface' export * from './miscellaneous' +export * from './api' diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index d62a7c3871..a313847b69 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -7,7 +7,6 @@ export type ModelInfo = { settings: ModelSettingParams parameters: ModelRuntimeParams engine?: InferenceEngine - proxy_model?: InferenceEngine } /** @@ -21,8 +20,6 @@ export enum InferenceEngine { groq = 'groq', triton_trtllm = 'triton_trtllm', nitro_tensorrt_llm = 'nitro-tensorrt-llm', - - tool_retrieval_enabled = 'tool_retrieval_enabled', } export type ModelArtifact = { @@ -94,8 +91,6 @@ export type Model = { * The model engine. */ engine: InferenceEngine - - proxy_model?: InferenceEngine } export type ModelMetadata = { diff --git a/extensions/assistant-extension/src/index.ts b/extensions/assistant-extension/src/index.ts index 97a1cb2207..64528b0e09 100644 --- a/extensions/assistant-extension/src/index.ts +++ b/extensions/assistant-extension/src/index.ts @@ -1,26 +1,21 @@ import { fs, Assistant, - MessageRequest, events, - InferenceEngine, - MessageEvent, - InferenceEvent, joinPath, - executeOnMain, AssistantExtension, AssistantEvent, + ToolManager, } from '@janhq/core' +import { RetrievalTool } from './tools/retrieval' export default class JanAssistantExtension extends AssistantExtension { private static readonly _homeDir = 'file://assistants' - private static readonly _threadDir = 'file://threads' - - controller = new AbortController() - isCancelled = false - retrievalThreadId: string | undefined = undefined async onLoad() { + // Register the retrieval tool + ToolManager.instance().register(new RetrievalTool()) + // making the assistant directory const assistantDirExist = await fs.existsSync( JanAssistantExtension._homeDir @@ -38,140 +33,6 @@ export default class JanAssistantExtension extends AssistantExtension { // Update the assistant list events.emit(AssistantEvent.OnAssistantsUpdate, {}) } - - // Events subscription - events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => - JanAssistantExtension.handleMessageRequest(data, this) - ) - - events.on(InferenceEvent.OnInferenceStopped, () => { - JanAssistantExtension.handleInferenceStopped(this) - }) - } - - private static async handleInferenceStopped(instance: JanAssistantExtension) { - instance.isCancelled = true - instance.controller?.abort() - } - - private static async handleMessageRequest( - data: MessageRequest, - instance: JanAssistantExtension - ) { - instance.isCancelled = false - instance.controller = new AbortController() - - if ( - data.model?.engine !== InferenceEngine.tool_retrieval_enabled || - !data.messages || - // TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined - // That could lead to an issue where thread stuck at generating response - !data.thread?.assistants[0]?.tools - ) { - return - } - - const latestMessage = data.messages[data.messages.length - 1] - - // 1. Ingest the document if needed - if ( - latestMessage && - latestMessage.content && - typeof latestMessage.content !== 'string' && - latestMessage.content.length > 1 - ) { - const docFile = latestMessage.content[1]?.doc_url?.url - if (docFile) { - await executeOnMain( - NODE, - 'toolRetrievalIngestNewDocument', - docFile, - data.model?.proxy_model - ) - } - } else if ( - // Check whether we need to ingest document or not - // Otherwise wrong context will be sent - !(await fs.existsSync( - await joinPath([ - JanAssistantExtension._threadDir, - data.threadId, - 'memory', - ]) - )) - ) { - // No document ingested, reroute the result to inference engine - const output = { - ...data, - model: { - ...data.model, - engine: data.model.proxy_model, - }, - } - events.emit(MessageEvent.OnMessageSent, output) - return - } - // 2. Load agent on thread changed - if (instance.retrievalThreadId !== data.threadId) { - await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId) - - instance.retrievalThreadId = data.threadId - - // Update the text splitter - await executeOnMain( - NODE, - 'toolRetrievalUpdateTextSplitter', - data.thread.assistants[0].tools[0]?.settings?.chunk_size ?? 4000, - data.thread.assistants[0].tools[0]?.settings?.chunk_overlap ?? 200 - ) - } - - // 3. Using the retrieval template with the result and query - if (latestMessage.content) { - const prompt = - typeof latestMessage.content === 'string' - ? latestMessage.content - : latestMessage.content[0].text - // Retrieve the result - const retrievalResult = await executeOnMain( - NODE, - 'toolRetrievalQueryResult', - prompt - ) - console.debug('toolRetrievalQueryResult', retrievalResult) - - // Update message content - if (data.thread?.assistants[0]?.tools && retrievalResult) - data.messages[data.messages.length - 1].content = - data.thread.assistants[0].tools[0].settings?.retrieval_template - ?.replace('{CONTEXT}', retrievalResult) - .replace('{QUESTION}', prompt) - } - - // Filter out all the messages that are not text - data.messages = data.messages.map((message) => { - if ( - message.content && - typeof message.content !== 'string' && - (message.content.length ?? 0) > 0 - ) { - return { - ...message, - content: [message.content[0]], - } - } - return message - }) - - // 4. Reroute the result to inference engine - const output = { - ...data, - model: { - ...data.model, - engine: data.model.proxy_model, - }, - } - events.emit(MessageEvent.OnMessageSent, output) } /** diff --git a/extensions/assistant-extension/src/tools/retrieval.ts b/extensions/assistant-extension/src/tools/retrieval.ts new file mode 100644 index 0000000000..35738fd8e0 --- /dev/null +++ b/extensions/assistant-extension/src/tools/retrieval.ts @@ -0,0 +1,108 @@ +import { + AssistantTool, + executeOnMain, + fs, + InferenceTool, + joinPath, + MessageRequest, +} from '@janhq/core' + +export class RetrievalTool extends InferenceTool { + private _threadDir = 'file://threads' + private retrievalThreadId: string | undefined = undefined + + name: string = 'retrieval' + + async process( + data: MessageRequest, + tool?: AssistantTool + ): Promise { + if (!data.model || !data.messages) { + return Promise.resolve(data) + } + + const latestMessage = data.messages[data.messages.length - 1] + + // 1. Ingest the document if needed + if ( + latestMessage && + latestMessage.content && + typeof latestMessage.content !== 'string' && + latestMessage.content.length > 1 + ) { + const docFile = latestMessage.content[1]?.doc_url?.url + if (docFile) { + await executeOnMain( + NODE, + 'toolRetrievalIngestNewDocument', + docFile, + data.model?.engine + ) + } + } else if ( + // Check whether we need to ingest document or not + // Otherwise wrong context will be sent + !(await fs.existsSync( + await joinPath([this._threadDir, data.threadId, 'memory']) + )) + ) { + // No document ingested, reroute the result to inference engine + + return Promise.resolve(data) + } + // 2. Load agent on thread changed + if (this.retrievalThreadId !== data.threadId) { + await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId) + + this.retrievalThreadId = data.threadId + + // Update the text splitter + await executeOnMain( + NODE, + 'toolRetrievalUpdateTextSplitter', + tool?.settings?.chunk_size ?? 4000, + tool?.settings?.chunk_overlap ?? 200 + ) + } + + // 3. Using the retrieval template with the result and query + if (latestMessage.content) { + const prompt = + typeof latestMessage.content === 'string' + ? latestMessage.content + : latestMessage.content[0].text + // Retrieve the result + const retrievalResult = await executeOnMain( + NODE, + 'toolRetrievalQueryResult', + prompt + ) + console.debug('toolRetrievalQueryResult', retrievalResult) + + // Update message content + if (retrievalResult) + data.messages[data.messages.length - 1].content = + tool?.settings?.retrieval_template + ?.replace('{CONTEXT}', retrievalResult) + .replace('{QUESTION}', prompt) + } + + // Filter out all the messages that are not text + data.messages = data.messages.map((message) => { + if ( + message.content && + typeof message.content !== 'string' && + (message.content.length ?? 0) > 0 + ) { + return { + ...message, + content: [message.content[0]], + } + } + return message + }) + + // 4. Reroute the result to inference engine + return Promise.resolve(data) + } +} diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index 3a23082baf..313b67365b 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine { return super.loadModel(model) } - override unloadModel(model: Model): void { - super.unloadModel(model) - - if (model.engine && model.engine !== this.provider) return + override async unloadModel(model?: Model) { + if (model?.engine && model.engine !== this.provider) return // stop the periocally health check if (this.getNitroProcesHealthIntervalId) { clearInterval(this.getNitroProcesHealthIntervalId) this.getNitroProcesHealthIntervalId = undefined } + return super.unloadModel(model) } } diff --git a/web/containers/DropdownListSidebar/index.tsx b/web/containers/DropdownListSidebar/index.tsx index 5022c83f1b..b0953cdea1 100644 --- a/web/containers/DropdownListSidebar/index.tsx +++ b/web/containers/DropdownListSidebar/index.tsx @@ -271,26 +271,7 @@ const DropdownListSidebar = ({ )} >
- {x.engine === InferenceEngine.openai && ( - - - - )} -
+
{x.name} @@ -307,8 +288,7 @@ const DropdownListSidebar = ({
{x.id} diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index d44c950e11..4d5555a469 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -8,26 +8,17 @@ import { ExtensionTypeEnum, MessageStatus, MessageRequest, - Model, ConversationalExtension, MessageEvent, MessageRequestType, ModelEvent, Thread, - ModelInitFailed, + EngineManager, } from '@janhq/core' import { useAtomValue, useSetAtom } from 'jotai' import { ulid } from 'ulidx' -import { - activeModelAtom, - loadModelErrorAtom, - stateModelAtom, -} from '@/hooks/useActiveModel' - -import { queuedMessageAtom } from '@/hooks/useSendChatMessage' - -import { toaster } from '../Toast' +import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' import { extensionManager } from '@/extension' import { @@ -51,8 +42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) { const activeModel = useAtomValue(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom) const setStateModel = useSetAtom(stateModelAtom) - const setQueuedMessage = useSetAtom(queuedMessageAtom) - const setLoadModelError = useSetAtom(loadModelErrorAtom) const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) const threads = useAtomValue(threadsAtom) @@ -88,44 +77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) { [addNewMessage] ) - const onModelReady = useCallback( - (model: Model) => { - setActiveModel(model) - toaster({ - title: 'Success!', - description: `Model ${model.id} has been started.`, - type: 'success', - }) - setStateModel(() => ({ - state: 'stop', - loading: false, - model: model.id, - })) - }, - [setActiveModel, setStateModel] - ) - const onModelStopped = useCallback(() => { - setTimeout(() => { - setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: '' }) - }, 500) + setActiveModel(undefined) + setStateModel({ state: 'start', loading: false, model: '' }) }, [setActiveModel, setStateModel]) - const onModelInitFailed = useCallback( - (res: ModelInitFailed) => { - console.error('Failed to load model: ', res.error.message) - setStateModel(() => ({ - state: 'start', - loading: false, - model: res.id, - })) - setLoadModelError(res.error.message) - setQueuedMessage(false) - }, - [setStateModel, setQueuedMessage, setLoadModelError] - ) - const updateThreadTitle = useCallback( (message: ThreadMessage) => { // Update only when it's finished @@ -274,7 +230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) { // 2. Update the title with the result of the inference setTimeout(() => { - events.emit(MessageEvent.OnMessageSent, messageRequest) + const engine = EngineManager.instance().get( + messageRequest.model?.engine ?? activeModelRef.current?.engine ?? '' + ) + engine?.inference(messageRequest) }, 1000) } } @@ -283,23 +242,16 @@ export default function EventHandler({ children }: { children: ReactNode }) { if (window.core?.events) { events.on(MessageEvent.OnMessageResponse, onNewMessageResponse) events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate) - events.on(ModelEvent.OnModelReady, onModelReady) - events.on(ModelEvent.OnModelFail, onModelInitFailed) events.on(ModelEvent.OnModelStopped, onModelStopped) } - }, [ - onNewMessageResponse, - onMessageResponseUpdate, - onModelReady, - onModelInitFailed, - onModelStopped, - ]) + }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped]) useEffect(() => { return () => { events.off(MessageEvent.OnMessageResponse, onNewMessageResponse) events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate) + events.off(ModelEvent.OnModelStopped, onModelStopped) } - }, [onNewMessageResponse, onMessageResponseUpdate]) + }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped]) return {children} } diff --git a/web/extension/ExtensionManager.ts b/web/extension/ExtensionManager.ts index c976010c67..6d96d71b5e 100644 --- a/web/extension/ExtensionManager.ts +++ b/web/extension/ExtensionManager.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { BaseExtension, ExtensionTypeEnum } from '@janhq/core' +import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core' import Extension from './Extension' @@ -8,14 +8,26 @@ import Extension from './Extension' * Manages the registration and retrieval of extensions. */ export class ExtensionManager { + // Registered extensions private extensions = new Map() + // Registered inference engines + private engines = new Map() + /** * Registers an extension. * @param extension - The extension to register. */ register(name: string, extension: T) { this.extensions.set(extension.type() ?? name, extension) + + // Register AI Engines + if ('provider' in extension && typeof extension.provider === 'string') { + this.engines.set( + extension.provider as unknown as string, + extension as unknown as AIEngine + ) + } } /** @@ -29,6 +41,15 @@ export class ExtensionManager { return this.extensions.get(type) as T | undefined } + /** + * Retrieves a extension by its type. + * @param engine - The engine name to retrieve. + * @returns The extension, if found. + */ + getEngine(engine: string): T | undefined { + return this.engines.get(engine) as T | undefined + } + /** * Loads all registered extension. */ diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index 98433c2ea8..0da28efe4f 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -1,6 +1,6 @@ import { useCallback, useEffect, useRef } from 'react' -import { events, Model, ModelEvent } from '@janhq/core' +import { EngineManager, Model } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { toaster } from '@/containers/Toast' @@ -38,19 +38,13 @@ export function useActiveModel() { (stateModel.model === modelId && stateModel.loading) ) { console.debug(`Model ${modelId} is already initialized. Ignore..`) - return + return Promise.resolve() } let model = downloadedModelsRef?.current.find((e) => e.id === modelId) - // Switch between engines - if (model && activeModel && activeModel.engine !== model.engine) { - stopModel() - // TODO: Refactor inference provider would address this - await new Promise((res) => setTimeout(res, 1000)) - } + await stopModel().catch() - // TODO: incase we have multiple assistants, the configuration will be from assistant setLoadModelError(undefined) setActiveModel(undefined) @@ -68,7 +62,8 @@ export function useActiveModel() { loading: false, model: '', })) - return + + return Promise.reject(`Model ${modelId} not found!`) } /// Apply thread model settings @@ -83,15 +78,52 @@ export function useActiveModel() { } localStorage.setItem(LAST_USED_MODEL_ID, model.id) - events.emit(ModelEvent.OnModelInit, model) + const engine = EngineManager.instance().get(model.engine) + return engine + ?.loadModel(model) + .then(() => { + setActiveModel(model) + setStateModel(() => ({ + state: 'stop', + loading: false, + model: model.id, + })) + toaster({ + title: 'Success!', + description: `Model ${model.id} has been started.`, + type: 'success', + }) + }) + .catch((error) => { + setStateModel(() => ({ + state: 'start', + loading: false, + model: model.id, + })) + + toaster({ + title: 'Failed!', + description: `Model ${model.id} failed to start.`, + type: 'success', + }) + setLoadModelError(error) + return Promise.reject(error) + }) } const stopModel = useCallback(async () => { if (activeModel) { setStateModel({ state: 'stop', loading: true, model: activeModel.id }) - events.emit(ModelEvent.OnModelStop, activeModel) + const engine = EngineManager.instance().get(activeModel.engine) + await engine + ?.unloadModel(activeModel) + .catch() + .then(() => { + setActiveModel(undefined) + setStateModel({ state: 'start', loading: false, model: '' }) + }) } - }, [activeModel, setStateModel]) + }, [activeModel, setActiveModel, setStateModel]) return { activeModel, startModel, stopModel, stateModel } } diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 1ba68f85e0..b380320091 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -2,27 +2,18 @@ import { useEffect, useRef } from 'react' import { - ChatCompletionMessage, ChatCompletionRole, - ContentType, - MessageRequest, MessageRequestType, - MessageStatus, ExtensionTypeEnum, Thread, ThreadMessage, - events, Model, ConversationalExtension, - MessageEvent, - InferenceEngine, - ChatCompletionMessageContentType, - AssistantTool, + EngineManager, + ToolManager, } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' -import { ulid } from 'ulidx' - import { selectedModelAtom } from '@/containers/DropdownListSidebar' import { currentPromptAtom, @@ -31,8 +22,11 @@ import { } from '@/containers/Providers/Jotai' import { compressImage, getBase64 } from '@/utils/base64' +import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' + import { loadModelErrorAtom, useActiveModel } from './useActiveModel' import { extensionManager } from '@/extension/ExtensionManager' @@ -65,7 +59,6 @@ export default function useSendChatMessage() { const currentMessages = useAtomValue(getCurrentChatMessagesAtom) const selectedModel = useAtomValue(selectedModelAtom) const { activeModel, startModel } = useActiveModel() - const setQueuedMessage = useSetAtom(queuedMessageAtom) const loadModelFailed = useAtomValue(loadModelErrorAtom) const modelRef = useRef() @@ -78,6 +71,7 @@ export default function useSendChatMessage() { const [fileUpload, setFileUpload] = useAtom(fileUploadAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const activeThreadRef = useRef() + const setQueuedMessage = useSetAtom(queuedMessageAtom) const selectedModelRef = useRef() @@ -103,51 +97,27 @@ export default function useSendChatMessage() { return } updateThreadWaiting(activeThreadRef.current.id, true) - const messages: ChatCompletionMessage[] = [ - activeThreadRef.current.assistants[0]?.instructions, - ] - .filter((e) => e && e.trim() !== '') - .map((instructions) => { - const systemMessage: ChatCompletionMessage = { - role: ChatCompletionRole.System, - content: instructions, - } - return systemMessage - }) - .concat( - currentMessages - .filter( - (e) => - (currentMessage.role === ChatCompletionRole.User || - e.id !== currentMessage.id) && - e.status !== MessageStatus.Error - ) - .map((msg) => ({ - role: msg.role, - content: msg.content[0]?.text.value ?? '', - })) - ) - - const messageRequest: MessageRequest = { - id: ulid(), - type: MessageRequestType.Thread, - messages: messages, - threadId: activeThreadRef.current.id, - model: - activeThreadRef.current.assistants[0].model ?? selectedModelRef.current, - } + + const requestBuilder = new MessageRequestBuilder( + MessageRequestType.Thread, + activeThreadRef.current.assistants[0].model ?? selectedModelRef.current, + activeThreadRef.current, + currentMessages + ).addSystemMessage(activeThreadRef.current.assistants[0]?.instructions) const modelId = selectedModelRef.current?.id ?? activeThreadRef.current.assistants[0].model.id if (modelRef.current?.id !== modelId) { - setQueuedMessage(true) - startModel(modelId) - await waitForModelStarting(modelId) - setQueuedMessage(false) + const error = await startModel(modelId).catch((error: Error) => error) + if (error) { + return + } } + setIsGeneratingResponse(true) + if (currentMessage.role !== ChatCompletionRole.User) { // Delete last response before regenerating deleteMessage(currentMessage.id ?? '') @@ -160,9 +130,22 @@ export default function useSendChatMessage() { ) } } - events.emit(MessageEvent.OnMessageSent, messageRequest) + // Process message request with Assistants tools + const request = await ToolManager.instance().process( + requestBuilder.build(), + activeThreadRef.current.assistants?.flatMap( + (assistant) => assistant.tools ?? [] + ) ?? [] + ) + + const engine = + requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? '' + + EngineManager.instance().get(engine)?.inference(request) } + // Define interface extending Array prototype + const sendChatMessage = async (message: string) => { if (!message || message.trim().length === 0) return @@ -176,8 +159,9 @@ export default function useSendChatMessage() { const runtimeParams = toRuntimeParams(activeModelParams) const settingParams = toSettingParams(activeModelParams) - updateThreadWaiting(activeThreadRef.current.id, true) const prompt = message.trim() + + updateThreadWaiting(activeThreadRef.current.id, true) setCurrentPrompt('') setEditPrompt('') @@ -185,69 +169,12 @@ export default function useSendChatMessage() { ? await getBase64(fileUpload[0].file) : undefined - const fileContentType = fileUpload[0]?.type - - const msgId = ulid() - - const isDocumentInput = base64Blob && fileContentType === 'pdf' - const isImageInput = base64Blob && fileContentType === 'image' - - if (isImageInput && base64Blob) { + if (base64Blob && fileUpload[0]?.type === 'image') { // Compress image base64Blob = await compressImage(base64Blob, 512) } - const messages: ChatCompletionMessage[] = [ - activeThreadRef.current.assistants[0]?.instructions, - ] - .filter((e) => e && e.trim() !== '') - .map((instructions) => { - const systemMessage: ChatCompletionMessage = { - role: ChatCompletionRole.System, - content: instructions, - } - return systemMessage - }) - .concat( - currentMessages - .filter((e) => e.status !== MessageStatus.Error) - .map((msg) => ({ - role: msg.role, - content: msg.content[0]?.text.value ?? '', - })) - .concat([ - { - role: ChatCompletionRole.User, - content: - selectedModelRef.current && base64Blob - ? [ - { - type: ChatCompletionMessageContentType.Text, - text: prompt, - }, - isDocumentInput - ? { - type: ChatCompletionMessageContentType.Doc, - doc_url: { - url: `threads/${activeThreadRef.current.id}/files/${msgId}.pdf`, - }, - } - : null, - isImageInput - ? { - type: ChatCompletionMessageContentType.Image, - image_url: { - url: base64Blob, - }, - } - : null, - ].filter((e) => e !== null) - : prompt, - } as ChatCompletionMessage, - ]) - ) - - let modelRequest = + const modelRequest = selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model // Fallback support for previous broken threads @@ -261,131 +188,83 @@ export default function useSendChatMessage() { if (runtimeParams.stream == null) { runtimeParams.stream = true } - // Add middleware to the model request with tool retrieval enabled - if ( - activeThreadRef.current.assistants[0].tools?.some( - (tool: AssistantTool) => tool.type === 'retrieval' && tool.enabled - ) - ) { - modelRequest = { - ...modelRequest, - // Tool retrieval support document input only for now - ...(isDocumentInput - ? { - engine: InferenceEngine.tool_retrieval_enabled, - proxy_model: modelRequest.engine, - } - : {}), - } - } - const messageRequest: MessageRequest = { - id: msgId, - type: MessageRequestType.Thread, - threadId: activeThreadRef.current.id, - messages, - model: { + + // Build Message Request + const requestBuilder = new MessageRequestBuilder( + MessageRequestType.Thread, + { ...modelRequest, settings: settingParams, parameters: runtimeParams, }, - thread: activeThreadRef.current, - } + activeThreadRef.current, + currentMessages + ).addSystemMessage(activeThreadRef.current.assistants[0].instructions) - const timestamp = Date.now() - const content: any = [] + requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type) - if (base64Blob && fileUpload[0]?.type === 'image') { - content.push({ - type: ContentType.Image, - text: { - value: prompt, - annotations: [base64Blob], - }, - }) - } + // Build Thread Message to persist + const threadMessageBuilder = new ThreadMessageBuilder( + requestBuilder + ).pushMessage(prompt, base64Blob, fileUpload) - if (base64Blob && fileUpload[0]?.type === 'pdf') { - content.push({ - type: ContentType.Pdf, - text: { - value: prompt, - annotations: [base64Blob], - name: fileUpload[0].file.name, - size: fileUpload[0].file.size, - }, - }) - } + const newMessage = threadMessageBuilder.build() - if (prompt && !base64Blob) { - content.push({ - type: ContentType.Text, - text: { - value: prompt, - annotations: [], - }, - }) - } - - const threadMessage: ThreadMessage = { - id: msgId, - thread_id: activeThreadRef.current.id, - role: ChatCompletionRole.User, - status: MessageStatus.Ready, - created: timestamp, - updated: timestamp, - object: 'thread.message', - content: content, - } - - addNewMessage(threadMessage) - if (base64Blob) { - setFileUpload([]) - } + // Push to states + addNewMessage(newMessage) + // Update thread state const updatedThread: Thread = { ...activeThreadRef.current, - updated: timestamp, + updated: newMessage.created, metadata: { ...(activeThreadRef.current.metadata ?? {}), lastMessage: prompt, }, } - - // change last update thread when send message updateThread(updatedThread) + // Add message await extensionManager .get(ExtensionTypeEnum.Conversational) - ?.addNewMessage(threadMessage) + ?.addNewMessage(newMessage) + // Start Model if not started const modelId = selectedModelRef.current?.id ?? activeThreadRef.current.assistants[0].model.id if (modelRef.current?.id !== modelId) { setQueuedMessage(true) - startModel(modelId) - await waitForModelStarting(modelId) + const error = await startModel(modelId).catch((error: Error) => error) setQueuedMessage(false) + if (error) { + updateThreadWaiting(activeThreadRef.current.id, false) + return + } } setIsGeneratingResponse(true) - events.emit(MessageEvent.OnMessageSent, messageRequest) + // Process message request with Assistants tools + const request = await ToolManager.instance().process( + requestBuilder.build(), + activeThreadRef.current.assistants?.flatMap( + (assistant) => assistant.tools ?? [] + ) ?? [] + ) + + // Request for inference + EngineManager.instance() + .get(requestBuilder.model?.engine ?? modelRequest.engine ?? '') + ?.inference(request) + + // Reset states setReloadModel(false) setEngineParamsUpdate(false) - } - const waitForModelStarting = async (modelId: string) => { - return new Promise((resolve) => { - setTimeout(async () => { - if (modelRef.current?.id !== modelId && !loadModelFailedRef.current) { - await waitForModelStarting(modelId) - resolve() - } else { - resolve() - } - }, 200) - }) + if (base64Blob) { + setFileUpload([]) + } } return { diff --git a/web/screens/Chat/ErrorMessage/index.tsx b/web/screens/Chat/ErrorMessage/index.tsx index 5be87a59d8..2104beb92d 100644 --- a/web/screens/Chat/ErrorMessage/index.tsx +++ b/web/screens/Chat/ErrorMessage/index.tsx @@ -74,7 +74,8 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {

- ) : loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? ( + ) : loadModelError && + loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
{
-
-
+
+
{fileUpload[0].file.name.replaceAll(/[-._]/g, ' ')}

diff --git a/web/screens/Chat/SimpleTextMessage/index.tsx b/web/screens/Chat/SimpleTextMessage/index.tsx index 7ea3e434b1..489def1c11 100644 --- a/web/screens/Chat/SimpleTextMessage/index.tsx +++ b/web/screens/Chat/SimpleTextMessage/index.tsx @@ -260,8 +260,8 @@ const SimpleTextMessage: React.FC = (props) => { -

-
+
+
{props.content[0].text.name?.replaceAll(/[-._]/g, ' ')}

diff --git a/web/services/coreService.ts b/web/services/coreService.ts index a483cc452a..aeb1cca1a9 100644 --- a/web/services/coreService.ts +++ b/web/services/coreService.ts @@ -1,3 +1,5 @@ +import { EngineManager, ToolManager } from '@janhq/core' + import { appService } from './appService' import { EventEmitter } from './eventsService' import { restAPI } from './restService' @@ -12,6 +14,8 @@ export const setupCoreServices = () => { if (!window.core) { window.core = { events: new EventEmitter(), + engineManager: new EngineManager(), + toolManager: new ToolManager(), api: { ...(window.electronAPI ? window.electronAPI : restAPI), ...appService, diff --git a/web/utils/messageRequestBuilder.ts b/web/utils/messageRequestBuilder.ts new file mode 100644 index 0000000000..e214b03ea4 --- /dev/null +++ b/web/utils/messageRequestBuilder.ts @@ -0,0 +1,130 @@ +import { + ChatCompletionMessage, + ChatCompletionMessageContent, + ChatCompletionMessageContentText, + ChatCompletionMessageContentType, + ChatCompletionRole, + MessageRequest, + MessageRequestType, + MessageStatus, + ModelInfo, + Thread, + ThreadMessage, +} from '@janhq/core' +import { ulid } from 'ulidx' + +import { FileType } from '@/containers/Providers/Jotai' + +export class MessageRequestBuilder { + msgId: string + type: MessageRequestType + messages: ChatCompletionMessage[] + model: ModelInfo + thread: Thread + + constructor( + type: MessageRequestType, + model: ModelInfo, + thread: Thread, + messages: ThreadMessage[] + ) { + this.msgId = ulid() + this.type = type + this.model = model + this.thread = thread + this.messages = messages + .filter((e) => e.status !== MessageStatus.Error) + .map((msg) => ({ + role: msg.role, + content: msg.content[0]?.text.value ?? '', + })) + } + + // Chainable + pushMessage( + message: string, + base64Blob: string | undefined, + fileContentType: FileType + ) { + if (base64Blob && fileContentType === 'pdf') + return this.addDocMessage(message) + else if (base64Blob && fileContentType === 'image') { + return this.addImageMessage(message, base64Blob) + } + this.messages = [ + ...this.messages, + { + role: ChatCompletionRole.User, + content: message, + }, + ] + return this + } + + // Chainable + addSystemMessage(message: string | undefined) { + if (!message || message.trim() === '') return this + this.messages = [ + { + role: ChatCompletionRole.System, + content: message, + }, + ...this.messages, + ] + return this + } + + // Chainable + addDocMessage(prompt: string) { + const message: ChatCompletionMessage = { + role: ChatCompletionRole.User, + content: [ + { + type: ChatCompletionMessageContentType.Text, + text: prompt, + } as ChatCompletionMessageContentText, + { + type: ChatCompletionMessageContentType.Doc, + doc_url: { + url: `threads/${this.thread.id}/files/${this.msgId}.pdf`, + }, + }, + ] as ChatCompletionMessageContent, + } + this.messages = [message, ...this.messages] + return this + } + + // Chainable + addImageMessage(prompt: string, base64: string) { + const message: ChatCompletionMessage = { + role: ChatCompletionRole.User, + content: [ + { + type: ChatCompletionMessageContentType.Text, + text: prompt, + } as ChatCompletionMessageContentText, + { + type: ChatCompletionMessageContentType.Image, + image_url: { + url: base64, + }, + }, + ] as ChatCompletionMessageContent, + } + + this.messages = [message, ...this.messages] + return this + } + + build(): MessageRequest { + return { + id: this.msgId, + type: this.type, + threadId: this.thread.id, + messages: this.messages, + model: this.model, + thread: this.thread, + } + } +} diff --git a/web/utils/threadMessageBuilder.ts b/web/utils/threadMessageBuilder.ts new file mode 100644 index 0000000000..92e51e5742 --- /dev/null +++ b/web/utils/threadMessageBuilder.ts @@ -0,0 +1,74 @@ +import { + ChatCompletionRole, + ContentType, + MessageStatus, + ThreadContent, + ThreadMessage, +} from '@janhq/core' + +import { FileInfo } from '@/containers/Providers/Jotai' + +import { MessageRequestBuilder } from './messageRequestBuilder' + +export class ThreadMessageBuilder { + messageRequest: MessageRequestBuilder + + content: ThreadContent[] = [] + + constructor(messageRequest: MessageRequestBuilder) { + this.messageRequest = messageRequest + } + + build(): ThreadMessage { + const timestamp = Date.now() + return { + id: this.messageRequest.msgId, + thread_id: this.messageRequest.thread.id, + role: ChatCompletionRole.User, + status: MessageStatus.Ready, + created: timestamp, + updated: timestamp, + object: 'thread.message', + content: this.content, + } + } + + pushMessage( + prompt: string, + base64: string | undefined, + fileUpload: FileInfo[] + ) { + if (base64 && fileUpload[0]?.type === 'image') { + this.content.push({ + type: ContentType.Image, + text: { + value: prompt, + annotations: [base64], + }, + }) + } + + if (base64 && fileUpload[0]?.type === 'pdf') { + this.content.push({ + type: ContentType.Pdf, + text: { + value: prompt, + annotations: [base64], + name: fileUpload[0].file.name, + size: fileUpload[0].file.size, + }, + }) + } + + if (prompt && !base64) { + this.content.push({ + type: ContentType.Text, + text: { + value: prompt, + annotations: [], + }, + }) + } + return this + } +}