Skip to content

Commit

Permalink
Merge pull request #2470 from janhq/chore/load-unload-model-sync
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Mar 25, 2024
2 parents 50f819f + d290ae1 commit 66f7d3d
Show file tree
Hide file tree
Showing 46 changed files with 742 additions and 574 deletions.
2 changes: 1 addition & 1 deletion core/src/core.ts → core/src/browser/core.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from './types'
import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from '../types'

/**
* Execute a extension module function in main process
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Assistant, AssistantInterface } from '../index'
import { Assistant, AssistantInterface } from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension'

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../index'
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension'

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
import { Model, ModelEvent } from '../../types'
import { MessageRequest, Model, ModelEvent } from '../../../types'
import { EngineManager } from './EngineManager'

/**
* Base AIEngine
Expand All @@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension {
// The inference engine
abstract provider: string
// The model folder
modelFolder: string = 'models'

/**
* On extension load, subscribe to events.
*/
override onLoad() {
this.registerEngine()

events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))

this.prePopulateModels()
}

/**
* Defines models
*/
models(): Promise<Model[]> {
return Promise.resolve([])
}

/**
* On extension load, subscribe to events.
* Registers AI Engines
*/
onLoad() {
this.prePopulateModels()
registerEngine() {
EngineManager.instance().register(this)
}

/**
* Loads the model.
*/
async loadModel(model: Model): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}

/*
* Inference request
*/
inference(data: MessageRequest) {}

/**
* Stop inference
*/
stopInference() {}

/**
* Pre-populate models to App Data Folder
*/
prePopulateModels(): Promise<void> {
const modelFolder = 'models'
return this.models().then((models) => {
const prePoluateOperations = models.map((model) =>
getJanDataFolderPath()
.then((janDataFolder) =>
// Attempt to create the model folder
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs
.mkdir(path)
.catch()
Expand Down
32 changes: 32 additions & 0 deletions core/src/browser/extensions/engines/EngineManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { AIEngine } from './AIEngine'

/**
* Manages the registration and retrieval of inference engines.
*/
export class EngineManager {
public engines = new Map<string, AIEngine>()

/**
* Registers an engine.
* @param engine - The engine to register.
*/
register<T extends AIEngine>(engine: T) {
this.engines.set(engine.provider, engine)
}

/**
* Retrieves a engine by provider.
* @param provider - The name of the engine to retrieve.
* @returns The engine, if found.
*/
get<T extends AIEngine>(provider: string): T | undefined {
return this.engines.get(provider) as T | undefined
}

/**
* The instance of the engine manager.
*/
static instance(): EngineManager {
return window.core?.engineManager as EngineManager ?? new EngineManager()
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core'
import { events } from '../../events'
import { Model, ModelEvent } from '../../types'
import { Model, ModelEvent } from '../../../types'
import { OAIEngine } from './OAIEngine'

/**
Expand All @@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
Expand All @@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* Load the model.
*/
async loadModel(model: Model) {
override async loadModel(model: Model): Promise<void> {
if (model.engine.toString() !== this.provider) return

const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
Expand All @@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
)

if (res?.error) {
events.emit(ModelEvent.OnModelFail, {
...model,
error: res.error,
})
return
events.emit(ModelEvent.OnModelFail, { error: res.error })
return Promise.reject(res.error)
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine?.toString() !== this.provider) return
this.loadedModel = undefined
override async unloadModel(model?: Model): Promise<void> {
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()

executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
this.loadedModel = undefined
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
ModelInfo,
ThreadContent,
ThreadMessage,
} from '../../types'
} from '../../../types'
import { events } from '../../events'

/**
Expand All @@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
Expand All @@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension unload
*/
onUnload(): void {}
override onUnload(): void {}

/*
* Inference request
*/
inference(data: MessageRequest) {
override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return

const timestamp = Date.now()
Expand Down Expand Up @@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine {
return
}
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
Expand All @@ -114,7 +115,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* Stops the inference.
*/
stopInference() {
override stopInference() {
this.isCancelled = true
this.controller?.abort()
}
Expand Down
26 changes: 26 additions & 0 deletions core/src/browser/extensions/engines/RemoteOAIEngine.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { OAIEngine } from './OAIEngine'

/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
override onLoad() {
super.onLoad()
}

/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Observable } from 'rxjs'
import { ModelRuntimeParams } from '../../../types'
import { ErrorCode, ModelRuntimeParams } from '../../../../types'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
Expand Down Expand Up @@ -34,6 +34,16 @@ export function requestInference(
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'Error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ export * from './AIEngine'
export * from './OAIEngine'
export * from './LocalOAIEngine'
export * from './RemoteOAIEngine'
export * from './EngineManager'
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { BaseExtension, ExtensionTypeEnum } from '../extension'
import { HuggingFaceInterface, HuggingFaceRepoData, Quantization } from '../types/huggingface'
import { Model } from '../types/model'
import { HuggingFaceInterface, HuggingFaceRepoData, Quantization } from '../../types/huggingface'
import { Model } from '../../types/model'

/**
* Hugging Face extension for converting HF models to GGUF.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
/**
* Base AI Engines.
*/
export * from './ai-engines'
export * from './engines'
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceInterface, MessageRequest, ThreadMessage } from '../index'
import { InferenceInterface, MessageRequest, ThreadMessage } from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension'

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { BaseExtension, ExtensionTypeEnum } from '../extension'
import { GpuSetting, ImportingModel, Model, ModelInterface, OptionType } from '../index'
import { GpuSetting, ImportingModel, Model, ModelInterface, OptionType } from '../../types'

/**
* Model extension for managing models.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { BaseExtension, ExtensionTypeEnum } from '../extension'
import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../index'
import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../../types'

/**
* Monitoring extension for system monitoring.
Expand Down
2 changes: 1 addition & 1 deletion core/src/fs.ts → core/src/browser/fs.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { FileStat } from './types'
import { FileStat } from '../types'

/**
* Writes data to a file at the specified path.
Expand Down
35 changes: 35 additions & 0 deletions core/src/browser/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Export Core module
* @module
*/
export * from './core'

/**
* Export Event module.
* @module
*/
export * from './events'

/**
* Export Filesystem module.
* @module
*/
export * from './fs'

/**
* Export Extension module.
* @module
*/
export * from './extension'

/**
* Export all base extensions.
* @module
*/
export * from './extensions'

/**
* Export all base tools.
* @module
*/
export * from './tools'
2 changes: 2 additions & 0 deletions core/src/browser/tools/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export * from './manager'
export * from './tool'
Loading

0 comments on commit 66f7d3d

Please sign in to comment.