Skip to content

Commit

Permalink
feat: client parses context
Browse files Browse the repository at this point in the history
  • Loading branch information
ivynya committed Dec 26, 2023
1 parent d52eefc commit c8c851f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
4 changes: 3 additions & 1 deletion client/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ func generate(c *websocket.Conn, req *Request) ([]*llms.Generation, error) {
log.Fatal(err)
}
ctx := context.Background()
completion, err := llm.Generate(ctx, []string{req.Generate.Prompt},
completion, err := llm.Generate(ctx,
[]string{req.Generate.Prompt},
req.Generate.Context,
llms.WithTemperature(0.8),
llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
resp, err := encodeRequest("response", string(chunk))
Expand Down
15 changes: 3 additions & 12 deletions ollama/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

var (
Expand All @@ -22,11 +21,6 @@ type LLM struct {
options options
}

var (
_ llms.LLM = (*LLM)(nil)
_ llms.LanguageModel = (*LLM)(nil)
)

// New creates a new ollama LLM implementation.
func New(opts ...Option) (*LLM, error) {
o := options{}
Expand All @@ -44,7 +38,7 @@ func New(opts ...Option) (*LLM, error) {

// Call Implement the call interface for LLM.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
r, err := o.Generate(ctx, []string{prompt}, options...)
r, err := o.Generate(ctx, []string{prompt}, []int{}, options...)
if err != nil {
return "", err
}
Expand All @@ -55,7 +49,7 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio
}

// Generate implemente the generate interface for LLM.
func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) {
func (o *LLM) Generate(ctx context.Context, prompts []string, chatContext []int, options ...llms.CallOption) ([]*llms.Generation, error) {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMStart(ctx, prompts)
}
Expand Down Expand Up @@ -92,6 +86,7 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
Prompt: prompt,
Template: o.options.customModelTemplate,
Options: ollamaOptions,
Context: chatContext,
Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil),
}

Expand Down Expand Up @@ -153,10 +148,6 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo
return embeddings, nil
}

func (o *LLM) GeneratePrompt(ctx context.Context, prompts []schema.PromptValue, options ...llms.CallOption) (llms.LLMResult, error) { //nolint:lll
return llms.GeneratePrompt(ctx, o, prompts, options...)
}

func (o *LLM) GetNumTokens(text string) int {
return llms.CountTokens(o.options.model, text)
}

0 comments on commit c8c851f

Please sign in to comment.