Skip to content

Commit

Permalink
feat: add nomic embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Feb 1, 2024
1 parent b65bda0 commit f9a564a
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 0 deletions.
85 changes: 85 additions & 0 deletions embedder/nomic/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package nomicembedder

import (
"bytes"
"encoding/json"
"io"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/restclientgo"
)

type Model string

const (
ModelNomicEmbedTextV1 Model = "nomic-embed-text-v1"
ModelAllMiniLML6V2 Model = "all-MiniLM-L6-v2"
)

type TaskType string

const (
TaskTypeSearchQuery TaskType = "search_query"
TaskTypeSearchDocument TaskType = "search_document"
TaskTypeClustering TaskType = "clustering"
TaskTypeClassification TaskType = "classification"
)

type request struct {
Model string `json:"model"`
Texts []string `json:"texts"`
TaskType TaskType `json:"task_type,omitempty"`
}

func (r *request) Path() (string, error) {
return "/embedding/text", nil
}

func (r *request) Encode() (io.Reader, error) {
jsonBytes, err := json.Marshal(r)
if err != nil {
return nil, err
}

return bytes.NewReader(jsonBytes), nil
}

func (r *request) ContentType() string {
return "application/json"
}

type response struct {
HTTPStatusCode int `json:"-"`
Embeddings []embedder.Embedding `json:"embeddings"`
Usage Usage `json:"usage"`
RawBody string `json:"-"`
}

type Usage struct {
TotalTokens int `json:"total_tokens"`
}

func (r *response) Decode(body io.Reader) error {
return json.NewDecoder(body).Decode(r)
}

func (r *response) SetBody(body io.Reader) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}

r.RawBody = string(b)
return nil
}

func (r *response) AcceptContentType() string {
return "application/json"
}

func (r *response) SetStatusCode(code int) error {
r.HTTPStatusCode = code
return nil
}

func (r *response) SetHeaders(_ restclientgo.Headers) error { return nil }
74 changes: 74 additions & 0 deletions embedder/nomic/nomic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package nomicembedder

import (
"context"
"net/http"
"os"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/restclientgo"
)

const (
defaultEndpoint = "https://api-atlas.nomic.ai/v1"
defaultModel = ModelNomicEmbedTextV1
)

type Embedder struct {
taskType TaskType
model Model
restClient *restclientgo.RestClient
}

func New() *Embedder {
apiKey := os.Getenv("NOMIC_API_KEY")

return &Embedder{
restClient: restclientgo.New(defaultEndpoint).WithRequestModifier(
func(req *http.Request) *http.Request {
req.Header.Set("Authorization", "Bearer "+apiKey)
return req
},
),
model: defaultModel,
}
}

func (e *Embedder) WithAPIKey(apiKey string) *Embedder {
e.restClient = restclientgo.New(defaultEndpoint).WithRequestModifier(
func(req *http.Request) *http.Request {
req.Header.Set("Authorization", "Bearer "+apiKey)
return req
},
)
return e
}

func (e *Embedder) WithTaskType(taskType TaskType) *Embedder {
e.taskType = taskType
return e
}

func (e *Embedder) WithModel(model Model) *Embedder {
e.model = model
return e
}

// Embed returns the embeddings for the given texts
func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
var resp response
err := e.restClient.Post(
ctx,
&request{
Texts: texts,
Model: string(e.model),
TaskType: e.taskType,
},
&resp,
)
if err != nil {
return nil, err
}

return resp.Embeddings, nil
}
103 changes: 103 additions & 0 deletions examples/embeddings/nomic/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package main

import (
"context"
"fmt"

nomicembedder "github.com/henomis/lingoose/embedder/nomic"
"github.com/henomis/lingoose/index"
indexoption "github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/index/vectordb/jsondb"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/loader"
"github.com/henomis/lingoose/prompt"
"github.com/henomis/lingoose/textsplitter"
)

// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt

func main() {

index := index.New(
jsondb.New().WithPersist("db.json"),
nomicembedder.New(),
).WithIncludeContents(true).WithAddDataCallback(func(data *index.Data) error {
data.Metadata["contentLen"] = len(data.Metadata["content"].(string))
return nil
})

indexIsEmpty, _ := index.IsEmpty(context.Background())

if indexIsEmpty {
err := ingestData(index)
if err != nil {
panic(err)
}
}

query := "What is the purpose of the NATO Alliance?"
similarities, err := index.Query(
context.Background(),
query,
indexoption.WithTopK(3),
)
if err != nil {
panic(err)
}

for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("----------")
}

documentContext := ""
for _, similarity := range similarities {
documentContext += similarity.Content() + "\n\n"
}

llmOpenAI := openai.NewCompletion().WithVerbose(true)
prompt1 := prompt.NewPromptTemplate(
"Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}").WithInputs(
map[string]string{
"query": query,
"context": documentContext,
},
)

err = prompt1.Format(nil)
if err != nil {
panic(err)
}

output, err := llmOpenAI.Completion(context.Background(), prompt1.String())
if err != nil {
panic(err)
}

fmt.Println(output)
}

func ingestData(index *index.Index) error {

fmt.Printf("Ingesting data...")

documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background())
if err != nil {
return err
}

textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20)

documentChunks := textSplitter.SplitDocuments(documents)

err = index.LoadFromDocuments(context.Background(), documentChunks)
if err != nil {
return err
}

fmt.Printf("Done!\n")

return nil
}

0 comments on commit f9a564a

Please sign in to comment.