diff --git a/adapters/handlers/grpc/v1/parse_search_request.go b/adapters/handlers/grpc/v1/parse_search_request.go index 10b52d6750..1ca8400753 100644 --- a/adapters/handlers/grpc/v1/parse_search_request.go +++ b/adapters/handlers/grpc/v1/parse_search_request.go @@ -368,12 +368,6 @@ func extractTargetVectors(req *pb.SearchRequest, class *models.Class) (*[]string var targetVectors *[]string if hs := req.HybridSearch; hs != nil { targetVectors = &hs.TargetVectors - if hs.NearText != nil { - targetVectors = &hs.NearText.TargetVectors - } - if hs.NearVector != nil { - targetVectors = &hs.NearVector.TargetVectors - } } if na := req.NearAudio; na != nil { targetVectors = &na.TargetVectors diff --git a/adapters/handlers/grpc/v1/parse_search_request_test.go b/adapters/handlers/grpc/v1/parse_search_request_test.go index 6badcdcc5c..4701db1c3d 100644 --- a/adapters/handlers/grpc/v1/parse_search_request_test.go +++ b/adapters/handlers/grpc/v1/parse_search_request_test.go @@ -204,11 +204,11 @@ func TestGRPCRequest(t *testing.T) { Alpha: 1.0, Query: "nearvecquery", NearVector: &pb.NearVector{ - VectorBytes: byteops.Float32ToByteVector([]float32{1, 2, 3}), - TargetVectors: []string{"custom"}, - Certainty: &one, - Distance: &one, + VectorBytes: byteops.Float32ToByteVector([]float32{1, 2, 3}), + Certainty: &one, + Distance: &one, }, + TargetVectors: []string{"custom"}, }, }, out: dto.GetParams{ @@ -221,12 +221,12 @@ func TestGRPCRequest(t *testing.T) { Query: "nearvecquery", FusionAlgorithm: 1, NearVectorParams: &searchparams.NearVector{ - Vector: []float32{1, 2, 3}, - TargetVectors: []string{"custom"}, - Certainty: 1.0, - Distance: 1.0, - WithDistance: true, + Vector: []float32{1, 2, 3}, + Certainty: 1.0, + Distance: 1.0, + WithDistance: true, }, + TargetVectors: []string{"custom"}, }, }, error: false, diff --git a/adapters/handlers/grpc/v1/tenants.go b/adapters/handlers/grpc/v1/tenants.go index 0c4b24e1b6..5a1e2bdd1e 100644 --- a/adapters/handlers/grpc/v1/tenants.go +++ b/adapters/handlers/grpc/v1/tenants.go @@ -27,7 +27,7 @@ func (s *Service) tenantsGet(ctx context.Context, principal *models.Principal, r var err error var tenants []*models.Tenant if req.Params == nil { - tenants, err = s.schemaManager.GetConsistentTenants(ctx, principal, req.Collection, req.IsConsistent, []string{}) + tenants, err = s.schemaManager.GetConsistentTenants(ctx, principal, req.Collection, true, []string{}) if err != nil { return nil, err } @@ -38,7 +38,7 @@ func (s *Service) tenantsGet(ctx context.Context, principal *models.Principal, r if len(requestedNames) == 0 { return nil, fmt.Errorf("must specify at least one tenant name") } - tenants, err = s.schemaManager.GetConsistentTenants(ctx, principal, req.Collection, req.IsConsistent, requestedNames) + tenants, err = s.schemaManager.GetConsistentTenants(ctx, principal, req.Collection, true, requestedNames) if err != nil { return nil, err } diff --git a/adapters/handlers/rest/clusterapi/indices.go b/adapters/handlers/rest/clusterapi/indices.go index 3dffd8de77..d15d664aca 100644 --- a/adapters/handlers/rest/clusterapi/indices.go +++ b/adapters/handlers/rest/clusterapi/indices.go @@ -1105,7 +1105,7 @@ func (i *indices) postShard() http.Handler { func (i *indices) putShardReinit() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { args := i.regexpShardReinit.FindStringSubmatch(r.URL.Path) - fmt.Println(args) + if len(args) != 3 { http.Error(w, "invalid URI", http.StatusBadRequest) return diff --git a/adapters/handlers/rest/configure_api.go b/adapters/handlers/rest/configure_api.go index 44ac0df139..8a71dc3930 100644 --- a/adapters/handlers/rest/configure_api.go +++ b/adapters/handlers/rest/configure_api.go @@ -193,6 +193,8 @@ func MakeAppState(ctx context.Context, options *swag.CommandLineOptionsGroup) *s MemtablesMaxSizeMB: appState.ServerConfig.Config.Persistence.MemtablesMaxSizeMB, MemtablesMinActiveSeconds: appState.ServerConfig.Config.Persistence.MemtablesMinActiveDurationSeconds, MemtablesMaxActiveSeconds: appState.ServerConfig.Config.Persistence.MemtablesMaxActiveDurationSeconds, + MaxSegmentSize: appState.ServerConfig.Config.Persistence.LSMMaxSegmentSize, + HNSWMaxLogSize: appState.ServerConfig.Config.Persistence.HNSWMaxLogSize, RootPath: appState.ServerConfig.Config.Persistence.DataPath, QueryLimit: appState.ServerConfig.Config.QueryDefaults.Limit, QueryMaximumResults: appState.ServerConfig.Config.QueryMaximumResults, @@ -251,8 +253,6 @@ func MakeAppState(ctx context.Context, options *swag.CommandLineOptionsGroup) *s remoteIndexClient, appState.Logger, appState.ServerConfig.Config.Persistence.DataPath) appState.Scaler = scaler - /// TODO-RAFT START - // server2port, err := parseNode2Port(appState) if len(server2port) == 0 || err != nil { appState.Logger. @@ -268,31 +268,31 @@ func MakeAppState(ctx context.Context, options *swag.CommandLineOptionsGroup) *s dataPath := appState.ServerConfig.Config.Persistence.DataPath rConfig := rStore.Config{ - WorkDir: filepath.Join(dataPath, "raft"), - NodeID: nodeName, - Host: addrs[0], - RaftPort: appState.ServerConfig.Config.Raft.Port, - RPCPort: appState.ServerConfig.Config.Raft.InternalRPCPort, - RaftRPCMessageMaxSize: appState.ServerConfig.Config.Raft.RPCMessageMaxSize, - ServerName2PortMap: server2port, - BootstrapTimeout: appState.ServerConfig.Config.Raft.BootstrapTimeout, - BootstrapExpect: appState.ServerConfig.Config.Raft.BootstrapExpect, - HeartbeatTimeout: appState.ServerConfig.Config.Raft.HeartbeatTimeout, - RecoveryTimeout: appState.ServerConfig.Config.Raft.RecoveryTimeout, - ElectionTimeout: appState.ServerConfig.Config.Raft.ElectionTimeout, - SnapshotInterval: appState.ServerConfig.Config.Raft.SnapshotInterval, - SnapshotThreshold: appState.ServerConfig.Config.Raft.SnapshotThreshold, - UpdateWaitTimeout: time.Second * 10, // TODO-RAFT read from the flag - MetadataOnlyVoters: appState.ServerConfig.Config.Raft.MetadataOnlyVoters, - DB: nil, - Parser: schema.NewParser(appState.Cluster, vectorIndex.ParseAndValidateConfig, migrator), - AddrResolver: appState.Cluster, - Logger: appState.Logger, - LogLevel: logLevel(), - LogJSONFormat: !logTextFormat(), - IsLocalHost: appState.ServerConfig.Config.Cluster.Localhost, - LoadLegacySchema: schemaRepo.LoadLegacySchema, - SaveLegacySchema: schemaRepo.SaveLegacySchema, + WorkDir: filepath.Join(dataPath, config.DefaultRaftDir), + NodeID: nodeName, + Host: addrs[0], + RaftPort: appState.ServerConfig.Config.Raft.Port, + RPCPort: appState.ServerConfig.Config.Raft.InternalRPCPort, + RaftRPCMessageMaxSize: appState.ServerConfig.Config.Raft.RPCMessageMaxSize, + ServerName2PortMap: server2port, + BootstrapTimeout: appState.ServerConfig.Config.Raft.BootstrapTimeout, + BootstrapExpect: appState.ServerConfig.Config.Raft.BootstrapExpect, + HeartbeatTimeout: appState.ServerConfig.Config.Raft.HeartbeatTimeout, + RecoveryTimeout: appState.ServerConfig.Config.Raft.RecoveryTimeout, + ElectionTimeout: appState.ServerConfig.Config.Raft.ElectionTimeout, + SnapshotInterval: appState.ServerConfig.Config.Raft.SnapshotInterval, + SnapshotThreshold: appState.ServerConfig.Config.Raft.SnapshotThreshold, + ConsistencyWaitTimeout: appState.ServerConfig.Config.Raft.ConsistencyWaitTimeout, + MetadataOnlyVoters: appState.ServerConfig.Config.Raft.MetadataOnlyVoters, + DB: nil, + Parser: schema.NewParser(appState.Cluster, vectorIndex.ParseAndValidateConfig, migrator), + AddrResolver: appState.Cluster, + Logger: appState.Logger, + LogLevel: logLevel(), + LogJSONFormat: !logTextFormat(), + IsLocalHost: appState.ServerConfig.Config.Cluster.Localhost, + LoadLegacySchema: schemaRepo.LoadLegacySchema, + SaveLegacySchema: schemaRepo.SaveLegacySchema, } for _, name := range appState.ServerConfig.Config.Raft.Join[:rConfig.BootstrapExpect] { if strings.Contains(name, rConfig.NodeID) { diff --git a/adapters/handlers/rest/doc.go b/adapters/handlers/rest/doc.go index 46bf8ea705..433ae863f3 100644 --- a/adapters/handlers/rest/doc.go +++ b/adapters/handlers/rest/doc.go @@ -18,7 +18,7 @@ // https // Host: localhost // BasePath: /v1 -// Version: 1.25.0-rc.0 +// Version: 1.25.0 // Contact: Weaviate https://github.com/weaviate // // Consumes: diff --git a/adapters/handlers/rest/embedded_spec.go b/adapters/handlers/rest/embedded_spec.go index 47b9929785..385beb7c4b 100644 --- a/adapters/handlers/rest/embedded_spec.go +++ b/adapters/handlers/rest/embedded_spec.go @@ -48,7 +48,7 @@ func init() { "url": "https://github.com/weaviate", "email": "hello@weaviate.io" }, - "version": "1.25.0-rc.0" + "version": "1.25.0" }, "basePath": "/v1", "paths": { @@ -5294,7 +5294,7 @@ func init() { "url": "https://github.com/weaviate", "email": "hello@weaviate.io" }, - "version": "1.25.0-rc.0" + "version": "1.25.0" }, "basePath": "/v1", "paths": { diff --git a/adapters/repos/db/aggregations_fixtures_for_test.go b/adapters/repos/db/aggregations_fixtures_for_test.go index 37482a1799..40ed8bb899 100644 --- a/adapters/repos/db/aggregations_fixtures_for_test.go +++ b/adapters/repos/db/aggregations_fixtures_for_test.go @@ -36,6 +36,25 @@ var productClass = &models.Class{ }, } +func boolRef(b bool) *bool { + return &b +} + +var notIndexedClass = &models.Class{ + Class: "NotIndexedClass", + VectorIndexConfig: enthnsw.NewDefaultUserConfig(), + InvertedIndexConfig: invertedConfig(), + Properties: []*models.Property{ + { + Name: "name", + DataType: schema.DataTypeText.PropString(), + Tokenization: models.PropertyTokenizationWhitespace, + IndexFilterable: boolRef(false), + IndexInverted: boolRef(false), + }, + }, +} + var companyClass = &models.Class{ Class: "AggregationsTestCompany", VectorIndexConfig: enthnsw.NewDefaultUserConfig(), diff --git a/adapters/repos/db/aggregations_integration_test.go b/adapters/repos/db/aggregations_integration_test.go index 094e8cceb1..4fe3b75575 100644 --- a/adapters/repos/db/aggregations_integration_test.go +++ b/adapters/repos/db/aggregations_integration_test.go @@ -129,6 +129,7 @@ func prepareCompanyTestSchemaAndData(repo *DB, Objects: &models.Schema{ Classes: []*models.Class{ productClass, + notIndexedClass, companyClass, arrayTypesClass, customerClass, @@ -147,6 +148,8 @@ func prepareCompanyTestSchemaAndData(repo *DB, migrator.AddClass(context.Background(), arrayTypesClass, schemaGetter.shardState)) require.Nil(t, migrator.AddClass(context.Background(), customerClass, schemaGetter.shardState)) + require.Nil(t, + migrator.AddClass(context.Background(), notIndexedClass, schemaGetter.shardState)) }) schemaGetter.schema = schema @@ -165,6 +168,20 @@ func prepareCompanyTestSchemaAndData(repo *DB, } }) + t.Run("import products into notIndexed class", func(t *testing.T) { + for i, schema := range products { + t.Run(fmt.Sprintf("importing product %d", i), func(t *testing.T) { + fixture := models.Object{ + Class: notIndexedClass.Class, + ID: productsIds[i], + Properties: schema, + } + require.Nil(t, + repo.PutObject(context.Background(), &fixture, []float32{0.1, 0.2, 0.01, 0.2}, nil, nil, 0)) + }) + } + }) + t.Run("import companies", func(t *testing.T) { for j := 0; j < importFactor; j++ { for i, schema := range companies { diff --git a/adapters/repos/db/aggregator/filtered.go b/adapters/repos/db/aggregator/filtered.go index ff205f3e85..3478d90b3b 100644 --- a/adapters/repos/db/aggregator/filtered.go +++ b/adapters/repos/db/aggregator/filtered.go @@ -124,6 +124,9 @@ func (fa *filteredAggregator) bm25Objects(ctx context.Context, kw *searchparams. return nil, nil, fmt.Errorf("bm25 objects: could not find class %s in schema", fa.params.ClassName) } cfg := inverted.ConfigFromModel(class.InvertedIndexConfig) + + kw.ChooseSearchableProperties(class) + objs, scores, err := inverted.NewBM25Searcher(cfg.BM25, fa.store, fa.getSchema.ReadOnlyClass, propertyspecific.Indices{}, fa.classSearcher, fa.GetPropertyLengthTracker(), fa.logger, fa.shardVersion, diff --git a/adapters/repos/db/aggregator/hybrid.go b/adapters/repos/db/aggregator/hybrid.go index c97e9370a4..231364471c 100644 --- a/adapters/repos/db/aggregator/hybrid.go +++ b/adapters/repos/db/aggregator/hybrid.go @@ -48,6 +48,8 @@ func (a *Aggregator) bm25Objects(ctx context.Context, kw *searchparams.KeywordRa } cfg := inverted.ConfigFromModel(class.InvertedIndexConfig) + kw.ChooseSearchableProperties(class) + objs, dists, err := inverted.NewBM25Searcher(cfg.BM25, a.store, a.getSchema.ReadOnlyClass, propertyspecific.Indices{}, a.classSearcher, a.GetPropertyLengthTracker(), a.logger, a.shardVersion, diff --git a/adapters/repos/db/bm25f_test.go b/adapters/repos/db/bm25f_test.go index cf144644eb..e7fba29308 100644 --- a/adapters/repos/db/bm25f_test.go +++ b/adapters/repos/db/bm25f_test.go @@ -114,6 +114,14 @@ func SetupClass(t require.TestingT, repo *DB, schemaGetter *fakeSchemaGetter, lo IndexFilterable: &vFalse, IndexSearchable: &vTrue, }, + // Test that bm25f handles this property being unsearchable + { + Name: "notSearchable", + DataType: schema.DataTypeTextArray.PropString(), + Tokenization: models.PropertyTokenizationWhitespace, + IndexFilterable: &vFalse, + IndexSearchable: &vFalse, + }, }, } diff --git a/adapters/repos/db/helper_for_test.go b/adapters/repos/db/helper_for_test.go index d1a9b8df7d..0d2ea8bab7 100644 --- a/adapters/repos/db/helper_for_test.go +++ b/adapters/repos/db/helper_for_test.go @@ -26,11 +26,11 @@ import ( "github.com/weaviate/weaviate/adapters/repos/db/indexcheckpoint" "github.com/weaviate/weaviate/adapters/repos/db/inverted" "github.com/weaviate/weaviate/adapters/repos/db/inverted/stopwords" - "github.com/weaviate/weaviate/entities/locks" "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/schema" schemaConfig "github.com/weaviate/weaviate/entities/schema/config" "github.com/weaviate/weaviate/entities/storobj" + esync "github.com/weaviate/weaviate/entities/sync" enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw" "github.com/weaviate/weaviate/usecases/memwatch" ) @@ -258,6 +258,7 @@ func testShardWithSettings(t *testing.T, ctx context.Context, class *models.Clas RootPath: tmpDir, ClassName: schema.ClassName(class.Class), QueryMaximumResults: maxResults, + ReplicationFactor: NewAtomicInt64(1), }, invertedIndexConfig: iic, vectorIndexUserConfig: vic, @@ -267,8 +268,8 @@ func testShardWithSettings(t *testing.T, ctx context.Context, class *models.Clas stopwords: sd, indexCheckpoints: checkpts, allocChecker: memwatch.NewDummyMonitor(), - shardCreateLocks: locks.NewNamedLocks(), - shardInUseLocks: locks.NewNamedRWLocks(), + shardCreateLocks: esync.NewKeyLocker(), + shardInUseLocks: esync.NewKeyRWLocker(), } idx.closingCtx, idx.closingCancel = context.WithCancel(context.Background()) idx.initCycleCallbacksNoop() @@ -294,7 +295,7 @@ func testObject(className string) *storobj.Object { } } -func createRandomObjects(r *rand.Rand, className string, numObj int) []*storobj.Object { +func createRandomObjects(r *rand.Rand, className string, numObj int, vectorDim int) []*storobj.Object { obj := make([]*storobj.Object, numObj) for i := 0; i < numObj; i++ { @@ -304,7 +305,11 @@ func createRandomObjects(r *rand.Rand, className string, numObj int) []*storobj. ID: strfmt.UUID(uuid.NewString()), Class: className, }, - Vector: []float32{r.Float32(), r.Float32(), r.Float32(), r.Float32()}, + Vector: make([]float32, vectorDim), + } + + for d := 0; d < vectorDim; d++ { + obj[i].Vector[d] = r.Float32() } } return obj diff --git a/adapters/repos/db/index.go b/adapters/repos/db/index.go index 70c075b0c6..11e37d2d3d 100644 --- a/adapters/repos/db/index.go +++ b/adapters/repos/db/index.go @@ -44,7 +44,6 @@ import ( "github.com/weaviate/weaviate/entities/errorcompounder" enterrors "github.com/weaviate/weaviate/entities/errors" "github.com/weaviate/weaviate/entities/filters" - "github.com/weaviate/weaviate/entities/locks" "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/multi" "github.com/weaviate/weaviate/entities/schema" @@ -52,6 +51,7 @@ import ( "github.com/weaviate/weaviate/entities/search" "github.com/weaviate/weaviate/entities/searchparams" "github.com/weaviate/weaviate/entities/storobj" + esync "github.com/weaviate/weaviate/entities/sync" "github.com/weaviate/weaviate/usecases/config" "github.com/weaviate/weaviate/usecases/memwatch" "github.com/weaviate/weaviate/usecases/modules" @@ -198,8 +198,8 @@ type Index struct { // loading will be set to true once the last shard was loaded. allShardsReady atomic.Bool allocChecker memwatch.AllocChecker - shardCreateLocks *locks.NamedLocks - shardInUseLocks *locks.NamedRWLocks + shardCreateLocks *esync.KeyLocker + shardInUseLocks *esync.KeyRWLocker modules *modules.Provider } @@ -269,8 +269,8 @@ func NewIndex(ctx context.Context, cfg IndexConfig, backupMutex: backupMutex{log: logger, retryDuration: mutexRetryDuration, notifyDuration: mutexNotifyDuration}, indexCheckpoints: indexCheckpoints, allocChecker: allocChecker, - shardCreateLocks: locks.NewNamedLocks(), - shardInUseLocks: locks.NewNamedRWLocks(), + shardCreateLocks: esync.NewKeyLocker(), + shardInUseLocks: esync.NewKeyRWLocker(), } index.closingCtx, index.closingCancel = context.WithCancel(context.Background()) @@ -555,7 +555,9 @@ type IndexConfig struct { MemtablesMaxSizeMB int MemtablesMinActiveSeconds int MemtablesMaxActiveSeconds int - ReplicationFactor int64 + MaxSegmentSize int64 + HNSWMaxLogSize int64 + ReplicationFactor *atomic.Int64 AvoidMMap bool DisableLazyLoadShards bool @@ -674,19 +676,15 @@ func (i *Index) IncomingPutObject(ctx context.Context, shardName string, shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return ErrShardNotFound - } - defer release() - - if err := shard.PutObject(ctx, object); err != nil { return err } + defer release() - return nil + return shard.PutObject(ctx, object) } func (i *Index) replicationEnabled() bool { - return i.Config.ReplicationFactor > 1 + return i.Config.ReplicationFactor.Load() > 1 } // parseDateFieldsInProps checks the schema for the current class for which @@ -900,7 +898,7 @@ func (i *Index) IncomingBatchPutObjects(ctx context.Context, shardName string, shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return duplicateErr(ErrShardNotFound, len(objects)) + return duplicateErr(err, len(objects)) } defer release() @@ -975,7 +973,7 @@ func (i *Index) IncomingBatchAddReferences(ctx context.Context, shardName string shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return duplicateErr(ErrShardNotFound, len(refs)) + return duplicateErr(err, len(refs)) } defer release() @@ -1040,7 +1038,7 @@ func (i *Index) IncomingGetObject(ctx context.Context, shardName string, ) (*storobj.Object, error) { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return nil, ErrShardNotFound + return nil, err } defer release() @@ -1052,7 +1050,7 @@ func (i *Index) IncomingMultiGetObjects(ctx context.Context, shardName string, ) ([]*storobj.Object, error) { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return nil, ErrShardNotFound + return nil, err } defer release() @@ -1188,7 +1186,7 @@ func (i *Index) IncomingExists(ctx context.Context, shardName string, ) (bool, error) { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return false, ErrShardNotFound + return false, err } defer release() @@ -1611,7 +1609,7 @@ func (i *Index) IncomingSearch(ctx context.Context, shardName string, ) ([]*storobj.Object, []float32, error) { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return nil, nil, ErrShardNotFound + return nil, nil, err } defer release() @@ -1687,7 +1685,7 @@ func (i *Index) IncomingDeleteObject(ctx context.Context, shardName string, shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return ErrShardNotFound + return err } defer release() @@ -1721,12 +1719,12 @@ func (i *Index) getOrInitLocalShardNoShutdown(ctx context.Context, shardName str shard, err := i.getOrInitLocalShard(ctx, shardName) if err != nil { - return nil, func() {}, err + return nil, func() {}, fmt.Errorf("get/init local shard %q, no shutdown: %w", shardName, err) } release, err := shard.preventShutdown() if err != nil { - return nil, func() {}, err + return nil, func() {}, fmt.Errorf("get/init local shard %q, no shutdown: %w", shardName, err) } return shard, release, nil } @@ -1816,7 +1814,7 @@ func (i *Index) IncomingMergeObject(ctx context.Context, shardName string, shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return ErrShardNotFound + return err } defer release() @@ -1869,7 +1867,7 @@ func (i *Index) IncomingAggregate(ctx context.Context, shardName string, ) (*aggregation.Result, error) { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return nil, ErrShardNotFound + return nil, err } defer release() @@ -2043,7 +2041,7 @@ func (i *Index) getShardsQueueSize(ctx context.Context, tenant string) (map[stri func (i *Index) IncomingGetShardQueueSize(ctx context.Context, shardName string) (int64, error) { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return 0, ErrShardNotFound + return 0, err } defer release() @@ -2098,7 +2096,7 @@ func (i *Index) getShardsStatus(ctx context.Context, tenant string) (map[string] func (i *Index) IncomingGetShardStatus(ctx context.Context, shardName string) (string, error) { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return "", ErrShardNotFound + return "", err } defer release() @@ -2120,7 +2118,7 @@ func (i *Index) updateShardStatus(ctx context.Context, shardName, targetStatus s func (i *Index) IncomingUpdateShardStatus(ctx context.Context, shardName, targetStatus string, schemaVersion uint64) error { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return ErrShardNotFound + return err } defer release() @@ -2179,7 +2177,7 @@ func (i *Index) IncomingFindUUIDs(ctx context.Context, shardName string, ) ([]strfmt.UUID, error) { shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { - return nil, ErrShardNotFound + return nil, err } defer release() @@ -2255,7 +2253,7 @@ func (i *Index) IncomingDeleteObjectBatch(ctx context.Context, shardName string, shard, release, err := i.getOrInitLocalShardNoShutdown(ctx, shardName) if err != nil { return objects.BatchSimpleObjects{ - objects.BatchSimpleObject{Err: ErrShardNotFound}, + objects.BatchSimpleObject{Err: err}, } } defer release() diff --git a/adapters/repos/db/index_integration_test.go b/adapters/repos/db/index_integration_test.go index 89a0dcbda1..1e07438ce6 100644 --- a/adapters/repos/db/index_integration_test.go +++ b/adapters/repos/db/index_integration_test.go @@ -103,8 +103,9 @@ func TestIndex_DropWithDataAndRecreateWithDataIndex(t *testing.T) { // create index with data shardState := singleShardState() index, err := NewIndex(testCtx(), IndexConfig{ - RootPath: dirName, - ClassName: schema.ClassName(class.Class), + RootPath: dirName, + ClassName: schema.ClassName(class.Class), + ReplicationFactor: NewAtomicInt64(1), }, shardState, inverted.ConfigFromModel(class.InvertedIndexConfig), hnsw.NewDefaultUserConfig(), nil, &fakeSchemaGetter{ schema: fakeSchema, shardState: shardState, @@ -163,8 +164,9 @@ func TestIndex_DropWithDataAndRecreateWithDataIndex(t *testing.T) { // recreate the index index, err = NewIndex(testCtx(), IndexConfig{ - RootPath: dirName, - ClassName: schema.ClassName(class.Class), + RootPath: dirName, + ClassName: schema.ClassName(class.Class), + ReplicationFactor: NewAtomicInt64(1), }, shardState, inverted.ConfigFromModel(class.InvertedIndexConfig), hnsw.NewDefaultUserConfig(), nil, &fakeSchemaGetter{ schema: fakeSchema, @@ -273,8 +275,9 @@ func TestIndex_DropReadOnlyIndexWithData(t *testing.T) { shardState := singleShardState() index, err := NewIndex(ctx, IndexConfig{ - RootPath: dirName, - ClassName: schema.ClassName(class.Class), + RootPath: dirName, + ClassName: schema.ClassName(class.Class), + ReplicationFactor: NewAtomicInt64(1), }, shardState, inverted.ConfigFromModel(class.InvertedIndexConfig), hnsw.NewDefaultUserConfig(), nil, &fakeSchemaGetter{ schema: fakeSchema, shardState: shardState, @@ -332,6 +335,7 @@ func emptyIdx(t *testing.T, rootDir string, class *models.Class) *Index { RootPath: rootDir, ClassName: schema.ClassName(class.Class), DisableLazyLoadShards: true, + ReplicationFactor: NewAtomicInt64(1), }, shardState, inverted.ConfigFromModel(invertedConfig()), hnsw.NewDefaultUserConfig(), nil, &fakeSchemaGetter{ shardState: shardState, diff --git a/adapters/repos/db/index_queue_test.go b/adapters/repos/db/index_queue_test.go index 0e0ca552d2..697ca9017a 100644 --- a/adapters/repos/db/index_queue_test.go +++ b/adapters/repos/db/index_queue_test.go @@ -601,7 +601,7 @@ func TestIndexQueue(t *testing.T) { t.Run("compression", func(t *testing.T) { var idx mockBatchIndexer - called := make(chan struct{}) + called := make(chan struct{}, 1) idx.shouldCompress = true idx.threshold = 4 idx.alreadyIndexed.Store(6) @@ -613,7 +613,7 @@ func TestIndexQueue(t *testing.T) { callback() }() - close(called) + called <- struct{}{} return nil } diff --git a/adapters/repos/db/init.go b/adapters/repos/db/init.go index 3a2ec59d6d..7b6e314b8e 100644 --- a/adapters/repos/db/init.go +++ b/adapters/repos/db/init.go @@ -16,6 +16,7 @@ import ( "fmt" "os" "path" + "sync/atomic" "time" enterrors "github.com/weaviate/weaviate/entities/errors" @@ -89,10 +90,12 @@ func (db *DB) init(ctx context.Context) error { MemtablesMaxSizeMB: db.config.MemtablesMaxSizeMB, MemtablesMinActiveSeconds: db.config.MemtablesMinActiveSeconds, MemtablesMaxActiveSeconds: db.config.MemtablesMaxActiveSeconds, + MaxSegmentSize: db.config.MaxSegmentSize, + HNSWMaxLogSize: db.config.HNSWMaxLogSize, TrackVectorDimensions: db.config.TrackVectorDimensions, AvoidMMap: db.config.AvoidMMap, DisableLazyLoadShards: db.config.DisableLazyLoadShards, - ReplicationFactor: class.ReplicationConfig.Factor, + ReplicationFactor: NewAtomicInt64(class.ReplicationConfig.Factor), }, db.schemaGetter.CopyShardingState(class.Class), inverted.ConfigFromModel(invertedConfig), convertToVectorIndexConfig(class.VectorIndexConfig), @@ -166,3 +169,9 @@ func fileExists(file string) (bool, error) { } return true, nil } + +func NewAtomicInt64(val int64) *atomic.Int64 { + aval := &atomic.Int64{} + aval.Store(val) + return aval +} diff --git a/adapters/repos/db/inverted/bm25_searcher.go b/adapters/repos/db/inverted/bm25_searcher.go index 7c3ca27eea..eec2325d01 100644 --- a/adapters/repos/db/inverted/bm25_searcher.go +++ b/adapters/repos/db/inverted/bm25_searcher.go @@ -79,6 +79,7 @@ func (b *BM25Searcher) BM25F(ctx context.Context, filterDocIds helpers.AllowList return nil, nil, inverted.NewMissingSearchableIndexError(property) } } + class := b.getClass(className.String()) if class == nil { return nil, nil, fmt.Errorf("could not find class %s in schema", className) diff --git a/adapters/repos/db/lsmkv/bucket.go b/adapters/repos/db/lsmkv/bucket.go index e14453ec45..65275f9b29 100644 --- a/adapters/repos/db/lsmkv/bucket.go +++ b/adapters/repos/db/lsmkv/bucket.go @@ -118,6 +118,10 @@ type Bucket struct { // optionally supplied to prevent starting memory-intensive // processes when memory pressure is high allocChecker memwatch.AllocChecker + + // optional segment size limit. If set, a compaction will skip segments that + // sum to more than the specified value. + maxSegmentSize int64 } func NewBucketCreator() *Bucket { return &Bucket{} } @@ -178,6 +182,7 @@ func (*Bucket) NewBucket(ctx context.Context, dir, rootDir string, logger logrus forceCompaction: b.forceCompaction, useBloomFilter: b.useBloomFilter, calcCountNetAdditions: b.calcCountNetAdditions, + maxSegmentSize: b.maxSegmentSize, }, b.allocChecker) if err != nil { return nil, fmt.Errorf("init disk segments: %w", err) @@ -1002,6 +1007,11 @@ func (b *Bucket) atomicallyAddDiskSegmentAndRemoveFlushing() error { b.flushLock.Lock() defer b.flushLock.Unlock() + if b.flushing.Size() == 0 { + b.flushing = nil + return nil + } + path := b.flushing.path if err := b.disk.add(path + ".db"); err != nil { return err diff --git a/adapters/repos/db/lsmkv/bucket_options.go b/adapters/repos/db/lsmkv/bucket_options.go index c0e004991e..bc53fc43c8 100644 --- a/adapters/repos/db/lsmkv/bucket_options.go +++ b/adapters/repos/db/lsmkv/bucket_options.go @@ -148,6 +148,13 @@ func WithCalcCountNetAdditions(calcCountNetAdditions bool) BucketOption { } } +func WithMaxSegmentSize(maxSegmentSize int64) BucketOption { + return func(b *Bucket) error { + b.maxSegmentSize = maxSegmentSize + return nil + } +} + /* Background for this option: diff --git a/adapters/repos/db/lsmkv/segment_group.go b/adapters/repos/db/lsmkv/segment_group.go index 668d89a7c3..0e643494b0 100644 --- a/adapters/repos/db/lsmkv/segment_group.go +++ b/adapters/repos/db/lsmkv/segment_group.go @@ -62,7 +62,8 @@ type SegmentGroup struct { calcCountNetAdditions bool // see bucket for more datails compactLeftOverSegments bool // see bucket for more datails - allocChecker memwatch.AllocChecker + allocChecker memwatch.AllocChecker + maxSegmentSize int64 } type sgConfig struct { @@ -75,6 +76,7 @@ type sgConfig struct { useBloomFilter bool calcCountNetAdditions bool forceCompaction bool + maxSegmentSize int64 } func newSegmentGroup(logger logrus.FieldLogger, metrics *Metrics, @@ -99,6 +101,7 @@ func newSegmentGroup(logger logrus.FieldLogger, metrics *Metrics, useBloomFilter: cfg.useBloomFilter, calcCountNetAdditions: cfg.calcCountNetAdditions, compactLeftOverSegments: cfg.forceCompaction, + maxSegmentSize: cfg.maxSegmentSize, allocChecker: allocChecker, } @@ -125,7 +128,11 @@ func newSegmentGroup(logger logrus.FieldLogger, metrics *Metrics, jointSegmentsIDs := strings.Split(jointSegments, "_") if len(jointSegmentsIDs) != 2 { - return nil, fmt.Errorf("invalid compacted segment file name %q", entry.Name()) + logger.WithField("action", "lsm_segment_init"). + WithField("path", filepath.Join(sg.dir, entry.Name())). + Warn("ignored (partially written) LSM compacted segment generated with a version older than v1.24.0") + + continue } leftSegmentFilename := fmt.Sprintf("segment-%s.db", jointSegmentsIDs[0]) @@ -217,7 +224,7 @@ func newSegmentGroup(logger logrus.FieldLogger, metrics *Metrics, logger.WithField("action", "lsm_segment_init"). WithField("path", filepath.Join(sg.dir, entry.Name())). WithField("wal_path", walFileName). - Info("Discarded (partially written) LSM segment, because an active WAL for " + + Info("discarded (partially written) LSM segment, because an active WAL for " + "the same segment was found. A recovery from the WAL will follow.") continue diff --git a/adapters/repos/db/lsmkv/segment_group_compaction.go b/adapters/repos/db/lsmkv/segment_group_compaction.go index 80e02b818f..1cb0fe7fc9 100644 --- a/adapters/repos/db/lsmkv/segment_group_compaction.go +++ b/adapters/repos/db/lsmkv/segment_group_compaction.go @@ -146,6 +146,12 @@ func (sg *SegmentGroup) compactOnce() (bool, error) { leftSegment := sg.segmentAtPos(pair[0]) rightSegment := sg.segmentAtPos(pair[1]) + if !sg.compactionFitsSizeLimit(leftSegment, rightSegment) { + // nothing to do this round, let's wait for the next round in the hopes + // that we'll find smaller (lower-level) segments that can still fit. + return false, nil + } + path := filepath.Join(sg.dir, "segment-"+segmentID(leftSegment.path)+"_"+segmentID(rightSegment.path)+".db.tmp") f, err := os.Create(path) @@ -470,3 +476,13 @@ func (s *segmentLevelStats) report(metrics *Metrics, }).Set(float64(count)) } } + +func (sg *SegmentGroup) compactionFitsSizeLimit(left, right *segment) bool { + if sg.maxSegmentSize == 0 { + // no limit is set, always return true + return true + } + + totalSize := left.size + right.size + return totalSize <= sg.maxSegmentSize +} diff --git a/adapters/repos/db/lsmkv/segment_group_compaction_test.go b/adapters/repos/db/lsmkv/segment_group_compaction_test.go new file mode 100644 index 0000000000..b5d4a06abf --- /dev/null +++ b/adapters/repos/db/lsmkv/segment_group_compaction_test.go @@ -0,0 +1,101 @@ +// _ _ +// __ _____ __ ___ ___ __ _| |_ ___ +// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ +// \ V V / __/ (_| |\ V /| | (_| | || __/ +// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| +// +// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. +// +// CONTACT: hello@weaviate.io +// + +package lsmkv + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSegmentGroup_BestCompactionPair(t *testing.T) { + var maxSegmentSize int64 = 10000 + + tests := []struct { + name string + segments []*segment + expectedPair []string + }{ + { + name: "single segment", + segments: []*segment{ + {size: 1000, path: "segment0", level: 0}, + }, + expectedPair: nil, + }, + { + name: "two segments, same level", + segments: []*segment{ + {size: 1000, path: "segment0", level: 0}, + {size: 1000, path: "segment1", level: 0}, + }, + expectedPair: []string{"segment0", "segment1"}, + }, + { + name: "multiple segments, multiple levels, lowest level is picked", + segments: []*segment{ + {size: 4000, path: "segment0", level: 2}, + {size: 4000, path: "segment1", level: 2}, + {size: 2000, path: "segment2", level: 1}, + {size: 2000, path: "segment3", level: 1}, + {size: 1000, path: "segment4", level: 0}, + {size: 1000, path: "segment5", level: 0}, + }, + expectedPair: []string{"segment4", "segment5"}, + }, + { + name: "two segments that don't fit the max size, but eliglbe segments of a lower level are present", + segments: []*segment{ + {size: 8000, path: "segment0", level: 3}, + {size: 8000, path: "segment1", level: 3}, + {size: 4000, path: "segment2", level: 2}, + {size: 4000, path: "segment3", level: 2}, + }, + expectedPair: []string{"segment2", "segment3"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sg := &SegmentGroup{ + segments: test.segments, + maxSegmentSize: maxSegmentSize, + } + pair := sg.bestCompactionCandidatePair() + if test.expectedPair == nil { + assert.Nil(t, pair) + } else { + leftPath := test.segments[pair[0]].path + rightPath := test.segments[pair[1]].path + assert.Equal(t, test.expectedPair, []string{leftPath, rightPath}) + } + }) + } +} + +func TestSegmenGroup_CompactionLargerThanMaxSize(t *testing.T) { + maxSegmentSize := int64(10000) + // this test only tests the unhappy path which has an early exist condition, + // meaning we don't need real segments, it is only metadata that is evaluated + // here. + sg := &SegmentGroup{ + segments: []*segment{ + {size: 8000, path: "segment0", level: 3}, + {size: 8000, path: "segment1", level: 3}, + }, + maxSegmentSize: maxSegmentSize, + } + + ok, err := sg.compactOnce() + assert.False(t, ok, "segments are too large to run") + assert.Nil(t, err) +} diff --git a/adapters/repos/db/lsmkv/store.go b/adapters/repos/db/lsmkv/store.go index 46a7f2b0d2..d709933268 100644 --- a/adapters/repos/db/lsmkv/store.go +++ b/adapters/repos/db/lsmkv/store.go @@ -61,7 +61,7 @@ func New(dir, rootDir string, logger logrus.FieldLogger, metrics *Metrics, dir: dir, rootDir: rootDir, bucketsByName: map[string]*Bucket{}, - bucketsLocks: wsync.New(), + bucketsLocks: wsync.NewKeyLocker(), bcreator: NewBucketCreator(), logger: logger, metrics: metrics, diff --git a/adapters/repos/db/migrator.go b/adapters/repos/db/migrator.go index e3376c1750..b4d3a5a1bd 100644 --- a/adapters/repos/db/migrator.go +++ b/adapters/repos/db/migrator.go @@ -63,10 +63,12 @@ func (m *Migrator) AddClass(ctx context.Context, class *models.Class, MemtablesMaxSizeMB: m.db.config.MemtablesMaxSizeMB, MemtablesMinActiveSeconds: m.db.config.MemtablesMinActiveSeconds, MemtablesMaxActiveSeconds: m.db.config.MemtablesMaxActiveSeconds, + MaxSegmentSize: m.db.config.MaxSegmentSize, + HNSWMaxLogSize: m.db.config.HNSWMaxLogSize, TrackVectorDimensions: m.db.config.TrackVectorDimensions, AvoidMMap: m.db.config.AvoidMMap, DisableLazyLoadShards: m.db.config.DisableLazyLoadShards, - ReplicationFactor: class.ReplicationConfig.Factor, + ReplicationFactor: NewAtomicInt64(class.ReplicationConfig.Factor), }, shardState, // no backward-compatibility check required, since newly added classes will @@ -454,6 +456,16 @@ func (m *Migrator) UpdateInvertedIndexConfig(ctx context.Context, className stri return idx.updateInvertedIndexConfig(ctx, conf) } +func (m *Migrator) UpdateReplicationFactor(ctx context.Context, className string, factor int64) error { + idx := m.db.GetIndex(schema.ClassName(className)) + if idx == nil { + return errors.Errorf("cannot update replication factor of non-existing index for %s", className) + } + + idx.Config.ReplicationFactor.Store(factor) + return nil +} + func (m *Migrator) RecalculateVectorDimensions(ctx context.Context) error { count := 0 m.logger. diff --git a/adapters/repos/db/node_wide_metrics_test.go b/adapters/repos/db/node_wide_metrics_test.go index 1ec6be79cc..0c88699e0b 100644 --- a/adapters/repos/db/node_wide_metrics_test.go +++ b/adapters/repos/db/node_wide_metrics_test.go @@ -28,14 +28,16 @@ func TestShardActivity(t *testing.T) { indices: map[string]*Index{ "Col1": { Config: IndexConfig{ - ClassName: "Col1", + ClassName: "Col1", + ReplicationFactor: NewAtomicInt64(1), }, partitioningEnabled: true, shards: shardMap{}, }, "NonMT": { Config: IndexConfig{ - ClassName: "NonMT", + ClassName: "NonMT", + ReplicationFactor: NewAtomicInt64(1), }, partitioningEnabled: false, shards: shardMap{}, diff --git a/adapters/repos/db/nodes.go b/adapters/repos/db/nodes.go index 873e15dc98..726f7d3c0a 100644 --- a/adapters/repos/db/nodes.go +++ b/adapters/repos/db/nodes.go @@ -291,7 +291,7 @@ func (db *DB) localNodeStatistics() (*models.Statistics, error) { IsVoter: stats["is_voter"].(bool), Open: stats["open"].(bool), Bootstrapped: stats["bootstrapped"].(bool), - InitialLastAppliedIndex: stats["initial_last_applied_index"].(uint64), + InitialLastAppliedIndex: stats["last_store_log_applied_index"].(uint64), DbLoaded: stats["db_loaded"].(bool), Candidates: stats["candidates"], Raft: raft, diff --git a/adapters/repos/db/replication.go b/adapters/repos/db/replication.go index 2161a1c052..17a170a614 100644 --- a/adapters/repos/db/replication.go +++ b/adapters/repos/db/replication.go @@ -20,6 +20,7 @@ import ( "path/filepath" "github.com/go-openapi/strfmt" + "github.com/pkg/errors" "github.com/weaviate/weaviate/entities/additional" "github.com/weaviate/weaviate/entities/multi" "github.com/weaviate/weaviate/entities/schema" @@ -253,12 +254,34 @@ func (i *Index) IncomingCreateShard(ctx context.Context, className string, shard func (i *Index) IncomingReinitShard(ctx context.Context, shardName string, ) error { - shard, err := i.getOrInitLocalShard(ctx, shardName) - if err != nil { - return fmt.Errorf("shard %q does not exist locally", shardName) + shard := func() ShardLike { + i.shardInUseLocks.Lock(shardName) + defer i.shardInUseLocks.Unlock(shardName) + + return i.shards.Load(shardName) + }() + + if shard != nil { + err := func() error { + i.shardCreateLocks.Lock(shardName) + defer i.shardCreateLocks.Unlock(shardName) + + i.shards.LoadAndDelete(shardName) + + if err := shard.Shutdown(ctx); err != nil { + if !errors.Is(err, errAlreadyShutdown) { + return err + } + } + return nil + }() + if err != nil { + return err + } } - return shard.reinit(ctx) + _, err := i.getOrInitLocalShard(ctx, shardName) + return err } func (s *Shard) filePutter(ctx context.Context, @@ -280,31 +303,6 @@ func (s *Shard) filePutter(ctx context.Context, return f, nil } -func (s *Shard) reinit(ctx context.Context) error { - if err := s.Shutdown(ctx); err != nil { - return fmt.Errorf("shutdown shard: %w", err) - } - - if err := s.initNonVector(ctx, nil); err != nil { - return fmt.Errorf("reinit non-vector: %w", err) - } - - if s.hasTargetVectors() { - if err := s.initTargetVectors(ctx); err != nil { - return fmt.Errorf("reinit vector: %w", err) - } - } else { - if err := s.initLegacyVector(ctx); err != nil { - return fmt.Errorf("reinit vector: %w", err) - } - } - - s.initCycleCallbacks() - s.initDimensionTracking() - - return nil -} - // OverwriteObjects if their state didn't change in the meantime // It returns nil if all object have been successfully overwritten // and otherwise a list of failed operations. diff --git a/adapters/repos/db/repo.go b/adapters/repos/db/repo.go index 46248d6af1..dce18b4ded 100644 --- a/adapters/repos/db/repo.go +++ b/adapters/repos/db/repo.go @@ -196,6 +196,8 @@ type Config struct { MemtablesMaxSizeMB int MemtablesMinActiveSeconds int MemtablesMaxActiveSeconds int + MaxSegmentSize int64 + HNSWMaxLogSize int64 TrackVectorDimensions bool ServerVersion string GitHash string diff --git a/adapters/repos/db/shard.go b/adapters/repos/db/shard.go index f5cda0b711..5e89ea7557 100644 --- a/adapters/repos/db/shard.go +++ b/adapters/repos/db/shard.go @@ -132,7 +132,6 @@ type ShardLike interface { commitReplication(context.Context, string, *backupMutex) interface{} abortReplication(context.Context, string) replica.SimpleResponse - reinit(context.Context) error filePutter(context.Context, string) (io.WriteCloser, error) // TODO tests only @@ -424,7 +423,11 @@ func (s *Shard) initVectorIndex(ctx context.Context, MakeCommitLoggerThunk: func() (hnsw.CommitLogger, error) { return hnsw.NewCommitLogger(s.path(), vecIdxID, s.index.logger, s.cycleCallbacks.vectorCommitLoggerCallbacks, - hnsw.WithAllocChecker(s.index.allocChecker)) + hnsw.WithAllocChecker(s.index.allocChecker), + hnsw.WithCommitlogThresholdForCombining(s.index.Config.HNSWMaxLogSize), + // consistent with previous logic where the individual limit is 1/5 of the combined limit + hnsw.WithCommitlogThreshold(s.index.Config.HNSWMaxLogSize/5), + ) }, AllocChecker: s.index.allocChecker, }, hnswUserConfig, s.cycleCallbacks.vectorTombstoneCleanupCallbacks, @@ -592,6 +595,7 @@ func (s *Shard) initLSMStore(ctx context.Context) error { s.dynamicMemtableSizing(), s.memtableDirtyConfig(), lsmkv.WithAllocChecker(s.index.allocChecker), + lsmkv.WithMaxSegmentSize(s.index.Config.MaxSegmentSize), } if s.metrics != nil && !s.metrics.grouped { @@ -718,6 +722,7 @@ func (s *Shard) addIDProperty(ctx context.Context) error { lsmkv.WithStrategy(lsmkv.StrategySetCollection), lsmkv.WithPread(s.index.Config.AvoidMMap), lsmkv.WithAllocChecker(s.index.allocChecker), + lsmkv.WithMaxSegmentSize(s.index.Config.MaxSegmentSize), ) } @@ -733,6 +738,7 @@ func (s *Shard) addDimensionsProperty(ctx context.Context) error { lsmkv.WithStrategy(lsmkv.StrategyMapCollection), lsmkv.WithPread(s.index.Config.AvoidMMap), lsmkv.WithAllocChecker(s.index.allocChecker), + lsmkv.WithMaxSegmentSize(s.index.Config.MaxSegmentSize), ) if err != nil { return err @@ -763,6 +769,7 @@ func (s *Shard) addCreationTimeUnixProperty(ctx context.Context) error { lsmkv.WithStrategy(lsmkv.StrategyRoaringSet), lsmkv.WithPread(s.index.Config.AvoidMMap), lsmkv.WithAllocChecker(s.index.allocChecker), + lsmkv.WithMaxSegmentSize(s.index.Config.MaxSegmentSize), ) } @@ -773,6 +780,7 @@ func (s *Shard) addLastUpdateTimeUnixProperty(ctx context.Context) error { lsmkv.WithStrategy(lsmkv.StrategyRoaringSet), lsmkv.WithPread(s.index.Config.AvoidMMap), lsmkv.WithAllocChecker(s.index.allocChecker), + lsmkv.WithMaxSegmentSize(s.index.Config.MaxSegmentSize), ) } @@ -838,6 +846,7 @@ func (s *Shard) createPropertyValueIndex(ctx context.Context, prop *models.Prope s.dynamicMemtableSizing(), lsmkv.WithPread(s.index.Config.AvoidMMap), lsmkv.WithAllocChecker(s.index.allocChecker), + lsmkv.WithMaxSegmentSize(s.index.Config.MaxSegmentSize), } if inverted.HasFilterableIndex(prop) { @@ -898,6 +907,7 @@ func (s *Shard) createPropertyLengthIndex(ctx context.Context, prop *models.Prop lsmkv.WithStrategy(lsmkv.StrategyRoaringSet), lsmkv.WithPread(s.index.Config.AvoidMMap), lsmkv.WithAllocChecker(s.index.allocChecker), + lsmkv.WithMaxSegmentSize(s.index.Config.MaxSegmentSize), ) } @@ -911,6 +921,7 @@ func (s *Shard) createPropertyNullIndex(ctx context.Context, prop *models.Proper lsmkv.WithStrategy(lsmkv.StrategyRoaringSet), lsmkv.WithPread(s.index.Config.AvoidMMap), lsmkv.WithAllocChecker(s.index.allocChecker), + lsmkv.WithMaxSegmentSize(s.index.Config.MaxSegmentSize), ) } diff --git a/adapters/repos/db/shard_lazyloader.go b/adapters/repos/db/shard_lazyloader.go index 9b926b7421..62128e7a1a 100644 --- a/adapters/repos/db/shard_lazyloader.go +++ b/adapters/repos/db/shard_lazyloader.go @@ -504,13 +504,6 @@ func (l *LazyLoadShard) abortReplication(ctx context.Context, shardID string) re return l.shard.abortReplication(ctx, shardID) } -func (l *LazyLoadShard) reinit(ctx context.Context) error { - if err := l.Load(ctx); err != nil { - return err - } - return l.shard.reinit(ctx) -} - func (l *LazyLoadShard) filePutter(ctx context.Context, shardID string) (io.WriteCloser, error) { if err := l.Load(ctx); err != nil { return nil, err diff --git a/adapters/repos/db/shard_test.go b/adapters/repos/db/shard_test.go index 0b4cba4b93..6b8057a561 100644 --- a/adapters/repos/db/shard_test.go +++ b/adapters/repos/db/shard_test.go @@ -29,8 +29,10 @@ import ( "github.com/stretchr/testify/require" "github.com/weaviate/weaviate/adapters/repos/db/lsmkv" "github.com/weaviate/weaviate/entities/additional" + "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/storagestate" "github.com/weaviate/weaviate/entities/storobj" + "github.com/weaviate/weaviate/entities/vectorindex/hnsw" ) func TestShard_UpdateStatus(t *testing.T) { @@ -171,7 +173,7 @@ func TestShard_ParallelBatches(t *testing.T) { r := getRandomSeed() batches := make([][]*storobj.Object, 4) for i := range batches { - batches[i] = createRandomObjects(r, "TestClass", 1000) + batches[i] = createRandomObjects(r, "TestClass", 1000, 4) } totalObjects := 1000 * len(batches) ctx := testCtx() @@ -191,3 +193,33 @@ func TestShard_ParallelBatches(t *testing.T) { require.Equal(t, totalObjects, int(shd.Counter().Get())) require.Nil(t, idx.drop()) } + +func TestShard_InvalidVectorBatches(t *testing.T) { + ctx := testCtx() + + class := &models.Class{Class: "TestClass"} + + shd, idx := testShardWithSettings(t, ctx, class, hnsw.NewDefaultUserConfig(), false, false) + + testShard(t, context.Background(), class.Class) + + r := getRandomSeed() + + batchSize := 1000 + + validBatch := createRandomObjects(r, class.Class, batchSize, 4) + + shd.PutObjectBatch(ctx, validBatch) + require.Equal(t, batchSize, int(shd.Counter().Get())) + + invalidBatch := createRandomObjects(r, class.Class, batchSize, 5) + + errs := shd.PutObjectBatch(ctx, invalidBatch) + require.Len(t, errs, batchSize) + for _, err := range errs { + require.ErrorContains(t, err, "new node has a vector with length 5. Existing nodes have vectors with length 4") + } + require.Equal(t, batchSize, int(shd.Counter().Get())) + + require.Nil(t, idx.drop()) +} diff --git a/adapters/repos/db/shard_write_put.go b/adapters/repos/db/shard_write_put.go index f2fb459e33..7e59161507 100644 --- a/adapters/repos/db/shard_write_put.go +++ b/adapters/repos/db/shard_write_put.go @@ -43,26 +43,6 @@ func (s *Shard) PutObject(ctx context.Context, object *storobj.Object) error { } func (s *Shard) putOne(ctx context.Context, uuid []byte, object *storobj.Object) error { - if s.hasTargetVectors() { - if len(object.Vectors) > 0 { - for targetVector, vector := range object.Vectors { - if vectorIndex := s.VectorIndexForName(targetVector); vectorIndex != nil { - if err := vectorIndex.ValidateBeforeInsert(vector); err != nil { - return errors.Wrapf(err, "Validate vector index %s for target vector %s", targetVector, object.ID()) - } - } - } - } - } else { - if object.Vector != nil { - // validation needs to happen before any changes are done. Otherwise, insertion is aborted somewhere in-between. - err := s.vectorIndex.ValidateBeforeInsert(object.Vector) - if err != nil { - return errors.Wrapf(err, "Validate vector index for %s", object.ID()) - } - } - } - status, err := s.putObjectLSM(object, uuid) if err != nil { return errors.Wrap(err, "store object in LSM store") @@ -231,13 +211,32 @@ func fetchObject(bucket *lsmkv.Bucket, idBytes []byte) (*storobj.Object, error) } func (s *Shard) putObjectLSM(obj *storobj.Object, idBytes []byte, -) (objectInsertStatus, error) { +) (status objectInsertStatus, err error) { before := time.Now() defer s.metrics.PutObject(before) + if s.hasTargetVectors() { + if len(obj.Vectors) > 0 { + for targetVector, vector := range obj.Vectors { + if vectorIndex := s.VectorIndexForName(targetVector); vectorIndex != nil { + if err := vectorIndex.ValidateBeforeInsert(vector); err != nil { + return status, errors.Wrapf(err, "Validate vector index %s for target vector %s", targetVector, obj.ID()) + } + } + } + } + } else { + if obj.Vector != nil { + // validation needs to happen before any changes are done. Otherwise, insertion is aborted somewhere in-between. + err := s.vectorIndex.ValidateBeforeInsert(obj.Vector) + if err != nil { + return status, errors.Wrapf(err, "Validate vector index for %s", obj.ID()) + } + } + } + bucket := s.store.Bucket(helpers.ObjectsBucketLSM) var prevObj *storobj.Object - var status objectInsertStatus // First the object bucket is checked if an object with the same uuid is alreadypresent, // to determine if it is insert or an update. diff --git a/adapters/repos/db/vector/hnsw/neighbor_connections.go b/adapters/repos/db/vector/hnsw/neighbor_connections.go index 0cc1fdedeb..bec86934df 100644 --- a/adapters/repos/db/vector/hnsw/neighbor_connections.go +++ b/adapters/repos/db/vector/hnsw/neighbor_connections.go @@ -118,14 +118,19 @@ func (n *neighborFinderConnector) processRecursively(from uint64, results *prior n.graph.handleDeletedNode(from) return nil } + // lock the nodes slice + n.graph.shardedNodeLocks.Lock(from) + // lock the node itself n.graph.nodes[from].Lock() if level >= len(n.graph.nodes[from].connections) { n.graph.nodes[from].Unlock() + n.graph.shardedNodeLocks.Unlock(from) return nil } connections := make([]uint64, len(n.graph.nodes[from].connections[level])) copy(connections, n.graph.nodes[from].connections[level]) n.graph.nodes[from].Unlock() + n.graph.shardedNodeLocks.Unlock(from) for _, id := range connections { if visited.Visited(id) { continue diff --git a/cluster/cluster.go b/cluster/cluster.go index c0037d0fc9..3d6fb2b307 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -72,7 +72,9 @@ func (c *Service) Open(ctx context.Context, db store.Indexer) error { c.client, c.config.NodeID, c.raftAddr, - c.config.AddrResolver) + c.config.AddrResolver, + c.Service.Ready, + ) bCtx, bCancel := context.WithTimeout(ctx, c.config.BootstrapTimeout) defer bCancel() diff --git a/cluster/store/bootstrap.go b/cluster/store/bootstrap.go index 52943d1dfc..fd689e68e8 100644 --- a/cluster/store/bootstrap.go +++ b/cluster/store/bootstrap.go @@ -32,6 +32,7 @@ type joiner interface { type Bootstrapper struct { joiner joiner addrResolver addressResolver + isStoreReady func() bool localRaftAddr string localNodeID string @@ -41,7 +42,7 @@ type Bootstrapper struct { } // NewBootstrapper constructs a new bootsrapper -func NewBootstrapper(joiner joiner, raftID, raftAddr string, r addressResolver) *Bootstrapper { +func NewBootstrapper(joiner joiner, raftID, raftAddr string, r addressResolver, isStoreReady func() bool) *Bootstrapper { return &Bootstrapper{ joiner: joiner, addrResolver: r, @@ -49,6 +50,7 @@ func NewBootstrapper(joiner joiner, raftID, raftAddr string, r addressResolver) jitter: time.Second, localNodeID: raftID, localRaftAddr: raftAddr, + isStoreReady: isStoreReady, } } @@ -65,19 +67,45 @@ func (b *Bootstrapper) Do(ctx context.Context, serverPortMap map[string]int, lg case <-ctx.Done(): return ctx.Err() case <-ticker.C: + if b.isStoreReady() { + lg.WithField("action", "bootstrap").Info("node reporting ready, node has probably recovered cluster from raft config. Exiting bootstrap process") + return nil + } + + // If we have found no other server, there is nobody to contact + if len(servers) == 0 { + continue + } + // try to join an existing cluster - if leader, err := b.join(ctx, servers, voter); err == nil { - lg.WithField("leader", leader).Info("successfully joined cluster") + if leader, err := b.join(ctx, servers, voter); err != nil { + lg.WithFields(logrus.Fields{ + "servers": servers, + "action": "bootstrap", + "voter": voter, + }).WithError(err).Warning("failed to join cluster, will notify next if voter") + } else { + lg.WithFields(logrus.Fields{ + "action": "bootstrap", + "leader": leader, + }).Info("successfully joined cluster") return nil } if voter { // notify other servers about readiness of this node to be joined if err := b.notify(ctx, servers); err != nil { - lg.WithField("servers", servers).WithError(err).Error("notify all peers") + lg.WithFields(logrus.Fields{ + "action": "bootstrap", + "servers": servers, + }).WithError(err).Error("notify all peers") + continue } + lg.WithFields(logrus.Fields{ + "action": "bootstrap", + "servers": servers, + }).Info("notified peers this node is ready to join as voter") } - } } } @@ -86,6 +114,11 @@ func (b *Bootstrapper) Do(ctx context.Context, serverPortMap map[string]int, lg func (b *Bootstrapper) join(ctx context.Context, servers []string, voter bool) (leader string, err error) { var resp *cmd.JoinPeerResponse req := &cmd.JoinPeerRequest{Id: b.localNodeID, Address: b.localRaftAddr, Voter: voter} + // For each server, try to join. + // If we have no error then we have a leader + // If we have an error check for err == NOT_FOUND and leader != "" -> we contacted a non-leader node part of the + // cluster, let's join the leader. + // If no server allows us to join a cluster, return an error for _, addr := range servers { resp, err = b.joiner.Join(ctx, addr, req) if err == nil { diff --git a/cluster/store/bootstrap_test.go b/cluster/store/bootstrap_test.go index 632e12a280..539c12469a 100644 --- a/cluster/store/bootstrap_test.go +++ b/cluster/store/bootstrap_test.go @@ -32,6 +32,7 @@ func TestBootStrapper(t *testing.T) { voter bool servers map[string]int doBefore func(*MockJoiner) + isReady func() bool success bool }{ { @@ -41,6 +42,7 @@ func TestBootStrapper(t *testing.T) { doBefore: func(m *MockJoiner) { m.On("Join", anything, anything, anything).Return(&cmd.JoinPeerResponse{}, nil) }, + isReady: func() bool { return false }, success: false, }, { @@ -50,6 +52,7 @@ func TestBootStrapper(t *testing.T) { doBefore: func(m *MockJoiner) { m.On("Join", anything, anything, anything).Return(&cmd.JoinPeerResponse{}, nil) }, + isReady: func() bool { return false }, success: true, }, { @@ -63,6 +66,7 @@ func TestBootStrapper(t *testing.T) { m.On("Notify", anything, "S1:1", anything).Return(&cmd.NotifyPeerResponse{}, nil) m.On("Notify", anything, "S2:2", anything).Return(&cmd.NotifyPeerResponse{}, errAny) }, + isReady: func() bool { return false }, success: false, }, { @@ -75,23 +79,35 @@ func TestBootStrapper(t *testing.T) { m.On("Join", anything, "S2:2", anything).Return(&cmd.JoinPeerResponse{Leader: "S3"}, err) m.On("Join", anything, "S3", anything).Return(&cmd.JoinPeerResponse{}, nil) }, + isReady: func() bool { return false }, success: true, }, + { + name: "exit early on cluster ready", + voter: true, + servers: servers, + doBefore: func(m *MockJoiner) {}, + isReady: func() bool { return true }, + success: true, + }, } for _, test := range tests { - m := &MockJoiner{} - b := NewBootstrapper(m, "RID", "ADDR", &MockAddressResolver{func(id string) string { return id }}) - b.retryPeriod = time.Millisecond - b.jitter = time.Millisecond - test.doBefore(m) - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - err := b.Do(ctx, test.servers, NewMockLogger(t).Logger, test.voter, make(chan struct{})) - cancel() - if test.success && err != nil { - t.Errorf("%s: %v", test.name, err) - } else if !test.success && err == nil { - t.Errorf("%s: test must fail", test.name) - } + test := test + t.Run(test.name, func(t *testing.T) { + m := &MockJoiner{} + b := NewBootstrapper(m, "RID", "ADDR", &MockAddressResolver{func(id string) string { return id }}, test.isReady) + b.retryPeriod = time.Millisecond + b.jitter = time.Millisecond + test.doBefore(m) + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + err := b.Do(ctx, test.servers, NewMockLogger(t).Logger, test.voter, make(chan struct{})) + cancel() + if test.success && err != nil { + t.Errorf("%s: %v", test.name, err) + } else if !test.success && err == nil { + t.Errorf("%s: test must fail", test.name) + } + }) } } diff --git a/cluster/store/db.go b/cluster/store/db.go index 9fd2569db1..800a1de3a0 100644 --- a/cluster/store/db.go +++ b/cluster/store/db.go @@ -112,28 +112,21 @@ func (db *localDB) UpdateClass(cmd *command.ApplyRequest, nodeID string, schemaO meta.Class.VectorIndexConfig = u.VectorIndexConfig meta.Class.InvertedIndexConfig = u.InvertedIndexConfig meta.Class.VectorConfig = u.VectorConfig - // TODO: fix PushShard issues before enabling scale out - // https://github.com/weaviate/weaviate/issues/4840 - // meta.Class.ReplicationConfig = u.ReplicationConfig + meta.Class.ReplicationConfig = u.ReplicationConfig meta.Class.MultiTenancyConfig = u.MultiTenancyConfig meta.ClassVersion = cmd.Version - // TODO: fix PushShard issues before enabling scale out - // https://github.com/weaviate/weaviate/issues/4840 - // if req.State != nil { - // meta.Sharding = *req.State - // } + if req.State != nil { + meta.Sharding = *req.State + } return nil } return db.apply( applyOp{ - op: cmd.GetType().String(), - updateSchema: func() error { return db.Schema.updateClass(req.Class.Class, update) }, - updateStore: func() error { return db.store.UpdateClass(req) }, - schemaOnly: schemaOnly, - // Apply the DB change last otherwise we will error on the parsing of the class while updating the store. - // We need the schema to first parse the update and apply it so that we can use it in the DB update. - applyDbUpdateFirst: false, + op: cmd.GetType().String(), + updateSchema: func() error { return db.Schema.updateClass(req.Class.Class, update) }, + updateStore: func() error { return db.store.UpdateClass(req) }, + schemaOnly: schemaOnly, triggerSchemaCallback: true, }, ) @@ -146,7 +139,6 @@ func (db *localDB) DeleteClass(cmd *command.ApplyRequest, schemaOnly bool) error updateSchema: func() error { db.Schema.deleteClass(cmd.Class); return nil }, updateStore: func() error { return db.store.DeleteClass(cmd.Class) }, schemaOnly: schemaOnly, - applyDbUpdateFirst: true, triggerSchemaCallback: true, }, ) @@ -163,14 +155,10 @@ func (db *localDB) AddProperty(cmd *command.ApplyRequest, schemaOnly bool) error return db.apply( applyOp{ - op: cmd.GetType().String(), - updateSchema: func() error { return db.Schema.addProperty(cmd.Class, cmd.Version, req.Properties...) }, - updateStore: func() error { return db.store.AddProperty(cmd.Class, req) }, - schemaOnly: schemaOnly, - // Apply the DB first to ensure the underlying buckets related to properties are created/deleted *before* the - // schema is updated. This allows us to have object write waiting on the right schema version to proceed only - // once the buck buckets are present. - applyDbUpdateFirst: true, + op: cmd.GetType().String(), + updateSchema: func() error { return db.Schema.addProperty(cmd.Class, cmd.Version, req.Properties...) }, + updateStore: func() error { return db.store.AddProperty(cmd.Class, req) }, + schemaOnly: schemaOnly, triggerSchemaCallback: true, }, ) @@ -184,11 +172,10 @@ func (db *localDB) UpdateShardStatus(cmd *command.ApplyRequest, schemaOnly bool) return db.apply( applyOp{ - op: cmd.GetType().String(), - updateSchema: func() error { return nil }, - updateStore: func() error { return db.store.UpdateShardStatus(&req) }, - schemaOnly: schemaOnly, - applyDbUpdateFirst: true, + op: cmd.GetType().String(), + updateSchema: func() error { return nil }, + updateStore: func() error { return db.store.UpdateShardStatus(&req) }, + schemaOnly: schemaOnly, }, ) } @@ -248,19 +235,6 @@ func (db *localDB) Load(ctx context.Context, nodeID string) error { return nil } -// Reload updates an already opened local database with the newest schema. -// It updates existing indexes and adds new ones as necessary -func (db *localDB) Reload() error { - cs := make([]command.UpdateClassRequest, len(db.Schema.Classes)) - i := 0 - for _, v := range db.Schema.Classes { - cs[i] = command.UpdateClassRequest{Class: &v.Class, State: &v.Sharding} - i++ - } - db.store.ReloadLocalDB(context.Background(), cs) - return nil -} - func (db *localDB) Close(ctx context.Context) (err error) { return db.store.Close(ctx) } @@ -270,7 +244,6 @@ type applyOp struct { updateSchema func() error updateStore func() error schemaOnly bool - applyDbUpdateFirst bool triggerSchemaCallback bool } @@ -287,28 +260,20 @@ func (op applyOp) validate() error { return nil } +// apply does apply commands from RAFT to schema 1st and then db func (db *localDB) apply(op applyOp) error { if err := op.validate(); err != nil { return fmt.Errorf("could not validate raft apply op: %s", err) } - // To avoid a if/else with repeated logic, setup op1 and op2 to either updateSchema or updateStore depending on - // op.applyDbUpdateFirst and op.schemaOnly - op1, op2 := op.updateSchema, op.updateStore - msg1, msg2 := errSchema, errDB - if op.applyDbUpdateFirst && !op.schemaOnly { - op1, op2 = op.updateStore, op.updateSchema - msg1, msg2 = errDB, errSchema - } - - if err := op1(); err != nil { - return fmt.Errorf("%w: %s: %w", msg1, op.op, err) + // schema applied 1st to make sure any validation happen before applying it to db + if err := op.updateSchema(); err != nil { + return fmt.Errorf("%w: %s: %w", errSchema, op.op, err) } - // If the operation is schema only, op1 is always the schemaUpdate so we can skip op2 if !op.schemaOnly { - if err := op2(); err != nil { - return fmt.Errorf("%w: %s: %w", msg2, op.op, err) + if err := op.updateStore(); err != nil { + return fmt.Errorf("%w: %s: %w", errDB, op.op, err) } } diff --git a/cluster/store/store.go b/cluster/store/store.go index 9e8eb308b5..c23ccc6ec0 100644 --- a/cluster/store/store.go +++ b/cluster/store/store.go @@ -27,6 +27,7 @@ import ( raftbolt "github.com/hashicorp/raft-boltdb/v2" "github.com/sirupsen/logrus" "github.com/weaviate/weaviate/cluster/proto/api" + command "github.com/weaviate/weaviate/cluster/proto/api" "github.com/weaviate/weaviate/entities/models" "google.golang.org/protobuf/proto" gproto "google.golang.org/protobuf/proto" @@ -112,9 +113,9 @@ type Config struct { RecoveryTimeout time.Duration SnapshotInterval time.Duration BootstrapTimeout time.Duration - // UpdateWaitTimeout Timeout duration for waiting for the update to be propagated to this follower node. - UpdateWaitTimeout time.Duration - SnapshotThreshold uint64 + // ConsistencyWaitTimeout is the duration we will wait for a schema version to land on that node + ConsistencyWaitTimeout time.Duration + SnapshotThreshold uint64 DB Indexer Parser Parser @@ -149,9 +150,8 @@ type Store struct { // applyTimeout timeout limit the amount of time raft waits for a command to be started applyTimeout time.Duration - - // UpdateWaitTimeout Timeout duration for waiting for the update to be propagated to this follower node. - updateWaitTimeout time.Duration + // consistencyWaitTimeout is the duration we will wait for a schema version to land on that node + consistencyWaitTimeout time.Duration nodeID string host string @@ -173,10 +173,8 @@ type Store struct { mutex sync.Mutex candidates map[string]string - // initialLastAppliedIndex represents the index of the last applied command when the store is opened. - initialLastAppliedIndex uint64 - - // lastIndex atomic.Uint64 + // lastAppliedIndexOnStart represents the index of the last applied command when the store is opened. + lastAppliedIndexOnStart atomic.Uint64 // lastAppliedIndex index of latest update to the store lastAppliedIndex atomic.Uint64 @@ -191,25 +189,25 @@ type Store struct { func New(cfg Config) Store { return Store{ - raftDir: cfg.WorkDir, - raftPort: cfg.RaftPort, - voter: cfg.Voter, - bootstrapExpect: cfg.BootstrapExpect, - candidates: make(map[string]string, cfg.BootstrapExpect), - recoveryTimeout: cfg.RecoveryTimeout, - heartbeatTimeout: cfg.HeartbeatTimeout, - electionTimeout: cfg.ElectionTimeout, - snapshotInterval: cfg.SnapshotInterval, - snapshotThreshold: cfg.SnapshotThreshold, - updateWaitTimeout: cfg.UpdateWaitTimeout, - applyTimeout: time.Second * 20, - nodeID: cfg.NodeID, - host: cfg.Host, - addResolver: newAddrResolver(&cfg), - db: &localDB{NewSchema(cfg.NodeID, cfg.DB), cfg.DB, cfg.Parser, cfg.Logger}, - log: cfg.Logger, - logLevel: cfg.LogLevel, - logJsonFormat: cfg.LogJSONFormat, + raftDir: cfg.WorkDir, + raftPort: cfg.RaftPort, + voter: cfg.Voter, + bootstrapExpect: cfg.BootstrapExpect, + candidates: make(map[string]string, cfg.BootstrapExpect), + recoveryTimeout: cfg.RecoveryTimeout, + heartbeatTimeout: cfg.HeartbeatTimeout, + electionTimeout: cfg.ElectionTimeout, + snapshotInterval: cfg.SnapshotInterval, + snapshotThreshold: cfg.SnapshotThreshold, + consistencyWaitTimeout: cfg.ConsistencyWaitTimeout, + applyTimeout: time.Second * 20, + nodeID: cfg.NodeID, + host: cfg.Host, + addResolver: newAddrResolver(&cfg), + db: &localDB{NewSchema(cfg.NodeID, cfg.DB), cfg.DB, cfg.Parser, cfg.Logger}, + log: cfg.Logger, + logLevel: cfg.LogLevel, + logJsonFormat: cfg.LogJSONFormat, // if true voters will only serve schema metadataOnlyVoters: cfg.MetadataOnlyVoters, @@ -239,14 +237,11 @@ func (st *Store) Open(ctx context.Context) (err error) { } rLog := rLog{st.logStore} - st.initialLastAppliedIndex, err = rLog.LastAppliedCommand() + l, err := rLog.LastAppliedCommand() if err != nil { return fmt.Errorf("read log last command: %w", err) } - lastSnapshotIndex := snapshotIndex(st.snapshotStore) - if st.initialLastAppliedIndex == 0 { // empty node - st.loadDatabase(ctx) - } + st.lastAppliedIndexOnStart.Store(l) st.log.WithFields(logrus.Fields{ "name": st.nodeID, @@ -256,13 +251,20 @@ func (st *Store) Open(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("raft.NewRaft %v %w", st.transport.LocalAddr(), err) } + if st.lastAppliedIndexOnStart.Load() <= st.raft.LastIndex() { + // this should include empty and non empty node + st.openDatabase(ctx) + } + st.lastAppliedIndex.Store(st.raft.AppliedIndex()) + st.log.WithFields(logrus.Fields{ - "raft_applied_index": st.raft.AppliedIndex(), - "raft_last_index": st.raft.LastIndex(), - "last_log_applied_index": st.initialLastAppliedIndex, - "last_snapshot_index": lastSnapshotIndex, - }).Info("raft node") + "raft_applied_index": st.raft.AppliedIndex(), + "raft_last_index": st.raft.LastIndex(), + "last_store_log_applied_index": st.lastAppliedIndexOnStart.Load(), + "last_store_applied_index": st.lastAppliedIndex.Load(), + "last_snapshot_index": snapshotIndex(st.snapshotStore), + }).Info("raft node constructed") // There's no hard limit on the migration, so it should take as long as necessary. // However, we believe that 1 day should be more than sufficient. @@ -448,7 +450,7 @@ func (st *Store) WaitForAppliedIndex(ctx context.Context, period time.Duration, if idx := st.lastAppliedIndex.Load(); idx >= version { return nil } - ctx, cancel := context.WithTimeout(ctx, st.updateWaitTimeout) + ctx, cancel := context.WithTimeout(ctx, st.consistencyWaitTimeout) defer cancel() ticker := time.NewTicker(period) defer ticker.Stop() @@ -513,8 +515,8 @@ func (f *Store) FindSimilarClass(name string) string { // The value of "candidates" is a map[string]string of the current candidates IDs/addresses, // see Store.candidates. // -// The value of "initial_last_applied_index" is the index of the last applied command found when -// the store was opened, see Store.initialLastAppliedIndex. +// The value of "last_store_log_applied_index" is the index of the last applied command found when +// the store was opened, see Store.lastAppliedIndexOnStart. // // The value of "last_applied_index" is the index of the latest update to the store, // see Store.lastAppliedIndex. @@ -537,7 +539,7 @@ func (st *Store) Stats() map[string]any { stats["open"] = st.open.Load() stats["bootstrapped"] = st.bootstrapped.Load() stats["candidates"] = st.candidates - stats["initial_last_applied_index"] = st.initialLastAppliedIndex + stats["last_store_log_applied_index"] = st.lastAppliedIndexOnStart.Load() stats["last_applied_index"] = st.lastAppliedIndex.Load() stats["db_loaded"] = st.dbLoaded.Load() @@ -620,9 +622,10 @@ func (st *Store) Execute(req *api.ApplyRequest) (uint64, error) { func (st *Store) Apply(l *raft.Log) interface{} { ret := Response{Version: l.Index} st.log.WithFields(logrus.Fields{ - "type": l.Type, - "index": l.Index, - }).Debug("apply command") + "log_type": l.Type, + "log_name": l.Type.String(), + "log_index": l.Index, + }).Debug("apply fsm store command") if l.Type != raft.LogCommand { st.log.WithFields(logrus.Fields{ "type": l.Type, @@ -636,24 +639,52 @@ func (st *Store) Apply(l *raft.Log) interface{} { panic("error proto un-marshalling log data") } - schemaOnly := l.Index <= st.initialLastAppliedIndex + // schemaOnly is necessary so that on restart when we are re-applying RAFT log entries to our in-memory schema we + // don't update the database. This can lead to dataloss for example if we drop then re-add a class. + // If we don't have any last applied index on start, schema only is always false. + schemaOnly := st.lastAppliedIndexOnStart.Load() != 0 && l.Index <= st.lastAppliedIndexOnStart.Load() defer func() { - st.lastAppliedIndex.Store(l.Index) - // If the local db has not been loaded, wait until we reach the state - // from the local raft log before loading the db. - // This is necessary because the database operations are not idempotent - if !st.dbLoaded.Load() && l.Index >= st.initialLastAppliedIndex { - st.loadDatabase(context.Background()) + // If we have an applied index from the previous store (i.e from disk). Then reload the DB once we catch up as + // that means we're done doing schema only. + if st.lastAppliedIndexOnStart.Load() != 0 && l.Index == st.lastAppliedIndexOnStart.Load() { + st.log.WithFields(logrus.Fields{ + "log_type": l.Type, + "log_name": l.Type.String(), + "log_index": l.Index, + "last_store_log_applied_index": st.lastAppliedIndexOnStart.Load(), + }).Debug("reloading local DB as RAFT and local DB are now caught up") + cs := make([]command.UpdateClassRequest, len(st.db.Schema.Classes)) + i := 0 + for _, v := range st.db.Schema.Classes { + cs[i] = command.UpdateClassRequest{Class: &v.Class, State: &v.Sharding} + i++ + } + st.db.store.ReloadLocalDB(context.Background(), cs) } + + st.lastAppliedIndex.Store(l.Index) if ret.Error != nil { st.log.WithFields(logrus.Fields{ - "type": l.Type, - "index": l.Index, + "log_type": l.Type, + "log_name": l.Type.String(), + "log_index": l.Index, + "cmd_type": cmd.Type, + "cmd_type_name": cmd.Type.String(), + "cmd_class": cmd.Class, }).WithError(ret.Error).Error("apply command") } }() cmd.Version = l.Index + st.log.WithFields(logrus.Fields{ + "log_type": l.Type, + "log_name": l.Type.String(), + "log_index": l.Index, + "cmd_type": cmd.Type, + "cmd_type_name": cmd.Type.String(), + "cmd_class": cmd.Class, + "cmd_schema_only": schemaOnly, + }).Debug("server.apply") switch cmd.Type { case api.ApplyRequest_TYPE_ADD_CLASS: @@ -723,13 +754,13 @@ func (st *Store) Restore(rc io.ReadCloser) error { } st.log.Info("successfully restored schema from snapshot") - if st.reloadDB() { + if st.reloadDBFromSnapshot() { st.log.WithField("n", st.db.Schema.len()). Info("successfully reloaded indexes from snapshot") } if st.raft != nil { - st.lastAppliedIndex.Store(st.raft.AppliedIndex()) // TODO-RAFT: check if raft return the latest applied index + st.lastAppliedIndex.Store(st.raft.AppliedIndex()) } return nil @@ -768,7 +799,6 @@ func (st *Store) Remove(id string) error { // Notify signals this Store that a node is ready for bootstrapping at the specified address. // Bootstrapping will be initiated once the number of known nodes reaches the expected level, // which includes this node. - func (st *Store) Notify(id, addr string) (err error) { if !st.open.Load() { return ErrNotOpen @@ -784,9 +814,10 @@ func (st *Store) Notify(id, addr string) (err error) { st.candidates[id] = addr if len(st.candidates) < st.bootstrapExpect { st.log.WithFields(logrus.Fields{ + "action": "bootstrap", "expect": st.bootstrapExpect, "got": st.candidates, - }).Debug("number of candidates") + }).Debug("number of candidates lower than bootstrap expect param, stopping notify") return nil } candidates := make([]raft.Server, 0, len(st.candidates)) @@ -801,11 +832,14 @@ func (st *Store) Notify(id, addr string) (err error) { i++ } - st.log.WithField("candidates", candidates).Info("starting cluster bootstrapping") + st.log.WithFields(logrus.Fields{ + "action": "bootstrap", + "candidates": candidates, + }).Info("starting cluster bootstrapping") fut := st.raft.BootstrapCluster(raft.Configuration{Servers: candidates}) if err := fut.Error(); err != nil { - st.log.WithError(err).Error("bootstrapping cluster") + st.log.WithField("action", "bootstrap").WithError(err).Error("could not bootstrapping cluster") if !errors.Is(err, raft.ErrCantBootstrap) { return err } @@ -846,7 +880,7 @@ func (st *Store) raftConfig() *raft.Config { return cfg } -func (st *Store) loadDatabase(ctx context.Context) { +func (st *Store) openDatabase(ctx context.Context) { if st.dbLoaded.Load() { return } @@ -861,39 +895,41 @@ func (st *Store) loadDatabase(ctx context.Context) { st.log.WithField("n", st.db.Schema.len()).Info("database has been successfully loaded") } -// reloadDB reloads the node's local db. If the db is already loaded, it will be reloaded. +// reloadDBFromSnapshot reloads the node's local db. If the db is already loaded, it will be reloaded. // If a snapshot exists and its is up to date with the log, it will be loaded. // Otherwise, the database will be loaded when the node synchronizes its state with the leader. -// For more details, see apply() -> loadDatabase(). // // In specific scenarios where the follower's state is too far behind the leader's log, // the leader may decide to send a snapshot. Consequently, the follower must update its state accordingly. -func (st *Store) reloadDB() bool { +func (st *Store) reloadDBFromSnapshot() bool { ctx := context.Background() if !st.dbLoaded.CompareAndSwap(true, false) { // the snapshot already includes the state from the raft log snapIndex := snapshotIndex(st.snapshotStore) st.log.WithFields(logrus.Fields{ - "last_applied_index": st.lastAppliedIndex.Load(), - "initial_last_applied_index": st.initialLastAppliedIndex, - "last_snapshot_index": snapIndex, + "last_applied_index": st.lastAppliedIndex.Load(), + "last_store_log_applied_index": st.lastAppliedIndexOnStart.Load(), + "last_snapshot_index": snapIndex, }).Info("load local db from snapshot") - if st.initialLastAppliedIndex <= snapIndex { - st.loadDatabase(ctx) + if st.lastAppliedIndexOnStart.Load() <= snapIndex { + st.openDatabase(ctx) return true } return false } - st.log.Info("reload local db: loading indexes ...") - if err := st.db.Reload(); err != nil { - st.log.WithError(err).Error("cannot reload database") - panic(fmt.Sprintf("cannot reload database: %v", err)) + st.log.Info("reload local db: update schema ...") + cs := make([]command.UpdateClassRequest, len(st.db.Schema.Classes)) + i := 0 + for _, v := range st.db.Schema.Classes { + cs[i] = command.UpdateClassRequest{Class: &v.Class, State: &v.Sharding} + i++ } + st.db.store.ReloadLocalDB(context.Background(), cs) st.dbLoaded.Store(true) - st.initialLastAppliedIndex = 0 + st.lastAppliedIndexOnStart.Store(0) return true } diff --git a/cluster/store/store_test.go b/cluster/store/store_test.go index a0fda98d5c..e929b9d65c 100644 --- a/cluster/store/store_test.go +++ b/cluster/store/store_test.go @@ -387,12 +387,11 @@ func TestServicePanics(t *testing.T) { // Cannot Open File Store m.indexer.On("Open", mock.Anything).Return(errAny) - assert.Panics(t, func() { m.store.loadDatabase(context.TODO()) }) + assert.Panics(t, func() { m.store.openDatabase(context.TODO()) }) } func TestStoreApply(t *testing.T) { doFirst := func(m *MockStore) { - m.indexer.On("Open", mock.Anything).Return(nil) m.parser.On("ParseClass", mock.Anything).Return(nil) m.indexer.On("TriggerSchemaUpdateCallbacks").Return() } @@ -449,8 +448,12 @@ func TestStoreApply(t *testing.T) { cmd.ApplyRequest_TYPE_ADD_CLASS, cmd.AddClassRequest{Class: cls, State: ss}, nil)}, - resp: Response{Error: nil}, - doBefore: doFirst, + resp: Response{Error: nil}, + doBefore: func(m *MockStore) { + m.indexer.On("AddClass", mock.Anything).Return(nil) + m.parser.On("ParseClass", mock.Anything).Return(nil) + m.indexer.On("TriggerSchemaUpdateCallbacks").Return() + }, doAfter: func(ms *MockStore) error { _, ok := ms.store.db.Schema.Classes["C1"] if !ok { @@ -495,9 +498,9 @@ func TestStoreApply(t *testing.T) { nil)}, resp: Response{Error: nil}, doBefore: func(m *MockStore) { - m.indexer.On("Open", mock.Anything).Return(nil) m.parser.On("ParseClass", mock.Anything).Return(nil) m.indexer.On("RestoreClassDir", cls.Class).Return(nil) + m.indexer.On("AddClass", mock.Anything).Return(nil) m.indexer.On("TriggerSchemaUpdateCallbacks").Return() }, doAfter: func(ms *MockStore) error { @@ -550,6 +553,7 @@ func TestStoreApply(t *testing.T) { doBefore: func(m *MockStore) { m.indexer.On("Open", mock.Anything).Return(nil) m.parser.On("ParseClassUpdate", mock.Anything, mock.Anything).Return(mock.Anything, nil) + m.indexer.On("UpdateClass", mock.Anything).Return(nil) m.store.db.Schema.addClass(cls, ss, 1) m.indexer.On("TriggerSchemaUpdateCallbacks").Return() }, @@ -561,7 +565,7 @@ func TestStoreApply(t *testing.T) { nil)}, resp: Response{Error: nil}, doBefore: func(m *MockStore) { - m.indexer.On("Open", mock.Anything).Return(nil) + m.indexer.On("DeleteClass", mock.Anything).Return(nil) m.indexer.On("TriggerSchemaUpdateCallbacks").Return() }, doAfter: func(ms *MockStore) error { @@ -605,8 +609,8 @@ func TestStoreApply(t *testing.T) { }, resp: Response{Error: nil}, doBefore: func(m *MockStore) { - m.indexer.On("Open", mock.Anything).Return(nil) m.store.db.Schema.addClass(cls, ss, 1) + m.indexer.On("AddProperty", mock.Anything, mock.Anything).Return(nil) m.indexer.On("TriggerSchemaUpdateCallbacks").Return() }, doAfter: func(ms *MockStore) error { @@ -634,8 +638,12 @@ func TestStoreApply(t *testing.T) { name: "UpdateShard/Success", req: raft.Log{Data: cmdAsBytes("C1", cmd.ApplyRequest_TYPE_UPDATE_SHARD_STATUS, cmd.UpdateShardStatusRequest{Class: "C1"}, nil)}, - resp: Response{Error: nil}, - doBefore: doFirst, + resp: Response{Error: nil}, + doBefore: func(m *MockStore) { + m.parser.On("ParseClass", mock.Anything).Return(nil) + m.indexer.On("UpdateShardStatus", mock.Anything).Return(nil) + m.indexer.On("TriggerSchemaUpdateCallbacks").Return() + }, }, { name: "AddTenant/Unmarshal", @@ -659,10 +667,11 @@ func TestStoreApply(t *testing.T) { })}, resp: Response{Error: nil}, doBefore: func(m *MockStore) { - m.indexer.On("Open", mock.Anything).Return(nil) m.store.db.Schema.addClass(cls, &sharding.State{ Physical: map[string]sharding.Physical{"T1": {}}, }, 1) + + m.indexer.On("AddTenants", mock.Anything, mock.Anything).Return(nil) }, doAfter: func(ms *MockStore) error { if _, ok := ms.store.db.Schema.Classes["C1"].Sharding.Physical["T1"]; !ok { @@ -720,8 +729,8 @@ func TestStoreApply(t *testing.T) { BelongsToNodes: []string{"NODE-2"}, Status: models.TenantActivityStatusHOT, }}} - m.indexer.On("Open", mock.Anything).Return(nil) m.store.db.Schema.addClass(cls, ss, 1) + m.indexer.On("UpdateTenants", mock.Anything, mock.Anything).Return(nil) }, doAfter: func(ms *MockStore) error { want := map[string]sharding.Physical{"T1": { @@ -763,8 +772,8 @@ func TestStoreApply(t *testing.T) { nil, &cmd.DeleteTenantsRequest{Tenants: []string{"T1", "T2"}})}, resp: Response{Error: nil}, doBefore: func(m *MockStore) { - m.indexer.On("Open", mock.Anything).Return(nil) m.store.db.Schema.addClass(cls, &sharding.State{Physical: map[string]sharding.Physical{"T1": {}}}, 1) + m.indexer.On("DeleteTenants", mock.Anything, mock.Anything).Return(nil) }, doAfter: func(ms *MockStore) error { if len(ms.store.db.Schema.Classes["C1"].Sharding.Physical) != 0 { @@ -776,27 +785,29 @@ func TestStoreApply(t *testing.T) { } for _, tc := range tests { - m := NewMockStore(t, "Node-1", 9091) - store := m.Store(tc.doBefore) - ret := store.Apply(&tc.req) - resp, ok := ret.(Response) - if !ok { - t.Errorf("%s: response has wrong type", tc.name) - } - if got, want := resp.Error, tc.resp.Error; want != nil { - if !errors.Is(resp.Error, tc.resp.Error) { - t.Errorf("%s: error want: %v got: %v", tc.name, want, got) + t.Run(tc.name, func(t *testing.T) { + m := NewMockStore(t, "Node-1", 9091) + store := m.Store(tc.doBefore) + ret := store.Apply(&tc.req) + resp, ok := ret.(Response) + if !ok { + t.Errorf("%s: response has wrong type", tc.name) } - } else if got != nil { - t.Errorf("%s: error want: nil got: %v", tc.name, got) - } - if tc.doAfter != nil { - if err := tc.doAfter(&m); err != nil { - t.Errorf("%s check updates: %v", tc.name, err) + if got, want := resp.Error, tc.resp.Error; want != nil { + if !errors.Is(resp.Error, tc.resp.Error) { + t.Errorf("%s: error want: %v got: %v", tc.name, want, got) + } + } else if got != nil { + t.Errorf("%s: error want: nil got: %v", tc.name, got) } - m.indexer.AssertExpectations(t) - m.parser.AssertExpectations(t) - } + if tc.doAfter != nil { + if err := tc.doAfter(&m); err != nil { + t.Errorf("%s check updates: %v", tc.name, err) + } + m.indexer.AssertExpectations(t) + m.parser.AssertExpectations(t) + } + }) } } @@ -852,22 +863,22 @@ func NewMockStore(t *testing.T, nodeID string, raftPort int) MockStore { logger: logger, cfg: Config{ - WorkDir: t.TempDir(), - NodeID: nodeID, - Host: "localhost", - RaftPort: raftPort, - Voter: true, - BootstrapExpect: 1, - HeartbeatTimeout: 1 * time.Second, - ElectionTimeout: 1 * time.Second, - RecoveryTimeout: 500 * time.Millisecond, - SnapshotInterval: 2 * time.Second, - SnapshotThreshold: 125, - DB: indexer, - Parser: parser, - AddrResolver: &MockAddressResolver{}, - Logger: logger.Logger, - UpdateWaitTimeout: time.Millisecond * 50, + WorkDir: t.TempDir(), + NodeID: nodeID, + Host: "localhost", + RaftPort: raftPort, + Voter: true, + BootstrapExpect: 1, + HeartbeatTimeout: 1 * time.Second, + ElectionTimeout: 1 * time.Second, + RecoveryTimeout: 500 * time.Millisecond, + SnapshotInterval: 2 * time.Second, + SnapshotThreshold: 125, + DB: indexer, + Parser: parser, + AddrResolver: &MockAddressResolver{}, + Logger: logger.Logger, + ConsistencyWaitTimeout: time.Millisecond * 50, }, } s := New(ms.cfg) diff --git a/entities/locks/named_locks.go b/entities/locks/named_locks.go deleted file mode 100644 index 073d8b358b..0000000000 --- a/entities/locks/named_locks.go +++ /dev/null @@ -1,103 +0,0 @@ -// _ _ -// __ _____ __ ___ ___ __ _| |_ ___ -// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ -// \ V V / __/ (_| |\ V /| | (_| | || __/ -// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| -// -// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. -// -// CONTACT: hello@weaviate.io -// - -package locks - -import "sync" - -type NamedLocks struct { - main *sync.Mutex - locks map[string]*sync.Mutex -} - -func NewNamedLocks() *NamedLocks { - return &NamedLocks{ - main: new(sync.Mutex), - locks: make(map[string]*sync.Mutex), - } -} - -func (l *NamedLocks) Lock(name string) { - l.lock(name).Lock() -} - -func (l *NamedLocks) Unlock(name string) { - l.lock(name).Unlock() -} - -func (l *NamedLocks) Locked(name string, callback func()) { - l.Lock(name) - defer l.Unlock(name) - - callback() -} - -func (l *NamedLocks) lock(name string) *sync.Mutex { - l.main.Lock() - defer l.main.Unlock() - - if _, ok := l.locks[name]; !ok { - l.locks[name] = new(sync.Mutex) - } - return l.locks[name] -} - -type NamedRWLocks struct { - main *sync.Mutex - locks map[string]*sync.RWMutex -} - -func NewNamedRWLocks() *NamedRWLocks { - return &NamedRWLocks{ - main: new(sync.Mutex), - locks: make(map[string]*sync.RWMutex), - } -} - -func (l *NamedRWLocks) Lock(name string) { - l.lock(name).Lock() -} - -func (l *NamedRWLocks) Unlock(name string) { - l.lock(name).Unlock() -} - -func (l *NamedRWLocks) Locked(name string, callback func()) { - l.Lock(name) - defer l.Unlock(name) - - callback() -} - -func (l *NamedRWLocks) RLock(name string) { - l.lock(name).RLock() -} - -func (l *NamedRWLocks) RUnlock(name string) { - l.lock(name).RUnlock() -} - -func (l *NamedRWLocks) RLocked(name string, callback func()) { - l.RLock(name) - defer l.RUnlock(name) - - callback() -} - -func (l *NamedRWLocks) lock(name string) *sync.RWMutex { - l.main.Lock() - defer l.main.Unlock() - - if _, ok := l.locks[name]; !ok { - l.locks[name] = new(sync.RWMutex) - } - return l.locks[name] -} diff --git a/entities/searchparams/retrieval.go b/entities/searchparams/retrieval.go index ec66918c08..e53a38537c 100644 --- a/entities/searchparams/retrieval.go +++ b/entities/searchparams/retrieval.go @@ -11,6 +11,14 @@ package searchparams +import ( + "fmt" + "strings" + + "github.com/weaviate/weaviate/entities/models" + "github.com/weaviate/weaviate/entities/schema" +) + type NearVector struct { Vector []float32 `json:"vector"` Certainty float64 `json:"certainty"` @@ -26,6 +34,61 @@ type KeywordRanking struct { AdditionalExplanations bool `json:"additionalExplanations"` } +// Indicates whether property should be indexed +// Index holds document ids with property of/containing particular value +// and number of its occurrences in that property +// (index created using bucket of StrategyMapCollection) +func HasSearchableIndex(prop *models.Property) bool { + switch dt, _ := schema.AsPrimitive(prop.DataType); dt { + case schema.DataTypeText, schema.DataTypeTextArray: + // by default property has searchable index only for text/text[] props + if prop.IndexSearchable == nil { + return true + } + return *prop.IndexSearchable + default: + return false + } +} + +func PropertyHasSearchableIndex(class *models.Class, tentativePropertyName string) bool { + if class == nil { + return false + } + + propertyName := strings.Split(tentativePropertyName, "^")[0] + p, err := schema.GetPropertyByName(class, propertyName) + if err != nil { + return false + } + return HasSearchableIndex(p) +} + +// GetPropertyByName returns the class by its name +func GetPropertyByName(c *models.Class, propName string) (*models.Property, error) { + for _, prop := range c.Properties { + // Check if the name of the property is the given name, that's the property we need + if prop.Name == strings.Split(propName, ".")[0] { + return prop, nil + } + } + return nil, fmt.Errorf("Property %v not found %v", propName, c.Class) +} + +func (k *KeywordRanking) ChooseSearchableProperties(class *models.Class) { + var validProperties []string + for _, prop := range k.Properties { + property, err := GetPropertyByName(class, prop) + if err != nil { + continue + } + if HasSearchableIndex(property) { + validProperties = append(validProperties, prop) + } + } + k.Properties = validProperties +} + type WeightedSearchResult struct { SearchParams interface{} `json:"searchParams"` Weight float64 `json:"weight"` diff --git a/entities/storobj/storage_object.go b/entities/storobj/storage_object.go index a80ae34311..a105161634 100644 --- a/entities/storobj/storage_object.go +++ b/entities/storobj/storage_object.go @@ -535,6 +535,16 @@ func DocIDFromBinary(in []byte) (uint64, error) { // 4 | uint32 | length of target vectors segment (in bytes) // n | uint16+[]byte | target vectors segment: sequence of vec_length + vec (uint16 + []byte), (uint16 + []byte) ... +const ( + maxVectorLength int = math.MaxUint16 + maxClassNameLength int = math.MaxUint16 + maxSchemaLength int = math.MaxUint32 + maxMetaLength int = math.MaxUint32 + maxVectorWeightsLength int = math.MaxUint32 + maxTargetVectorsSegmentLength int = math.MaxUint32 + maxTargetVectorsOffsetsLength int = math.MaxUint32 +) + func (ko *Object) MarshalBinary() ([]byte, error) { if ko.MarshallerVersion != 1 { return nil, errors.Errorf("unsupported marshaller version %d", ko.MarshallerVersion) @@ -552,41 +562,75 @@ func (ko *Object) MarshalBinary() ([]byte, error) { if err != nil { return nil, err } + + if len(ko.Vector) > maxVectorLength { + return nil, fmt.Errorf("could not marshal '%s' max length exceeded (%d/%d)", "vector", len(ko.Vector), maxVectorLength) + } vectorLength := uint32(len(ko.Vector)) + className := []byte(ko.Class()) + if len(className) > maxClassNameLength { + return nil, fmt.Errorf("could not marshal '%s' max length exceeded (%d/%d)", "className", len(className), maxClassNameLength) + } classNameLength := uint32(len(className)) + schema, err := json.Marshal(ko.Properties()) if err != nil { return nil, err } + if len(schema) > maxSchemaLength { + return nil, fmt.Errorf("could not marshal '%s' max length exceeded (%d/%d)", "schema", len(schema), maxSchemaLength) + } schemaLength := uint32(len(schema)) + meta, err := json.Marshal(ko.AdditionalProperties()) if err != nil { return nil, err } + if len(meta) > maxMetaLength { + return nil, fmt.Errorf("could not marshal '%s' max length exceeded (%d/%d)", "meta", len(meta), maxMetaLength) + } metaLength := uint32(len(meta)) + vectorWeights, err := json.Marshal(ko.VectorWeights()) if err != nil { return nil, err } + if len(vectorWeights) > maxVectorWeightsLength { + return nil, fmt.Errorf("could not marshal '%s' max length exceeded (%d/%d)", "vectorWeights", len(vectorWeights), maxVectorWeightsLength) + } vectorWeightsLength := uint32(len(vectorWeights)) var targetVectorsOffsets []byte - targetVectorsOffsetsLength := uint32(0) - targetVectorsSegmentLength := uint32(0) + var targetVectorsOffsetsLength uint32 + var targetVectorsSegmentLength int targetVectorsOffsetOrder := make([]string, 0, len(ko.Vectors)) if len(ko.Vectors) > 0 { offsetsMap := map[string]uint32{} for name, vec := range ko.Vectors { - offsetsMap[name] = targetVectorsSegmentLength - targetVectorsSegmentLength += 2 + 4*uint32(len(vec)) // 2 for vec length + vec bytes + if len(vec) > maxVectorLength { + return nil, fmt.Errorf("could not marshal '%s' max length exceeded (%d/%d)", "vector", len(vec), maxVectorLength) + } + + offsetsMap[name] = uint32(targetVectorsSegmentLength) + targetVectorsSegmentLength += 2 + 4*len(vec) // 2 for vec length + vec bytes + + if targetVectorsSegmentLength > maxTargetVectorsSegmentLength { + return nil, + fmt.Errorf("could not marshal '%s' max length exceeded (%d/%d)", + "targetVectorsSegmentLength", targetVectorsSegmentLength, maxTargetVectorsSegmentLength) + } + targetVectorsOffsetOrder = append(targetVectorsOffsetOrder, name) } targetVectorsOffsets, err = msgpack.Marshal(offsetsMap) if err != nil { - return nil, fmt.Errorf("Could not marshal target vectors offsets: %w", err) + return nil, fmt.Errorf("could not marshal target vectors offsets: %w", err) + } + if len(targetVectorsOffsets) > maxTargetVectorsOffsetsLength { + return nil, fmt.Errorf("could not marshal '%s' max length exceeded (%d/%d)", "targetVectorsOffsets", len(targetVectorsOffsets), maxTargetVectorsOffsetsLength) } targetVectorsOffsetsLength = uint32(len(targetVectorsOffsets)) } @@ -598,7 +642,7 @@ func (ko *Object) MarshalBinary() ([]byte, error) { 4 + metaLength + 4 + vectorWeightsLength + 4 + targetVectorsOffsetsLength + - 4 + targetVectorsSegmentLength + 4 + uint32(targetVectorsSegmentLength) byteBuffer := make([]byte, totalBufferLength) rw := byteops.NewReadWriter(byteBuffer) @@ -648,7 +692,7 @@ func (ko *Object) MarshalBinary() ([]byte, error) { } } - rw.WriteUint32(targetVectorsSegmentLength) + rw.WriteUint32(uint32(targetVectorsSegmentLength)) for _, name := range targetVectorsOffsetOrder { vec := ko.Vectors[name] vecLen := len(vec) diff --git a/entities/storobj/storage_object_test.go b/entities/storobj/storage_object_test.go index cb255abe6c..b214525578 100644 --- a/entities/storobj/storage_object_test.go +++ b/entities/storobj/storage_object_test.go @@ -12,6 +12,7 @@ package storobj import ( + "crypto/rand" "fmt" "testing" "time" @@ -700,3 +701,58 @@ func TestVectorFromBinary(t *testing.T) { require.Nil(t, err) assert.Equal(t, vector3, outVector3) } + +func TestStorageInvalidObjectMarshalling(t *testing.T) { + t.Run("invalid className", func(t *testing.T) { + invalidClassName := make([]byte, maxClassNameLength+1) + rand.Read(invalidClassName[:]) + + invalidObj := FromObject( + &models.Object{ + Class: string(invalidClassName), + CreationTimeUnix: 123456, + LastUpdateTimeUnix: 56789, + ID: strfmt.UUID("73f2eb5f-5abf-447a-81ca-74b1dd168247"), + }, + nil, + nil, + ) + + _, err := invalidObj.MarshalBinary() + require.ErrorContains(t, err, "could not marshal 'className' max length exceeded") + }) + + t.Run("invalid vector", func(t *testing.T) { + invalidObj := FromObject( + &models.Object{ + Class: "classA", + CreationTimeUnix: 123456, + LastUpdateTimeUnix: 56789, + ID: strfmt.UUID("73f2eb5f-5abf-447a-81ca-74b1dd168247"), + }, + make([]float32, maxVectorLength+1), + nil, + ) + + _, err := invalidObj.MarshalBinary() + require.ErrorContains(t, err, "could not marshal 'vector' max length exceeded") + }) + + t.Run("invalid named vector size", func(t *testing.T) { + invalidObj := FromObject( + &models.Object{ + Class: "classA", + CreationTimeUnix: 123456, + LastUpdateTimeUnix: 56789, + ID: strfmt.UUID("73f2eb5f-5abf-447a-81ca-74b1dd168247"), + }, + nil, + models.Vectors{ + "vector1": make(models.Vector, maxVectorLength+1), + }, + ) + + _, err := invalidObj.MarshalBinary() + require.ErrorContains(t, err, "could not marshal 'vector' max length exceeded") + }) +} diff --git a/entities/sync/sync.go b/entities/sync/sync.go index 40d41313c0..61b808b192 100644 --- a/entities/sync/sync.go +++ b/entities/sync/sync.go @@ -17,15 +17,16 @@ import ( // KeyLocker it is a thread safe wrapper of sync.Map // Usage: it's used in order to lock specific key in a map -// to synchronizes concurrent access to a code block. -// locker.Lock(id) -// defer locker.Unlock(id) +// to synchronize concurrent access to a code block. +// +// locker.Lock(id) +// defer locker.Unlock(id) type KeyLocker struct { m sync.Map } -// New creates Keylocker -func New() *KeyLocker { +// NewKeyLocker creates Keylocker +func NewKeyLocker() *KeyLocker { return &KeyLocker{ m: sync.Map{}, } @@ -44,9 +45,68 @@ func (s *KeyLocker) Lock(ID string) { } // Unlock it unlocks a specific item by it's ID -// and it will delete it from the shared locks map func (s *KeyLocker) Unlock(ID string) { iLocks, _ := s.m.Load(ID) iLock := iLocks.(*sync.Mutex) iLock.Unlock() } + +// KeyRWLocker it is a thread safe wrapper of sync.Map +// Usage: it's used in order to lock/rlock specific key in a map +// to synchronize concurrent access to a code block. +// +// locker.Lock(id) +// defer locker.Unlock(id) +// +// or +// +// locker.RLock(id) +// defer locker.RUnlock(id) +type KeyRWLocker struct { + m sync.Map +} + +// NewKeyLocker creates Keylocker +func NewKeyRWLocker() *KeyRWLocker { + return &KeyRWLocker{ + m: sync.Map{}, + } +} + +// Lock it locks a specific bucket by it's ID +// to hold ant concurrent access to that specific item +// +// do not forget calling Unlock() after locking it. +func (s *KeyRWLocker) Lock(ID string) { + iLock := &sync.RWMutex{} + iLocks, _ := s.m.LoadOrStore(ID, iLock) + + iLock = iLocks.(*sync.RWMutex) + iLock.Lock() +} + +// Unlock it unlocks a specific item by it's ID +func (s *KeyRWLocker) Unlock(ID string) { + iLocks, _ := s.m.Load(ID) + iLock := iLocks.(*sync.RWMutex) + iLock.Unlock() +} + +// RLock it rlocks a specific bucket by it's ID +// to hold ant concurrent access to that specific item +// +// do not forget calling RUnlock() after rlocking it. +func (s *KeyRWLocker) RLock(ID string) { + iLock := &sync.RWMutex{} + iLocks, _ := s.m.LoadOrStore(ID, iLock) + + iLock = iLocks.(*sync.RWMutex) + iLock.RLock() +} + +// RUnlock it runlocks a specific item by it's ID +func (s *KeyRWLocker) RUnlock(ID string) { + iLocks, _ := s.m.Load(ID) + iLock := iLocks.(*sync.RWMutex) + iLock.RUnlock() +} diff --git a/entities/sync/sync_test.go b/entities/sync/sync_test.go index 962ebc3495..70a3d00948 100644 --- a/entities/sync/sync_test.go +++ b/entities/sync/sync_test.go @@ -25,9 +25,32 @@ func mutexLocked(m *sync.Mutex) bool { return state.Int()&mLocked == mLocked } -func TestSyncLockUnlock(t *testing.T) { +func rwMutexLocked(m *sync.RWMutex) bool { + // can not RLock + rlocked := m.TryRLock() + if rlocked { + defer m.RUnlock() + } + return !rlocked +} + +func rwMutexRLocked(m *sync.RWMutex) bool { + // can not Lock, but can RLock + locked := m.TryLock() + if locked { + defer m.Unlock() + return false + } + rlocked := m.TryRLock() + if rlocked { + defer m.RUnlock() + } + return rlocked +} + +func TestKeyLockerLockUnlock(t *testing.T) { r := require.New(t) - s := New() + s := NewKeyLocker() s.Lock("t1") lock, _ := s.m.Load("t1") @@ -45,3 +68,48 @@ func TestSyncLockUnlock(t *testing.T) { lock, _ = s.m.Load("t2") r.False(mutexLocked(lock.(*sync.Mutex))) } + +func TestKeyRWLockerLockUnlock(t *testing.T) { + r := require.New(t) + s := NewKeyRWLocker() + + s.Lock("t1") + lock, _ := s.m.Load("t1") + r.True(rwMutexLocked(lock.(*sync.RWMutex))) + r.False(rwMutexRLocked(lock.(*sync.RWMutex))) + + s.Unlock("t1") + lock, _ = s.m.Load("t1") + r.False(rwMutexLocked(lock.(*sync.RWMutex))) + r.False(rwMutexRLocked(lock.(*sync.RWMutex))) + + s.Lock("t2") + lock, _ = s.m.Load("t2") + r.True(rwMutexLocked(lock.(*sync.RWMutex))) + r.False(rwMutexRLocked(lock.(*sync.RWMutex))) + + s.Unlock("t2") + lock, _ = s.m.Load("t2") + r.False(rwMutexLocked(lock.(*sync.RWMutex))) + r.False(rwMutexRLocked(lock.(*sync.RWMutex))) + + s.RLock("t1") + lock, _ = s.m.Load("t1") + r.False(rwMutexLocked(lock.(*sync.RWMutex))) + r.True(rwMutexRLocked(lock.(*sync.RWMutex))) + + s.RUnlock("t1") + lock, _ = s.m.Load("t1") + r.False(rwMutexLocked(lock.(*sync.RWMutex))) + r.False(rwMutexRLocked(lock.(*sync.RWMutex))) + + s.RLock("t2") + lock, _ = s.m.Load("t2") + r.False(rwMutexLocked(lock.(*sync.RWMutex))) + r.True(rwMutexRLocked(lock.(*sync.RWMutex))) + + s.RUnlock("t2") + lock, _ = s.m.Load("t2") + r.False(rwMutexLocked(lock.(*sync.RWMutex))) + r.False(rwMutexRLocked(lock.(*sync.RWMutex))) +} diff --git a/go.mod b/go.mod index b36e7d4ec8..70d52af630 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,10 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.10.0 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.1 github.com/KimMachineGun/automemlimit v0.3.0 + github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2/config v1.27.12 + github.com/aws/aws-sdk-go-v2/credentials v1.17.12 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1 github.com/cenkalti/backoff/v4 v4.2.1 github.com/coreos/go-oidc/v3 v3.10.0 github.com/edsrzf/mmap-go v1.1.0 @@ -79,6 +83,17 @@ require ( github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/armon/go-metrics v0.4.1 // indirect github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.20.6 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.28.7 // indirect + github.com/aws/smithy-go v1.20.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/boltdb/bolt v1.3.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index 4efa45f088..fcf7ac033d 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,36 @@ github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/config v1.27.12 h1:vq88mBaZI4NGLXk8ierArwSILmYHDJZGJOeAc/pzEVQ= +github.com/aws/aws-sdk-go-v2/config v1.27.12/go.mod h1:IOrsf4IiN68+CgzyuyGUYTpCrtUQTbbMEAtR/MR/4ZU= +github.com/aws/aws-sdk-go-v2/credentials v1.17.12 h1:PVbKQ0KjDosI5+nEdRMU8ygEQDmkJTSHBqPjEX30lqc= +github.com/aws/aws-sdk-go-v2/credentials v1.17.12/go.mod h1:jlWtGFRtKsqc5zqerHZYmKmRkUXo3KPM14YJ13ZEjwE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1 h1:vTHgBjsGhgKWWIgioxd7MkBH5Ekr8C6Cb+/8iWf1dpc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.6 h1:o5cTaeunSpfXiLTIBx5xo2enQmiChtu1IBbzXnfU9Hs= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.6/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.5 h1:Ciiz/plN+Z+pPO1G0W2zJoYIIl0KtKzY0LJ78NXYTws= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.5/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.7 h1:et3Ta53gotFR4ERLXXHIHl/Uuk1qYpP5uU7cvNql8ns= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.7/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bmatcuk/doublestar v1.1.3 h1:S4Ka/fLvUtm+5TqKuByWyuGenBjTP8w+Z/GpQIWB9Yg= diff --git a/grpc/generated/protocol/v1/base.pb.go b/grpc/generated/protocol/v1/base.pb.go index 30526c5fda..144b57dd30 100644 --- a/grpc/generated/protocol/v1/base.pb.go +++ b/grpc/generated/protocol/v1/base.pb.go @@ -157,7 +157,7 @@ type NumberArrayProperties struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // Deprecated: Marked as deprecated in v1/base.proto. + // Deprecated: Do not use. Values []float64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` // will be removed in the future, use vector_bytes PropName string `protobuf:"bytes,2,opt,name=prop_name,json=propName,proto3" json:"prop_name,omitempty"` ValuesBytes []byte `protobuf:"bytes,3,opt,name=values_bytes,json=valuesBytes,proto3" json:"values_bytes,omitempty"` @@ -195,7 +195,7 @@ func (*NumberArrayProperties) Descriptor() ([]byte, []int) { return file_v1_base_proto_rawDescGZIP(), []int{0} } -// Deprecated: Marked as deprecated in v1/base.proto. +// Deprecated: Do not use. func (x *NumberArrayProperties) GetValues() []float64 { if x != nil { return x.Values @@ -791,11 +791,10 @@ type Filters struct { Operator Filters_Operator `protobuf:"varint,1,opt,name=operator,proto3,enum=weaviate.v1.Filters_Operator" json:"operator,omitempty"` // protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED // - // Deprecated: Marked as deprecated in v1/base.proto. + // Deprecated: Do not use. On []string `protobuf:"bytes,2,rep,name=on,proto3" json:"on,omitempty"` // will be removed in the future, use path Filters []*Filters `protobuf:"bytes,3,rep,name=filters,proto3" json:"filters,omitempty"` // Types that are assignable to TestValue: - // // *Filters_ValueText // *Filters_ValueInt // *Filters_ValueBoolean @@ -848,7 +847,7 @@ func (x *Filters) GetOperator() Filters_Operator { return Filters_OPERATOR_UNSPECIFIED } -// Deprecated: Marked as deprecated in v1/base.proto. +// Deprecated: Do not use. func (x *Filters) GetOn() []string { if x != nil { return x.On @@ -1169,7 +1168,6 @@ type FilterTarget struct { unknownFields protoimpl.UnknownFields // Types that are assignable to Target: - // // *FilterTarget_Property // *FilterTarget_SingleTarget // *FilterTarget_MultiTarget diff --git a/grpc/generated/protocol/v1/batch.pb.go b/grpc/generated/protocol/v1/batch.pb.go index a8af5f5779..561bae35a2 100644 --- a/grpc/generated/protocol/v1/batch.pb.go +++ b/grpc/generated/protocol/v1/batch.pb.go @@ -81,7 +81,7 @@ type BatchObject struct { Uuid string `protobuf:"bytes,1,opt,name=uuid,proto3" json:"uuid,omitempty"` // protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED // - // Deprecated: Marked as deprecated in v1/batch.proto. + // Deprecated: Do not use. Vector []float32 `protobuf:"fixed32,2,rep,packed,name=vector,proto3" json:"vector,omitempty"` // deprecated, will be removed Properties *BatchObject_Properties `protobuf:"bytes,3,opt,name=properties,proto3" json:"properties,omitempty"` Collection string `protobuf:"bytes,4,opt,name=collection,proto3" json:"collection,omitempty"` @@ -130,7 +130,7 @@ func (x *BatchObject) GetUuid() string { return "" } -// Deprecated: Marked as deprecated in v1/batch.proto. +// Deprecated: Do not use. func (x *BatchObject) GetVector() []float32 { if x != nil { return x.Vector diff --git a/grpc/generated/protocol/v1/batch_delete.pb.go b/grpc/generated/protocol/v1/batch_delete.pb.go index 0b5b9484ed..239923b98b 100644 --- a/grpc/generated/protocol/v1/batch_delete.pb.go +++ b/grpc/generated/protocol/v1/batch_delete.pb.go @@ -289,14 +289,14 @@ var file_v1_batch_delete_proto_rawDesc = []byte{ 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x12, 0x19, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x88, 0x01, 0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x42, - 0x6f, 0x0a, 0x23, 0x69, 0x6f, 0x2e, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, + 0x75, 0x0a, 0x23, 0x69, 0x6f, 0x2e, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x2e, 0x76, 0x31, 0x42, 0x12, 0x57, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x42, 0x61, 0x74, 0x63, 0x68, 0x5a, 0x34, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2f, - 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x67, 0x65, - 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x64, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x63, 0x6f, 0x6c, 0x2e, 0x76, 0x31, 0x42, 0x18, 0x57, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x42, 0x61, 0x74, 0x63, 0x68, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x5a, 0x34, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x77, 0x65, 0x61, + 0x76, 0x69, 0x61, 0x74, 0x65, 0x2f, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2f, 0x67, + 0x72, 0x70, 0x63, 0x2f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x64, 0x3b, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/grpc/generated/protocol/v1/properties.pb.go b/grpc/generated/protocol/v1/properties.pb.go index cafded197f..73cc28c59d 100644 --- a/grpc/generated/protocol/v1/properties.pb.go +++ b/grpc/generated/protocol/v1/properties.pb.go @@ -71,7 +71,6 @@ type Value struct { unknownFields protoimpl.UnknownFields // Types that are assignable to Kind: - // // *Value_NumberValue // *Value_StringValue // *Value_BoolValue @@ -134,7 +133,7 @@ func (x *Value) GetNumberValue() float64 { return 0 } -// Deprecated: Marked as deprecated in v1/properties.proto. +// Deprecated: Do not use. func (x *Value) GetStringValue() string { if x, ok := x.GetKind().(*Value_StringValue); ok { return x.StringValue @@ -209,7 +208,7 @@ func (x *Value) GetNullValue() structpb.NullValue { if x, ok := x.GetKind().(*Value_NullValue); ok { return x.NullValue } - return structpb.NullValue(0) + return structpb.NullValue_NULL_VALUE } func (x *Value) GetTextValue() string { @@ -228,7 +227,7 @@ type Value_NumberValue struct { } type Value_StringValue struct { - // Deprecated: Marked as deprecated in v1/properties.proto. + // Deprecated: Do not use. StringValue string `protobuf:"bytes,2,opt,name=string_value,json=stringValue,proto3,oneof"` } @@ -307,10 +306,9 @@ type ListValue struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // Deprecated: Marked as deprecated in v1/properties.proto. + // Deprecated: Do not use. Values []*Value `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` // Types that are assignable to Kind: - // // *ListValue_NumberValues // *ListValue_BoolValues // *ListValue_ObjectValues @@ -353,7 +351,7 @@ func (*ListValue) Descriptor() ([]byte, []int) { return file_v1_properties_proto_rawDescGZIP(), []int{2} } -// Deprecated: Marked as deprecated in v1/properties.proto. +// Deprecated: Do not use. func (x *ListValue) GetValues() []*Value { if x != nil { return x.Values @@ -468,7 +466,7 @@ type NumberValues struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // * + //* // The values are stored as a byte array, where each 8 bytes represent a single float64 value. // The byte array is stored in little-endian order using uint64 encoding. Values []byte `protobuf:"bytes,1,opt,name=values,proto3" json:"values,omitempty"` @@ -518,9 +516,6 @@ type TextValues struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // * - // The values are stored as a byte array, where each byte contains a single UTF-8 character. - // Individual text values are delimited by a ',' character within the overall byte array. Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` } @@ -568,9 +563,6 @@ type BoolValues struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // * - // The values are stored as a byte array, where each byte represents a single boolean value. - // The byte array is stored in little-endian order using uint64 encoding. Values []bool `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` } @@ -665,9 +657,6 @@ type DateValues struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // * - // The values are stored as a byte array, where each byte contains a single UTF-8 character. - // Individual date values are delimited by a ',' character within the overall byte array. Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` } @@ -715,9 +704,6 @@ type UuidValues struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // * - // The values are stored as a byte array, where each byte contains a single UTF-8 character. - // Individual uuid values are delimited by a ',' character within the overall byte array. Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` } @@ -765,7 +751,7 @@ type IntValues struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // * + //* // The values are stored as a byte array, where each 8 bytes represent a single int64 value. // The byte array is stored in little-endian order using uint64 encoding. Values []byte `protobuf:"bytes,1,opt,name=values,proto3" json:"values,omitempty"` diff --git a/grpc/generated/protocol/v1/search_get.pb.go b/grpc/generated/protocol/v1/search_get.pb.go index ffe4ab8a22..8aac90100c 100644 --- a/grpc/generated/protocol/v1/search_get.pb.go +++ b/grpc/generated/protocol/v1/search_get.pb.go @@ -72,7 +72,7 @@ type SearchRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // required + //required Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` // parameters Tenant string `protobuf:"bytes,10,opt,name=tenant,proto3" json:"tenant,omitempty"` @@ -103,7 +103,7 @@ type SearchRequest struct { NearImu *NearIMUSearch `protobuf:"bytes,51,opt,name=near_imu,json=nearImu,proto3,oneof" json:"near_imu,omitempty"` Generative *GenerativeSearch `protobuf:"bytes,60,opt,name=generative,proto3,oneof" json:"generative,omitempty"` Rerank *Rerank `protobuf:"bytes,61,opt,name=rerank,proto3,oneof" json:"rerank,omitempty"` - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. Uses_123Api bool `protobuf:"varint,100,opt,name=uses_123_api,json=uses123Api,proto3" json:"uses_123_api,omitempty"` Uses_125Api bool `protobuf:"varint,101,opt,name=uses_125_api,json=uses125Api,proto3" json:"uses_125_api,omitempty"` } @@ -315,7 +315,7 @@ func (x *SearchRequest) GetRerank() *Rerank { return nil } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *SearchRequest) GetUses_123Api() bool { if x != nil { return x.Uses_123Api @@ -780,15 +780,14 @@ type Hybrid struct { Properties []string `protobuf:"bytes,2,rep,name=properties,proto3" json:"properties,omitempty"` // protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED // - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. Vector []float32 `protobuf:"fixed32,3,rep,packed,name=vector,proto3" json:"vector,omitempty"` // will be removed in the future, use vector_bytes Alpha float32 `protobuf:"fixed32,4,opt,name=alpha,proto3" json:"alpha,omitempty"` FusionType Hybrid_FusionType `protobuf:"varint,5,opt,name=fusion_type,json=fusionType,proto3,enum=weaviate.v1.Hybrid_FusionType" json:"fusion_type,omitempty"` VectorBytes []byte `protobuf:"bytes,6,opt,name=vector_bytes,json=vectorBytes,proto3" json:"vector_bytes,omitempty"` TargetVectors []string `protobuf:"bytes,7,rep,name=target_vectors,json=targetVectors,proto3" json:"target_vectors,omitempty"` - NearText *NearTextSearch `protobuf:"bytes,8,opt,name=near_text,json=nearText,proto3" json:"near_text,omitempty"` - NearVector *NearVector `protobuf:"bytes,9,opt,name=near_vector,json=nearVector,proto3" json:"near_vector,omitempty"` - GroupBy *GroupBy `protobuf:"bytes,10,opt,name=group_by,json=groupBy,proto3" json:"group_by,omitempty"` + NearText *NearTextSearch `protobuf:"bytes,8,opt,name=near_text,json=nearText,proto3" json:"near_text,omitempty"` // target_vector in msg is ignored and should not be set for hybrid + NearVector *NearVector `protobuf:"bytes,9,opt,name=near_vector,json=nearVector,proto3" json:"near_vector,omitempty"` // same as above. Use the target vector in the hybrid message } func (x *Hybrid) Reset() { @@ -837,7 +836,7 @@ func (x *Hybrid) GetProperties() []string { return nil } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *Hybrid) GetVector() []float32 { if x != nil { return x.Vector @@ -887,13 +886,6 @@ func (x *Hybrid) GetNearVector() *NearVector { return nil } -func (x *Hybrid) GetGroupBy() *GroupBy { - if x != nil { - return x.GroupBy - } - return nil -} - type NearTextSearch struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1415,7 +1407,6 @@ type BM25 struct { Query string `protobuf:"bytes,1,opt,name=query,proto3" json:"query,omitempty"` Properties []string `protobuf:"bytes,2,rep,name=properties,proto3" json:"properties,omitempty"` - GroupBy *GroupBy `protobuf:"bytes,3,opt,name=group_by,json=groupBy,proto3" json:"group_by,omitempty"` } func (x *BM25) Reset() { @@ -1464,13 +1455,6 @@ func (x *BM25) GetProperties() []string { return nil } -func (x *BM25) GetGroupBy() *GroupBy { - if x != nil { - return x.GroupBy - } - return nil -} - type RefPropertiesRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1549,7 +1533,7 @@ type NearVector struct { // protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED // - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. Vector []float32 `protobuf:"fixed32,1,rep,packed,name=vector,proto3" json:"vector,omitempty"` // will be removed in the future, use vector_bytes Certainty *float64 `protobuf:"fixed64,2,opt,name=certainty,proto3,oneof" json:"certainty,omitempty"` Distance *float64 `protobuf:"fixed64,3,opt,name=distance,proto3,oneof" json:"distance,omitempty"` @@ -1589,7 +1573,7 @@ func (*NearVector) Descriptor() ([]byte, []int) { return file_v1_search_get_proto_rawDescGZIP(), []int{17} } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *NearVector) GetVector() []float32 { if x != nil { return x.Vector @@ -2074,7 +2058,7 @@ type MetadataResult struct { Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED // - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. Vector []float32 `protobuf:"fixed32,2,rep,packed,name=vector,proto3" json:"vector,omitempty"` CreationTimeUnix int64 `protobuf:"varint,3,opt,name=creation_time_unix,json=creationTimeUnix,proto3" json:"creation_time_unix,omitempty"` CreationTimeUnixPresent bool `protobuf:"varint,4,opt,name=creation_time_unix_present,json=creationTimeUnixPresent,proto3" json:"creation_time_unix_present,omitempty"` @@ -2138,7 +2122,7 @@ func (x *MetadataResult) GetId() string { return "" } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *MetadataResult) GetVector() []float32 { if x != nil { return x.Vector @@ -2298,22 +2282,22 @@ type PropertiesResult struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. NonRefProperties *structpb.Struct `protobuf:"bytes,1,opt,name=non_ref_properties,json=nonRefProperties,proto3" json:"non_ref_properties,omitempty"` RefProps []*RefPropertiesResult `protobuf:"bytes,2,rep,name=ref_props,json=refProps,proto3" json:"ref_props,omitempty"` TargetCollection string `protobuf:"bytes,3,opt,name=target_collection,json=targetCollection,proto3" json:"target_collection,omitempty"` Metadata *MetadataResult `protobuf:"bytes,4,opt,name=metadata,proto3" json:"metadata,omitempty"` - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. NumberArrayProperties []*NumberArrayProperties `protobuf:"bytes,5,rep,name=number_array_properties,json=numberArrayProperties,proto3" json:"number_array_properties,omitempty"` - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. IntArrayProperties []*IntArrayProperties `protobuf:"bytes,6,rep,name=int_array_properties,json=intArrayProperties,proto3" json:"int_array_properties,omitempty"` - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. TextArrayProperties []*TextArrayProperties `protobuf:"bytes,7,rep,name=text_array_properties,json=textArrayProperties,proto3" json:"text_array_properties,omitempty"` - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. BooleanArrayProperties []*BooleanArrayProperties `protobuf:"bytes,8,rep,name=boolean_array_properties,json=booleanArrayProperties,proto3" json:"boolean_array_properties,omitempty"` - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. ObjectProperties []*ObjectProperties `protobuf:"bytes,9,rep,name=object_properties,json=objectProperties,proto3" json:"object_properties,omitempty"` - // Deprecated: Marked as deprecated in v1/search_get.proto. + // Deprecated: Do not use. ObjectArrayProperties []*ObjectArrayProperties `protobuf:"bytes,10,rep,name=object_array_properties,json=objectArrayProperties,proto3" json:"object_array_properties,omitempty"` NonRefProps *Properties `protobuf:"bytes,11,opt,name=non_ref_props,json=nonRefProps,proto3" json:"non_ref_props,omitempty"` RefPropsRequested bool `protobuf:"varint,12,opt,name=ref_props_requested,json=refPropsRequested,proto3" json:"ref_props_requested,omitempty"` @@ -2351,7 +2335,7 @@ func (*PropertiesResult) Descriptor() ([]byte, []int) { return file_v1_search_get_proto_rawDescGZIP(), []int{26} } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *PropertiesResult) GetNonRefProperties() *structpb.Struct { if x != nil { return x.NonRefProperties @@ -2380,7 +2364,7 @@ func (x *PropertiesResult) GetMetadata() *MetadataResult { return nil } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *PropertiesResult) GetNumberArrayProperties() []*NumberArrayProperties { if x != nil { return x.NumberArrayProperties @@ -2388,7 +2372,7 @@ func (x *PropertiesResult) GetNumberArrayProperties() []*NumberArrayProperties { return nil } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *PropertiesResult) GetIntArrayProperties() []*IntArrayProperties { if x != nil { return x.IntArrayProperties @@ -2396,7 +2380,7 @@ func (x *PropertiesResult) GetIntArrayProperties() []*IntArrayProperties { return nil } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *PropertiesResult) GetTextArrayProperties() []*TextArrayProperties { if x != nil { return x.TextArrayProperties @@ -2404,7 +2388,7 @@ func (x *PropertiesResult) GetTextArrayProperties() []*TextArrayProperties { return nil } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *PropertiesResult) GetBooleanArrayProperties() []*BooleanArrayProperties { if x != nil { return x.BooleanArrayProperties @@ -2412,7 +2396,7 @@ func (x *PropertiesResult) GetBooleanArrayProperties() []*BooleanArrayProperties return nil } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *PropertiesResult) GetObjectProperties() []*ObjectProperties { if x != nil { return x.ObjectProperties @@ -2420,7 +2404,7 @@ func (x *PropertiesResult) GetObjectProperties() []*ObjectProperties { return nil } -// Deprecated: Marked as deprecated in v1/search_get.proto. +// Deprecated: Do not use. func (x *PropertiesResult) GetObjectArrayProperties() []*ObjectArrayProperties { if x != nil { return x.ObjectArrayProperties @@ -3187,36 +3171,34 @@ var file_v1_search_get_proto_depIdxs = []int32{ 0, // 22: weaviate.v1.Hybrid.fusion_type:type_name -> weaviate.v1.Hybrid.FusionType 9, // 23: weaviate.v1.Hybrid.near_text:type_name -> weaviate.v1.NearTextSearch 18, // 24: weaviate.v1.Hybrid.near_vector:type_name -> weaviate.v1.NearVector - 2, // 25: weaviate.v1.Hybrid.group_by:type_name -> weaviate.v1.GroupBy - 29, // 26: weaviate.v1.NearTextSearch.move_to:type_name -> weaviate.v1.NearTextSearch.Move - 29, // 27: weaviate.v1.NearTextSearch.move_away:type_name -> weaviate.v1.NearTextSearch.Move - 2, // 28: weaviate.v1.BM25.group_by:type_name -> weaviate.v1.GroupBy - 6, // 29: weaviate.v1.RefPropertiesRequest.properties:type_name -> weaviate.v1.PropertiesRequest - 5, // 30: weaviate.v1.RefPropertiesRequest.metadata:type_name -> weaviate.v1.MetadataRequest - 25, // 31: weaviate.v1.SearchReply.results:type_name -> weaviate.v1.SearchResult - 24, // 32: weaviate.v1.SearchReply.group_by_results:type_name -> weaviate.v1.GroupByResult - 25, // 33: weaviate.v1.GroupByResult.objects:type_name -> weaviate.v1.SearchResult - 22, // 34: weaviate.v1.GroupByResult.rerank:type_name -> weaviate.v1.RerankReply - 23, // 35: weaviate.v1.GroupByResult.generative:type_name -> weaviate.v1.GenerativeReply - 27, // 36: weaviate.v1.SearchResult.properties:type_name -> weaviate.v1.PropertiesResult - 26, // 37: weaviate.v1.SearchResult.metadata:type_name -> weaviate.v1.MetadataResult - 32, // 38: weaviate.v1.MetadataResult.vectors:type_name -> weaviate.v1.Vectors - 33, // 39: weaviate.v1.PropertiesResult.non_ref_properties:type_name -> google.protobuf.Struct - 28, // 40: weaviate.v1.PropertiesResult.ref_props:type_name -> weaviate.v1.RefPropertiesResult - 26, // 41: weaviate.v1.PropertiesResult.metadata:type_name -> weaviate.v1.MetadataResult - 34, // 42: weaviate.v1.PropertiesResult.number_array_properties:type_name -> weaviate.v1.NumberArrayProperties - 35, // 43: weaviate.v1.PropertiesResult.int_array_properties:type_name -> weaviate.v1.IntArrayProperties - 36, // 44: weaviate.v1.PropertiesResult.text_array_properties:type_name -> weaviate.v1.TextArrayProperties - 37, // 45: weaviate.v1.PropertiesResult.boolean_array_properties:type_name -> weaviate.v1.BooleanArrayProperties - 38, // 46: weaviate.v1.PropertiesResult.object_properties:type_name -> weaviate.v1.ObjectProperties - 39, // 47: weaviate.v1.PropertiesResult.object_array_properties:type_name -> weaviate.v1.ObjectArrayProperties - 40, // 48: weaviate.v1.PropertiesResult.non_ref_props:type_name -> weaviate.v1.Properties - 27, // 49: weaviate.v1.RefPropertiesResult.properties:type_name -> weaviate.v1.PropertiesResult - 50, // [50:50] is the sub-list for method output_type - 50, // [50:50] is the sub-list for method input_type - 50, // [50:50] is the sub-list for extension type_name - 50, // [50:50] is the sub-list for extension extendee - 0, // [0:50] is the sub-list for field type_name + 29, // 25: weaviate.v1.NearTextSearch.move_to:type_name -> weaviate.v1.NearTextSearch.Move + 29, // 26: weaviate.v1.NearTextSearch.move_away:type_name -> weaviate.v1.NearTextSearch.Move + 6, // 27: weaviate.v1.RefPropertiesRequest.properties:type_name -> weaviate.v1.PropertiesRequest + 5, // 28: weaviate.v1.RefPropertiesRequest.metadata:type_name -> weaviate.v1.MetadataRequest + 25, // 29: weaviate.v1.SearchReply.results:type_name -> weaviate.v1.SearchResult + 24, // 30: weaviate.v1.SearchReply.group_by_results:type_name -> weaviate.v1.GroupByResult + 25, // 31: weaviate.v1.GroupByResult.objects:type_name -> weaviate.v1.SearchResult + 22, // 32: weaviate.v1.GroupByResult.rerank:type_name -> weaviate.v1.RerankReply + 23, // 33: weaviate.v1.GroupByResult.generative:type_name -> weaviate.v1.GenerativeReply + 27, // 34: weaviate.v1.SearchResult.properties:type_name -> weaviate.v1.PropertiesResult + 26, // 35: weaviate.v1.SearchResult.metadata:type_name -> weaviate.v1.MetadataResult + 32, // 36: weaviate.v1.MetadataResult.vectors:type_name -> weaviate.v1.Vectors + 33, // 37: weaviate.v1.PropertiesResult.non_ref_properties:type_name -> google.protobuf.Struct + 28, // 38: weaviate.v1.PropertiesResult.ref_props:type_name -> weaviate.v1.RefPropertiesResult + 26, // 39: weaviate.v1.PropertiesResult.metadata:type_name -> weaviate.v1.MetadataResult + 34, // 40: weaviate.v1.PropertiesResult.number_array_properties:type_name -> weaviate.v1.NumberArrayProperties + 35, // 41: weaviate.v1.PropertiesResult.int_array_properties:type_name -> weaviate.v1.IntArrayProperties + 36, // 42: weaviate.v1.PropertiesResult.text_array_properties:type_name -> weaviate.v1.TextArrayProperties + 37, // 43: weaviate.v1.PropertiesResult.boolean_array_properties:type_name -> weaviate.v1.BooleanArrayProperties + 38, // 44: weaviate.v1.PropertiesResult.object_properties:type_name -> weaviate.v1.ObjectProperties + 39, // 45: weaviate.v1.PropertiesResult.object_array_properties:type_name -> weaviate.v1.ObjectArrayProperties + 40, // 46: weaviate.v1.PropertiesResult.non_ref_props:type_name -> weaviate.v1.Properties + 27, // 47: weaviate.v1.RefPropertiesResult.properties:type_name -> weaviate.v1.PropertiesResult + 48, // [48:48] is the sub-list for method output_type + 48, // [48:48] is the sub-list for method input_type + 48, // [48:48] is the sub-list for extension type_name + 48, // [48:48] is the sub-list for extension extendee + 0, // [0:48] is the sub-list for field type_name } func init() { file_v1_search_get_proto_init() } diff --git a/grpc/generated/protocol/v1/tenants.pb.go b/grpc/generated/protocol/v1/tenants.pb.go index 1792a4356d..ae3baa0ccb 100644 --- a/grpc/generated/protocol/v1/tenants.pb.go +++ b/grpc/generated/protocol/v1/tenants.pb.go @@ -77,12 +77,10 @@ type TenantsGetRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` - IsConsistent bool `protobuf:"varint,2,opt,name=is_consistent,json=isConsistent,proto3" json:"is_consistent,omitempty"` + Collection string `protobuf:"bytes,1,opt,name=collection,proto3" json:"collection,omitempty"` // we might need to add a tenant-cursor api at some point, make this easily extendable // // Types that are assignable to Params: - // // *TenantsGetRequest_Names Params isTenantsGetRequest_Params `protobuf_oneof:"params"` } @@ -126,13 +124,6 @@ func (x *TenantsGetRequest) GetCollection() string { return "" } -func (x *TenantsGetRequest) GetIsConsistent() bool { - if x != nil { - return x.IsConsistent - } - return false -} - func (m *TenantsGetRequest) GetParams() isTenantsGetRequest_Params { if m != nil { return m.Params @@ -152,7 +143,7 @@ type isTenantsGetRequest_Params interface { } type TenantsGetRequest_Names struct { - Names *TenantNames `protobuf:"bytes,3,opt,name=names,proto3,oneof"` + Names *TenantNames `protobuf:"bytes,2,opt,name=names,proto3,oneof"` } func (*TenantsGetRequest_Names) isTenantsGetRequest_Params() {} @@ -319,50 +310,48 @@ var File_v1_tenants_proto protoreflect.FileDescriptor var file_v1_tenants_proto_rawDesc = []byte{ 0x0a, 0x10, 0x76, 0x31, 0x2f, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0b, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x76, 0x31, 0x22, - 0x94, 0x01, 0x0a, 0x11, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, 0x47, 0x65, 0x74, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6c, 0x6c, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x69, 0x73, 0x5f, 0x63, 0x6f, 0x6e, 0x73, - 0x69, 0x73, 0x74, 0x65, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x69, 0x73, - 0x43, 0x6f, 0x6e, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x74, 0x12, 0x30, 0x0a, 0x05, 0x6e, 0x61, - 0x6d, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x77, 0x65, 0x61, 0x76, - 0x69, 0x61, 0x74, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x4e, 0x61, - 0x6d, 0x65, 0x73, 0x48, 0x00, 0x52, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x42, 0x08, 0x0a, 0x06, - 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x22, 0x25, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, - 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x22, 0x54, 0x0a, - 0x0f, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, 0x47, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79, - 0x12, 0x12, 0x0a, 0x04, 0x74, 0x6f, 0x6f, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, - 0x74, 0x6f, 0x6f, 0x6b, 0x12, 0x2d, 0x0a, 0x07, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, - 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x52, 0x07, 0x74, 0x65, 0x6e, 0x61, - 0x6e, 0x74, 0x73, 0x22, 0x68, 0x0a, 0x06, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x12, 0x12, 0x0a, - 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, - 0x65, 0x12, 0x4a, 0x0a, 0x0f, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x5f, 0x73, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x21, 0x2e, 0x77, 0x65, 0x61, - 0x76, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x41, - 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0e, 0x61, - 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x2a, 0xc3, 0x01, - 0x0a, 0x14, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x26, 0x0a, 0x22, 0x54, 0x45, 0x4e, 0x41, 0x4e, 0x54, - 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, - 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x1e, - 0x0a, 0x1a, 0x54, 0x45, 0x4e, 0x41, 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, - 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x48, 0x4f, 0x54, 0x10, 0x01, 0x12, 0x1f, - 0x0a, 0x1b, 0x54, 0x45, 0x4e, 0x41, 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, - 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x4f, 0x4c, 0x44, 0x10, 0x02, 0x12, - 0x1f, 0x0a, 0x1b, 0x54, 0x45, 0x4e, 0x41, 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x49, - 0x54, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x57, 0x41, 0x52, 0x4d, 0x10, 0x03, - 0x12, 0x21, 0x0a, 0x1d, 0x54, 0x45, 0x4e, 0x41, 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, - 0x49, 0x54, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x46, 0x52, 0x4f, 0x5a, 0x45, - 0x4e, 0x10, 0x04, 0x42, 0x71, 0x0a, 0x23, 0x69, 0x6f, 0x2e, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, - 0x74, 0x65, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x76, 0x31, 0x42, 0x14, 0x57, 0x65, 0x61, 0x76, - 0x69, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, - 0x5a, 0x34, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x77, 0x65, 0x61, - 0x76, 0x69, 0x61, 0x74, 0x65, 0x2f, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2f, 0x67, - 0x72, 0x70, 0x63, 0x2f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x64, 0x3b, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6f, 0x0a, 0x11, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6c, 0x6c, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x30, 0x0a, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x76, + 0x31, 0x2e, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x48, 0x00, 0x52, + 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x42, 0x08, 0x0a, 0x06, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, + 0x22, 0x25, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x12, + 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x22, 0x54, 0x0a, 0x0f, 0x54, 0x65, 0x6e, 0x61, 0x6e, + 0x74, 0x73, 0x47, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x6f, + 0x6f, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x74, 0x6f, 0x6f, 0x6b, 0x12, 0x2d, + 0x0a, 0x07, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x13, 0x2e, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x65, + 0x6e, 0x61, 0x6e, 0x74, 0x52, 0x07, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, 0x22, 0x68, 0x0a, + 0x06, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x4a, 0x0a, 0x0f, 0x61, + 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x21, 0x2e, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2e, + 0x76, 0x31, 0x2e, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, + 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0e, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, + 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x2a, 0xc3, 0x01, 0x0a, 0x14, 0x54, 0x65, 0x6e, 0x61, + 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x26, 0x0a, 0x22, 0x54, 0x45, 0x4e, 0x41, 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, + 0x49, 0x54, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, + 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x1e, 0x0a, 0x1a, 0x54, 0x45, 0x4e, 0x41, + 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, + 0x55, 0x53, 0x5f, 0x48, 0x4f, 0x54, 0x10, 0x01, 0x12, 0x1f, 0x0a, 0x1b, 0x54, 0x45, 0x4e, 0x41, + 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, + 0x55, 0x53, 0x5f, 0x43, 0x4f, 0x4c, 0x44, 0x10, 0x02, 0x12, 0x1f, 0x0a, 0x1b, 0x54, 0x45, 0x4e, + 0x41, 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, 0x59, 0x5f, 0x53, 0x54, 0x41, + 0x54, 0x55, 0x53, 0x5f, 0x57, 0x41, 0x52, 0x4d, 0x10, 0x03, 0x12, 0x21, 0x0a, 0x1d, 0x54, 0x45, + 0x4e, 0x41, 0x4e, 0x54, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, 0x59, 0x5f, 0x53, 0x54, + 0x41, 0x54, 0x55, 0x53, 0x5f, 0x46, 0x52, 0x4f, 0x5a, 0x45, 0x4e, 0x10, 0x04, 0x42, 0x71, 0x0a, + 0x23, 0x69, 0x6f, 0x2e, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x2e, 0x76, 0x31, 0x42, 0x14, 0x57, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x73, 0x5a, 0x34, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2f, + 0x77, 0x65, 0x61, 0x76, 0x69, 0x61, 0x74, 0x65, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x67, 0x65, + 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x64, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/grpc/proto/v1/batch_delete.proto b/grpc/proto/v1/batch_delete.proto index 958103af63..d0a9f8e827 100644 --- a/grpc/proto/v1/batch_delete.proto +++ b/grpc/proto/v1/batch_delete.proto @@ -6,7 +6,7 @@ import "v1/base.proto"; option go_package = "github.com/weaviate/weaviate/grpc/generated;protocol"; option java_package = "io.weaviate.client.grpc.protocol.v1"; -option java_outer_classname = "WeaviateProtoBatch"; +option java_outer_classname = "WeaviateProtoBatchDelete"; message BatchDeleteRequest { string collection = 1; diff --git a/grpc/proto/v1/search_get.proto b/grpc/proto/v1/search_get.proto index 3e00334864..66b0d35c1c 100644 --- a/grpc/proto/v1/search_get.proto +++ b/grpc/proto/v1/search_get.proto @@ -116,8 +116,8 @@ message Hybrid { FusionType fusion_type = 5; bytes vector_bytes = 6; repeated string target_vectors = 7; - NearTextSearch near_text = 8; - NearVector near_vector = 9; + NearTextSearch near_text = 8; // target_vector in msg is ignored and should not be set for hybrid + NearVector near_vector = 9; // same as above. Use the target vector in the hybrid message } message NearTextSearch { diff --git a/grpc/proto/v1/tenants.proto b/grpc/proto/v1/tenants.proto index f0078d2194..3dc995dc1b 100644 --- a/grpc/proto/v1/tenants.proto +++ b/grpc/proto/v1/tenants.proto @@ -16,10 +16,9 @@ enum TenantActivityStatus { message TenantsGetRequest { string collection = 1; - bool is_consistent = 2; // we might need to add a tenant-cursor api at some point, make this easily extendable oneof params { - TenantNames names = 3; + TenantNames names = 2; }; } diff --git a/modules/generative-aws/clients/aws.go b/modules/generative-aws/clients/aws.go index a2d07096d3..9d6db1d26a 100644 --- a/modules/generative-aws/clients/aws.go +++ b/modules/generative-aws/clients/aws.go @@ -22,10 +22,15 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/weaviate/weaviate/entities/moduletools" - "github.com/weaviate/weaviate/modules/generative-aws/config" + generativeconfig "github.com/weaviate/weaviate/modules/generative-aws/config" + "github.com/weaviate/weaviate/usecases/modulecomponents" generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" ) @@ -41,19 +46,21 @@ func buildSagemakerUrl(service, region, endpoint string) string { return fmt.Sprintf(urlTemplate, service, region, endpoint) } -type aws struct { +type awsClient struct { awsAccessKey string awsSecretKey string + awsSessionToken string buildBedrockUrlFn func(service, region, model string) string buildSagemakerUrlFn func(service, region, endpoint string) string httpClient *http.Client logger logrus.FieldLogger } -func New(awsAccessKey string, awsSecretKey string, timeout time.Duration, logger logrus.FieldLogger) *aws { - return &aws{ - awsAccessKey: awsAccessKey, - awsSecretKey: awsSecretKey, +func New(awsAccessKey, awsSecretKey, awsSessionToken string, timeout time.Duration, logger logrus.FieldLogger) *awsClient { + return &awsClient{ + awsAccessKey: awsAccessKey, + awsSecretKey: awsSecretKey, + awsSessionToken: awsSessionToken, httpClient: &http.Client{ Timeout: timeout, }, @@ -63,7 +70,7 @@ func New(awsAccessKey string, awsSecretKey string, timeout time.Duration, logger } } -func (v *aws) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { +func (v *awsClient) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { forPrompt, err := v.generateForPrompt(textProperties, prompt) if err != nil { return nil, err @@ -71,7 +78,7 @@ func (v *aws) GenerateSingleResult(ctx context.Context, textProperties map[strin return v.Generate(ctx, cfg, forPrompt) } -func (v *aws) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { +func (v *awsClient) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { forTask, err := v.generatePromptForTask(textProperties, task) if err != nil { return nil, err @@ -79,74 +86,51 @@ func (v *aws) GenerateAllResults(ctx context.Context, textProperties []map[strin return v.Generate(ctx, cfg, forTask) } -func (v *aws) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { - settings := config.NewClassSettings(cfg) +func (v *awsClient) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { + settings := generativeconfig.NewClassSettings(cfg) service := settings.Service() - region := settings.Region() - model := settings.Model() - endpoint := settings.Endpoint() - targetModel := settings.TargetModel() - targetVariant := settings.TargetVariant() - - var body []byte - var endpointUrl string - var host string - var path string - var err error - headers := map[string]string{ - "accept": "*/*", - "content-type": contentType, + accessKey, err := v.getAwsAccessKey(ctx) + if err != nil { + return nil, errors.Wrapf(err, "AWS Access Key") } + secretKey, err := v.getAwsAccessSecret(ctx) + if err != nil { + return nil, errors.Wrapf(err, "AWS Secret Key") + } + awsSessionToken, err := v.getAwsSessionToken(ctx) + if err != nil { + return nil, err + } + maxRetries := 5 if v.isBedrock(service) { - endpointUrl = v.buildBedrockUrlFn(service, region, model) - host = service + "-runtime" + "." + region + ".amazonaws.com" - path = "/model/" + model + "/invoke" - - if v.isAmazonModel(model) { - body, err = json.Marshal(bedrockAmazonGenerateRequest{ - InputText: prompt, - }) - } else if v.isAnthropicModel(model) { - var builder strings.Builder - builder.WriteString("\n\nHuman: ") - builder.WriteString(prompt) - builder.WriteString("\n\nAssistant:") - body, err = json.Marshal(bedrockAnthropicGenerateRequest{ - Prompt: builder.String(), - MaxTokensToSample: *settings.MaxTokenCount(), - Temperature: *settings.Temperature(), - TopK: *settings.TopK(), - TopP: settings.TopP(), - StopSequences: settings.StopSequences(), - AnthropicVersion: "bedrock-2023-05-31", - }) - } else if v.isAI21Model(model) { - body, err = json.Marshal(bedrockAI21GenerateRequest{ - Prompt: prompt, - MaxTokens: *settings.MaxTokenCount(), - Temperature: *settings.Temperature(), - TopP: settings.TopP(), - StopSequences: settings.StopSequences(), - }) - } else if v.isCohereModel(model) { - body, err = json.Marshal(bedrockCohereRequest{ - Prompt: prompt, - Temperature: *settings.Temperature(), - MaxTokens: *settings.MaxTokenCount(), - // ReturnLikeliHood: "GENERATION", // contray to docs, this is invalid - }) - } - - headers["x-amzn-bedrock-save"] = "false" - if err != nil { - return nil, errors.Wrapf(err, "marshal body") - } + return v.sendBedrockRequest(ctx, + prompt, + accessKey, secretKey, awsSessionToken, maxRetries, + cfg, + ) } else if v.isSagemaker(service) { + var body []byte + var endpointUrl string + var host string + var path string + var err error + + region := settings.Region() + endpoint := settings.Endpoint() + targetModel := settings.TargetModel() + targetVariant := settings.TargetVariant() + endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint) host = "runtime." + service + "." + region + ".amazonaws.com" path = "/endpoints/" + endpoint + "/invocations" + + headers := map[string]string{ + "accept": "*/*", + "content-type": contentType, + } + if targetModel != "" { headers["x-amzn-sagemaker-target-model"] = targetModel } @@ -159,97 +143,255 @@ func (v *aws) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt if err != nil { return nil, errors.Wrapf(err, "marshal body") } + headers["host"] = host + amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers) + headers["Authorization"] = authorizationHeader + headers["x-amz-date"] = amzDate + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrap(err, "create POST request") + } + + for k, v := range headers { + req.Header.Set(k, v) + } + + res, err := v.httpClient.Do(req) + if err != nil { + return nil, errors.Wrap(err, "send POST request") + } + defer res.Body.Close() + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, errors.Wrap(err, "read response body") + } + + return v.parseSagemakerResponse(bodyBytes, res) } else { - return nil, errors.Wrapf(err, "service error") + return &generativemodels.GenerateResponse{ + Result: nil, + }, nil } +} - accessKey, err := v.getAwsAccessKey(ctx) +func (v *awsClient) sendBedrockRequest( + ctx context.Context, + prompt string, + awsKey, awsSecret, awsSessionToken string, + maxRetries int, + cfg moduletools.ClassConfig, +) (*generativemodels.GenerateResponse, error) { + settings := generativeconfig.NewClassSettings(cfg) + model := settings.Model() + region := settings.Region() + req, err := v.createRequestBody(prompt, cfg) if err != nil { - return nil, errors.Wrapf(err, "AWS Access Key") + return nil, fmt.Errorf("failed to create request for model %s: %w", model, err) } - secretKey, err := v.getAwsAccessSecret(ctx) + + body, err := json.Marshal(req) if err != nil { - return nil, errors.Wrapf(err, "AWS Secret Key") + return nil, fmt.Errorf("failed to marshal request for model %s: %w", model, err) } - headers["host"] = host - amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers) - headers["Authorization"] = authorizationHeader - headers["x-amz-date"] = amzDate - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body)) + sdkConfig, err := config.LoadDefaultConfig(ctx, + config.WithRegion(region), + config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(awsKey, awsSecret, awsSessionToken), + ), + config.WithRetryMaxAttempts(maxRetries), + ) if err != nil { - return nil, errors.Wrap(err, "create POST request") + return nil, fmt.Errorf("failed to load AWS configuration: %w", err) } - for k, v := range headers { - req.Header.Set(k, v) + client := bedrockruntime.NewFromConfig(sdkConfig) + result, err := client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(model), + ContentType: aws.String("application/json"), + Body: body, + }) + if err != nil { + errMsg := err.Error() + if strings.Contains(errMsg, "no such host") { + return nil, fmt.Errorf("Bedrock service is not available in the selected region. " + + "Please double-check the service availability for your region at " + + "https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/") + } else if strings.Contains(errMsg, "Could not resolve the foundation model") { + return nil, fmt.Errorf("Could not resolve the foundation model from model identifier: \"%v\". "+ + "Please verify that the requested model exists and is accessible within the specified region", model) + } else { + return nil, fmt.Errorf("Couldn't invoke %s model: %w", model, err) + } } - res, err := v.httpClient.Do(req) - if err != nil { - return nil, errors.Wrap(err, "send POST request") + return v.parseBedrockResponse(result.Body, model) +} + +func (v *awsClient) createRequestBody(prompt string, cfg moduletools.ClassConfig) (interface{}, error) { + settings := generativeconfig.NewClassSettings(cfg) + model := settings.Model() + if v.isAmazonModel(model) { + return bedrockAmazonGenerateRequest{ + InputText: prompt, + }, nil + } else if v.isAnthropicClaude3Model(model) { + return bedrockAnthropicClaude3Request{ + AnthropicVersion: "bedrock-2023-05-31", + MaxTokens: *settings.MaxTokenCount(), + Messages: []bedrockAnthropicClaude3Message{ + { + Role: "user", + Content: []bedrockAnthropicClaude3Content{ + { + ContentType: "text", + Text: &prompt, + }, + }, + }, + }, + }, nil + } else if v.isAnthropicModel(model) { + var builder strings.Builder + builder.WriteString("\n\nHuman: ") + builder.WriteString(prompt) + builder.WriteString("\n\nAssistant:") + return bedrockAnthropicGenerateRequest{ + Prompt: builder.String(), + MaxTokensToSample: *settings.MaxTokenCount(), + Temperature: *settings.Temperature(), + StopSequences: settings.StopSequences(), + TopK: settings.TopK(), + TopP: settings.TopP(), + AnthropicVersion: "bedrock-2023-05-31", + }, nil + } else if v.isAI21Model(model) { + return bedrockAI21GenerateRequest{ + Prompt: prompt, + MaxTokens: *settings.MaxTokenCount(), + Temperature: *settings.Temperature(), + TopP: settings.TopP(), + StopSequences: settings.StopSequences(), + }, nil + } else if v.isCohereCommandRModel(model) { + return bedrockCohereCommandRRequest{ + Message: prompt, + }, nil + } else if v.isCohereModel(model) { + return bedrockCohereRequest{ + Prompt: prompt, + Temperature: *settings.Temperature(), + MaxTokens: *settings.MaxTokenCount(), + // ReturnLikeliHood: "GENERATION", // contray to docs, this is invalid + }, nil + } else if v.isMistralAIModel(model) { + return bedrockMistralAIRequest{ + Prompt: fmt.Sprintf("[INST] %s [/INST]", prompt), + MaxTokens: settings.MaxTokenCount(), + Temperature: settings.Temperature(), + }, nil + } else if v.isMetaModel(model) { + return bedrockMetaRequest{ + Prompt: prompt, + MaxGenLen: settings.MaxTokenCount(), + Temperature: settings.Temperature(), + }, nil } - defer res.Body.Close() + return nil, fmt.Errorf("unspported model: %s", model) +} - bodyBytes, err := io.ReadAll(res.Body) +func (v *awsClient) parseBedrockResponse(bodyBytes []byte, model string) (*generativemodels.GenerateResponse, error) { + content, err := v.getBedrockResponseMessage(model, bodyBytes) if err != nil { - return nil, errors.Wrap(err, "read response body") + return nil, err } - if v.isBedrock(service) { - return v.parseBedrockResponse(bodyBytes, res) - } else if v.isSagemaker(service) { - return v.parseSagemakerResponse(bodyBytes, res) - } else { + if content != "" { return &generativemodels.GenerateResponse{ - Result: nil, + Result: &content, }, nil } + + return &generativemodels.GenerateResponse{ + Result: nil, + }, nil } -func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) { +func (v *awsClient) getBedrockResponseMessage(model string, bodyBytes []byte) (string, error) { + var content string var resBodyMap map[string]interface{} if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil { - return nil, errors.Wrap(err, "unmarshal response body") + return "", errors.Wrap(err, "unmarshal response body") } - var resBody bedrockGenerateResponse - if err := json.Unmarshal(bodyBytes, &resBody); err != nil { - return nil, errors.Wrap(err, "unmarshal response body") + if v.isCohereCommandRModel(model) { + var resBody bedrockCohereCommandRResponse + if err := json.Unmarshal(bodyBytes, &resBody); err != nil { + return "", errors.Wrap(err, "unmarshal response body") + } + return resBody.Text, nil + } else if v.isAnthropicClaude3Model(model) { + var resBody bedrockAnthropicClaude3Response + if err := json.Unmarshal(bodyBytes, &resBody); err != nil { + return "", errors.Wrap(err, "unmarshal response body") + } + if len(resBody.Content) > 0 && resBody.Content[0].Text != nil { + return *resBody.Content[0].Text, nil + } + return "", fmt.Errorf("no message from model: %s", model) + } else if v.isAnthropicModel(model) { + var resBody bedrockAnthropicClaudeResponse + if err := json.Unmarshal(bodyBytes, &resBody); err != nil { + return "", errors.Wrap(err, "unmarshal response body") + } + return resBody.Completion, nil + } else if v.isAI21Model(model) { + var resBody bedrockAI21Response + if err := json.Unmarshal(bodyBytes, &resBody); err != nil { + return "", errors.Wrap(err, "unmarshal response body") + } + if len(resBody.Completions) > 0 { + return resBody.Completions[0].Data.Text, nil + } + return "", fmt.Errorf("no message from model: %s", model) + } else if v.isMistralAIModel(model) { + var resBody bedrockMistralAIResponse + if err := json.Unmarshal(bodyBytes, &resBody); err != nil { + return "", errors.Wrap(err, "unmarshal response body") + } + if len(resBody.Outputs) > 0 { + return resBody.Outputs[0].Text, nil + } + return "", fmt.Errorf("no message from model: %s", model) + } else if v.isMetaModel(model) { + var resBody bedrockMetaResponse + if err := json.Unmarshal(bodyBytes, &resBody); err != nil { + return "", errors.Wrap(err, "unmarshal response body") + } + return resBody.Generation, nil } - if res.StatusCode != 200 || resBody.Message != nil { - if resBody.Message != nil { - return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s", - res.StatusCode, *resBody.Message) - } - return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode) + var resBody bedrockGenerateResponse + if err := json.Unmarshal(bodyBytes, &resBody); err != nil { + return "", errors.Wrap(err, "unmarshal response body") } if len(resBody.Results) == 0 && len(resBody.Generations) == 0 { - return nil, fmt.Errorf("received empty response from AWS Bedrock") + return "", fmt.Errorf("received empty response from AWS Bedrock") } - var content string if len(resBody.Results) > 0 && len(resBody.Results[0].CompletionReason) > 0 { content = resBody.Results[0].OutputText } else if len(resBody.Generations) > 0 { content = resBody.Generations[0].Text } - if content != "" { - return &generativemodels.GenerateResponse{ - Result: &content, - }, nil - } - - return &generativemodels.GenerateResponse{ - Result: nil, - }, nil + return content, nil } -func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) { +func (v *awsClient) parseSagemakerResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) { var resBody sagemakerGenerateResponse if err := json.Unmarshal(bodyBytes, &resBody); err != nil { return nil, errors.Wrap(err, "unmarshal response body") @@ -280,15 +422,15 @@ func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response) (*gen }, nil } -func (v *aws) isSagemaker(service string) bool { +func (v *awsClient) isSagemaker(service string) bool { return service == "sagemaker" } -func (v *aws) isBedrock(service string) bool { +func (v *awsClient) isBedrock(service string) bool { return service == "bedrock" } -func (v *aws) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { +func (v *awsClient) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { marshal, err := json.Marshal(textProperties) if err != nil { return "", err @@ -297,7 +439,7 @@ func (v *aws) generatePromptForTask(textProperties []map[string]string, task str %v`, task, string(marshal)), nil } -func (v *aws) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { +func (v *awsClient) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { all := compile.FindAll([]byte(prompt), -1) for _, match := range all { originalProperty := string(match) @@ -312,13 +454,11 @@ func (v *aws) generateForPrompt(textProperties map[string]string, prompt string) return prompt, nil } -func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) { - awsAccessKey := ctx.Value("X-Aws-Access-Key") - if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok && - len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 { - return awsAccessKeyHeader[0], nil +func (v *awsClient) getAwsAccessKey(ctx context.Context) (string, error) { + if awsAccessKey := v.getHeaderValue(ctx, "X-Aws-Access-Key"); awsAccessKey != "" { + return awsAccessKey, nil } - if len(v.awsAccessKey) > 0 { + if v.awsAccessKey != "" { return v.awsAccessKey, nil } return "", errors.New("no access key found " + @@ -326,13 +466,11 @@ func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) { "nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY") } -func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) { - awsAccessSecret := ctx.Value("X-Aws-Secret-Key") - if awsAccessSecretHeader, ok := awsAccessSecret.([]string); ok && - len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 { - return awsAccessSecretHeader[0], nil +func (v *awsClient) getAwsAccessSecret(ctx context.Context) (string, error) { + if awsSecret := v.getHeaderValue(ctx, "X-Aws-Secret-Key"); awsSecret != "" { + return awsSecret, nil } - if len(v.awsSecretKey) > 0 { + if v.awsSecretKey != "" { return v.awsSecretKey, nil } return "", errors.New("no secret found " + @@ -340,20 +478,59 @@ func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) { "nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY") } -func (v *aws) isAmazonModel(model string) bool { - return strings.Contains(model, "amazon") +func (v *awsClient) getAwsSessionToken(ctx context.Context) (string, error) { + if awsSessionToken := v.getHeaderValue(ctx, "X-Aws-Session-Token"); awsSessionToken != "" { + return awsSessionToken, nil + } + if v.awsSessionToken != "" { + return v.awsSessionToken, nil + } + return "", nil +} + +func (v *awsClient) getHeaderValue(ctx context.Context, header string) string { + headerValue := ctx.Value(header) + if value, ok := headerValue.([]string); ok && + len(value) > 0 && len(value[0]) > 0 { + return value[0] + } + // try getting header from GRPC if not successful + if value := modulecomponents.GetValueFromGRPC(ctx, header); len(value) > 0 && len(value[0]) > 0 { + return value[0] + } + return "" +} + +func (v *awsClient) isAmazonModel(model string) bool { + return strings.HasPrefix(model, "amazon") +} + +func (v *awsClient) isAI21Model(model string) bool { + return strings.HasPrefix(model, "ai21") +} + +func (v *awsClient) isAnthropicModel(model string) bool { + return strings.HasPrefix(model, "anthropic") } -func (v *aws) isAI21Model(model string) bool { - return strings.Contains(model, "ai21") +func (v *awsClient) isAnthropicClaude3Model(model string) bool { + return strings.HasPrefix(model, "anthropic.claude-3") } -func (v *aws) isAnthropicModel(model string) bool { - return strings.Contains(model, "anthropic") +func (v *awsClient) isCohereModel(model string) bool { + return strings.HasPrefix(model, "cohere") } -func (v *aws) isCohereModel(model string) bool { - return strings.Contains(model, "cohere") +func (v *awsClient) isCohereCommandRModel(model string) bool { + return strings.HasPrefix(model, "cohere.command-r") +} + +func (v *awsClient) isMistralAIModel(model string) bool { + return strings.HasPrefix(model, "mistral") +} + +func (v *awsClient) isMetaModel(model string) bool { + return strings.HasPrefix(model, "meta") } type bedrockAmazonGenerateRequest struct { @@ -365,12 +542,58 @@ type bedrockAnthropicGenerateRequest struct { Prompt string `json:"prompt,omitempty"` MaxTokensToSample int `json:"max_tokens_to_sample,omitempty"` Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP *float64 `json:"top_p,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` AnthropicVersion string `json:"anthropic_version,omitempty"` } +type bedrockAnthropicClaudeResponse struct { + Completion string `json:"completion"` +} + +type bedrockAnthropicClaude3Request struct { + AnthropicVersion string `json:"anthropic_version,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Messages []bedrockAnthropicClaude3Message `json:"messages,omitempty"` +} + +type bedrockAnthropicClaude3Message struct { + Role string `json:"role,omitempty"` + Content []bedrockAnthropicClaude3Content `json:"content,omitempty"` +} + +type bedrockAnthropicClaude3Content struct { + // possible values are: image, text + ContentType string `json:"type,omitempty"` + Text *string `json:"text,omitempty"` + Source *bedrockAnthropicClaudeV3Source `json:"source,omitempty"` +} + +type bedrockAnthropicClaude3Response struct { + ID string `json:"id,omitempty"` + ContentType string `json:"type,omitempty"` + Role string `json:"role,omitempty"` + Model string `json:"model,omitempty"` + StopReason string `json:"stop_reason,omitempty"` + Usage bedrockAnthropicClaude3UsageResponse `json:"usage,omitempty"` + Content []bedrockAnthropicClaude3Content `json:"content,omitempty"` +} + +type bedrockAnthropicClaude3UsageResponse struct { + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` +} + +type bedrockAnthropicClaudeV3Source struct { + // possible values are: base64 + ContentType string `json:"type,omitempty"` + // possible values are: image/jpeg + MediaType string `json:"media_type,omitempty"` + // base64 encoded image + Data string `json:"data,omitempty"` +} + type bedrockAI21GenerateRequest struct { Prompt string `json:"prompt,omitempty"` MaxTokens int `json:"maxTokens,omitempty"` @@ -381,6 +604,19 @@ type bedrockAI21GenerateRequest struct { PresencePenalty penalty `json:"presencePenalty,omitempty"` FrequencyPenalty penalty `json:"frequencyPenalty,omitempty"` } + +type bedrockAI21Response struct { + Completions []bedrockAI21Completion `json:"completions,omitempty"` +} + +type bedrockAI21Completion struct { + Data bedrockAI21Data `json:"data,omitempty"` +} + +type bedrockAI21Data struct { + Text string `json:"text,omitempty"` +} + type bedrockCohereRequest struct { Prompt string `json:"prompt,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` @@ -388,6 +624,10 @@ type bedrockCohereRequest struct { ReturnLikeliHood string `json:"return_likelihood,omitempty"` } +type bedrockCohereCommandRRequest struct { + Message string `json:"message,omitempty"` +} + type penalty struct { Scale int `json:"scale,omitempty"` } @@ -410,6 +650,19 @@ type bedrockGenerateResponse struct { Message *string `json:"message,omitempty"` } +type bedrockCohereCommandRResponse struct { + ChatHistory []bedrockCohereChatHistory `json:"chat_history,omitempty"` + ResponseID string `json:"response_id,omitempty"` + GenerationID string `json:"generation_id,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Text string `json:"text,omitempty"` +} + +type bedrockCohereChatHistory struct { + Message string `json:"message,omitempty"` + Role string `json:"role,omitempty"` +} + type sagemakerGenerateResponse struct { Generations []Generation `json:"generations,omitempty"` Message *string `json:"message,omitempty"` @@ -431,3 +684,32 @@ type Result struct { OutputText string `json:"outputText,omitempty"` CompletionReason string `json:"completionReason,omitempty"` } + +type bedrockMistralAIRequest struct { + Prompt string `json:"prompt,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK *int `json:"topK,omitempty"` +} + +type bedrockMistralAIResponse struct { + Outputs []bedrockMistralAIOutput `json:"outputs,omitempty"` +} + +type bedrockMistralAIOutput struct { + Text string `json:"text,omitempty"` +} + +type bedrockMetaRequest struct { + Prompt string `json:"prompt,omitempty"` + MaxGenLen *int `json:"max_gen_len,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` +} + +type bedrockMetaResponse struct { + Generation string `json:"generation,omitempty"` + PromptTokenCount *int `json:"prompt_token_count,omitempty"` + GenerationTokenCount *int `json:"generation_token_count,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} diff --git a/modules/generative-aws/clients/aws_meta.go b/modules/generative-aws/clients/aws_meta.go index f790c14816..070ec92c96 100644 --- a/modules/generative-aws/clients/aws_meta.go +++ b/modules/generative-aws/clients/aws_meta.go @@ -11,7 +11,7 @@ package clients -func (v *aws) MetaInfo() (map[string]interface{}, error) { +func (v *awsClient) MetaInfo() (map[string]interface{}, error) { return map[string]interface{}{ "name": "Generative Search - AWS", "documentationHref": "https://docs.aws.amazon.com/bedrock/latest/APIReference/welcome.html", diff --git a/modules/generative-aws/clients/aws_meta_test.go b/modules/generative-aws/clients/aws_meta_test.go index 8130a7db3e..a1d3caef1e 100644 --- a/modules/generative-aws/clients/aws_meta_test.go +++ b/modules/generative-aws/clients/aws_meta_test.go @@ -24,7 +24,7 @@ func TestGetMeta(t *testing.T) { t.Run("when the server is providing meta", func(t *testing.T) { server := httptest.NewServer(&testMetaHandler{t: t}) defer server.Close() - c := New(server.URL, "", 60*time.Second, nullLogger()) + c := New(server.URL, "", "", 60*time.Second, nullLogger()) meta, err := c.MetaInfo() assert.Nil(t, err) diff --git a/modules/generative-aws/clients/aws_test.go b/modules/generative-aws/clients/aws_test.go index d459228b34..1e71cc4c1e 100644 --- a/modules/generative-aws/clients/aws_test.go +++ b/modules/generative-aws/clients/aws_test.go @@ -41,7 +41,7 @@ func TestGetAnswer(t *testing.T) { server := httptest.NewServer(handler) defer server.Close() - c := &aws{ + c := &awsClient{ httpClient: &http.Client{}, logger: nullLogger(), awsAccessKey: "123", @@ -72,7 +72,7 @@ func TestGetAnswer(t *testing.T) { }) defer server.Close() - c := &aws{ + c := &awsClient{ httpClient: &http.Client{}, logger: nullLogger(), awsAccessKey: "123", diff --git a/modules/generative-aws/config/class_settings.go b/modules/generative-aws/config/class_settings.go index c45a36040d..9fe655ae45 100644 --- a/modules/generative-aws/config/class_settings.go +++ b/modules/generative-aws/config/class_settings.go @@ -61,14 +61,43 @@ var ( DefaultCohereTopP = 1.0 ) +var ( + DefaultMistralAIMaxTokens = 200 + DefaultMistralAITemperature = 0.5 +) + +var ( + DefaultMetaMaxTokens = 512 + DefaultMetaTemperature = 0.5 +) + var availableAWSServices = []string{ DefaultService, "sagemaker", } var availableBedrockModels = []string{ + "ai21.j2-ultra-v1", + "ai21.j2-mid-v1", + "amazon.titan-text-lite-v1", + "amazon.titan-text-express-v1", + "amazon.titan-text-premier-v1:0", + "anthropic.claude-v2", + "anthropic.claude-v2:1", + "anthropic.claude-instant-v1", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", "cohere.command-text-v14", "cohere.command-light-text-v14", + "cohere.command-r-v1:0", + "cohere.command-r-plus-v1:0", + "meta.llama3-8b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", + "meta.llama2-13b-chat-v1", + "meta.llama2-70b-chat-v1", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", } type classSettings struct { @@ -216,6 +245,12 @@ func (ic *classSettings) MaxTokenCount() *int { if isCohereModel(ic.Model()) { return ic.getIntProperty(maxTokenCountProperty, &DefaultCohereMaxTokens) } + if isMistralAIModel(ic.Model()) { + return ic.getIntProperty(maxTokenCountProperty, &DefaultMistralAIMaxTokens) + } + if isMetaModel(ic.Model()) { + return ic.getIntProperty(maxTokenCountProperty, &DefaultMetaMaxTokens) + } } return ic.getIntProperty(maxTokenCountProperty, nil) } @@ -246,6 +281,12 @@ func (ic *classSettings) Temperature() *float64 { if isAI21Model(ic.Model()) { return ic.getFloatProperty(temperatureProperty, &DefaultAI21Temperature) } + if isMistralAIModel(ic.Model()) { + return ic.getFloatProperty(temperatureProperty, &DefaultMistralAITemperature) + } + if isMetaModel(ic.Model()) { + return ic.getFloatProperty(temperatureProperty, &DefaultMetaTemperature) + } } return ic.getFloatProperty(temperatureProperty, nil) } @@ -309,3 +350,11 @@ func isAnthropicModel(model string) bool { func isCohereModel(model string) bool { return strings.HasPrefix(model, "cohere") } + +func isMistralAIModel(model string) bool { + return strings.HasPrefix(model, "mistral") +} + +func isMetaModel(model string) bool { + return strings.HasPrefix(model, "meta") +} diff --git a/modules/generative-aws/config/class_settings_test.go b/modules/generative-aws/config/class_settings_test.go index da5a130e3d..638ed3f940 100644 --- a/modules/generative-aws/config/class_settings_test.go +++ b/modules/generative-aws/config/class_settings_test.go @@ -20,7 +20,6 @@ import ( ) func Test_classSettings_Validate(t *testing.T) { - t.Skip("Skipping this test for now") tests := []struct { name string cfg moduletools.ClassConfig @@ -33,7 +32,7 @@ func Test_classSettings_Validate(t *testing.T) { wantMaxTokenCount int wantStopSequences []string wantTemperature float64 - wantTopP int + wantTopP float64 wantErr error }{ { @@ -144,20 +143,6 @@ func Test_classSettings_Validate(t *testing.T) { }, wantErr: errors.Errorf("topP has to be an integer value between 0 and 1"), }, - { - name: "wrong all", - cfg: fakeClassConfig{ - classConfig: map[string]interface{}{ - "maxTokenCount": 9000, - "temperature": 2, - "topP": 3, - }, - }, - wantErr: errors.Errorf("wrong service, " + - "available services are: [bedrock sagemaker], " + - "region cannot be empty", - ), - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/modules/generative-aws/module.go b/modules/generative-aws/module.go index 2cf46121a4..1813537c26 100644 --- a/modules/generative-aws/module.go +++ b/modules/generative-aws/module.go @@ -67,8 +67,8 @@ func (m *GenerativeAWSModule) initAdditional(ctx context.Context, timeout time.D ) error { awsAccessKey := m.getAWSAccessKey() awsSecret := m.getAWSSecretAccessKey() - - client := clients.New(awsAccessKey, awsSecret, timeout, logger) + awsSessionToken := os.Getenv("AWS_SESSION_TOKEN") + client := clients.New(awsAccessKey, awsSecret, awsSessionToken, timeout, logger) m.generative = client diff --git a/modules/generative-octoai/clients/octoai.go b/modules/generative-octoai/clients/octoai.go index 97aae7327e..74314bcf1d 100644 --- a/modules/generative-octoai/clients/octoai.go +++ b/modules/generative-octoai/clients/octoai.go @@ -69,7 +69,7 @@ func (v *octoai) GenerateAllResults(ctx context.Context, textProperties []map[st func (v *octoai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { settings := config.NewClassSettings(cfg) - octoAIUrl, err := v.getOctoAIUrl(ctx, settings.BaseURL()) + octoAIUrl, isImage, err := v.getOctoAIUrl(ctx, settings.BaseURL()) if err != nil { return nil, errors.Wrap(err, "join OctoAI API host and path") } @@ -78,11 +78,29 @@ func (v *octoai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prom {"role": "user", "content": prompt}, } - input := generateInput{ - Messages: octoAIPrompt, - Model: settings.Model(), - MaxTokens: settings.MaxTokens(), - Temperature: settings.Temperature(), + var input interface{} + if !isImage { + input = generateInputText{ + Messages: octoAIPrompt, + Model: settings.Model(), + MaxTokens: settings.MaxTokens(), + Temperature: settings.Temperature(), + } + } else { + input = generateInputImage{ + Prompt: prompt, + NegativePrompt: "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft", + Sampler: "DDIM", + CfgScale: 11, + Height: 1024, + Width: 1024, + Seed: 0, + Steps: 20, + NumImages: 1, + HighNoiseFrac: 0.7, + Strength: 0.92, + UseRefiner: true, + } } body, err := json.Marshal(input) @@ -125,19 +143,29 @@ func (v *octoai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prom return nil, errors.Errorf("connection to OctoAI API failed with status: %d", res.StatusCode) } - textResponse := resBody.Choices[0].Message.Content + var textResponse string + if isImage { + textResponse = resBody.Images[0].Image + } else { + textResponse = resBody.Choices[0].Message.Content + } return &generativemodels.GenerateResponse{ Result: &textResponse, }, nil } -func (v *octoai) getOctoAIUrl(ctx context.Context, baseURL string) (string, error) { +func (v *octoai) getOctoAIUrl(ctx context.Context, baseURL string) (string, bool, error) { passedBaseURL := baseURL if headerBaseURL := v.getValueFromContext(ctx, "X-Octoai-Baseurl"); headerBaseURL != "" { passedBaseURL = headerBaseURL } - return url.JoinPath(passedBaseURL, "/v1/chat/completions") + if strings.Contains(passedBaseURL, "image") { + urlTmp, err := url.JoinPath(passedBaseURL, "/generate/sdxl") + return urlTmp, true, err + } + urlTmp, err := url.JoinPath(passedBaseURL, "/v1/chat/completions") + return urlTmp, false, err } func (v *octoai) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { @@ -177,24 +205,40 @@ func (v *octoai) getValueFromContext(ctx context.Context, key string) string { } func (v *octoai) getApiKey(ctx context.Context) (string, error) { - if apiKey := v.getValueFromContext(ctx, "X-OctoAI-Api-Key"); apiKey != "" { - return apiKey, nil - } if v.apiKey != "" { return v.apiKey, nil } + if apiKey := modulecomponents.GetValueFromContext(ctx, "X-OctoAI-Api-Key"); apiKey != "" { + return apiKey, nil + } return "", errors.New("no api key found " + "neither in request header: X-OctoAI-Api-Key " + "nor in environment variable under OCTOAI_APIKEY") } -type generateInput struct { +type generateInputText struct { Model string `json:"model"` Messages []map[string]string `json:"messages"` MaxTokens int `json:"max_tokens"` Temperature int `json:"temperature"` } +type generateInputImage struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt"` + Sampler string `json:"sampler"` + CfgScale int `json:"cfg_scale"` + Height int `json:"height"` + Width int `json:"width"` + Seed int `json:"seed"` + Steps int `json:"steps"` + NumImages int `json:"num_images"` + HighNoiseFrac float64 `json:"high_noise_frac"` + Strength float64 `json:"strength"` + UseRefiner bool `json:"use_refiner"` + // StylePreset string `json:"style_preset"` +} + type Message struct { Role string `json:"role"` Content string `json:"content"` @@ -205,9 +249,16 @@ type Choice struct { Index int `json:"index"` FinishReason string `json:"finish_reason"` } + +type Image struct { + Image string `json:"image_b64"` +} + type generateResponse struct { Choices []Choice - Error *octoaiApiError `json:"error,omitempty"` + Images []Image + + Error *octoaiApiError `json:"error,omitempty"` } type octoaiApiError struct { diff --git a/modules/generative-octoai/clients/octoai_test.go b/modules/generative-octoai/clients/octoai_test.go index 4061ff2767..27c85b98c2 100644 --- a/modules/generative-octoai/clients/octoai_test.go +++ b/modules/generative-octoai/clients/octoai_test.go @@ -90,13 +90,23 @@ func TestGetAnswer(t *testing.T) { } }) } - t.Run("when X-Octoai-BaseURL header is passed", func(t *testing.T) { + t.Run("when X-Octoai-BaseURL header is passed for text", func(t *testing.T) { c := New("apiKey", 5*time.Second, nullLogger()) baseUrl := "https://text.octoai.run" - buildURL, err := c.getOctoAIUrl(context.Background(), baseUrl) + buildURL, isImage, err := c.getOctoAIUrl(context.Background(), baseUrl) assert.Equal(t, nil, err) + assert.Equal(t, false, isImage) assert.Equal(t, "https://text.octoai.run/v1/chat/completions", buildURL) }) + + t.Run("when X-Octoai-BaseURL header is passed for image", func(t *testing.T) { + c := New("apiKey", 5*time.Second, nullLogger()) + baseUrl := "https://image.octoai.run" + buildURL, isImage, err := c.getOctoAIUrl(context.Background(), baseUrl) + assert.Equal(t, nil, err) + assert.Equal(t, true, isImage) + assert.Equal(t, "https://image.octoai.run/generate/sdxl", buildURL) + }) } type testAnswerHandler struct { diff --git a/modules/generative-ollama/clients/ollama.go b/modules/generative-ollama/clients/ollama.go index a2ea7f0d34..390ca39da4 100644 --- a/modules/generative-ollama/clients/ollama.go +++ b/modules/generative-ollama/clients/ollama.go @@ -68,7 +68,7 @@ func (v *ollama) Generate(ctx context.Context, cfg moduletools.ClassConfig, prom ollamaUrl := v.getOllamaUrl(ctx, settings.ApiEndpoint()) input := generateInput{ - Model: settings.ModelID(), + Model: settings.Model(), Prompt: prompt, Stream: false, } diff --git a/modules/generative-ollama/config/class_settings.go b/modules/generative-ollama/config/class_settings.go index d9719a25d2..a520a7af69 100644 --- a/modules/generative-ollama/config/class_settings.go +++ b/modules/generative-ollama/config/class_settings.go @@ -20,12 +20,12 @@ import ( const ( apiEndpointProperty = "apiEndpoint" - modelIDProperty = "modelId" + modelProperty = "model" ) const ( DefaultApiEndpoint = "http://localhost:11434" - DefaultModelID = "llama3" + DefaultModel = "llama3" ) type classSettings struct { @@ -45,9 +45,9 @@ func (ic *classSettings) Validate(class *models.Class) error { if ic.ApiEndpoint() == "" { return errors.New("apiEndpoint cannot be empty") } - model := ic.ModelID() + model := ic.Model() if model == "" { - return errors.New("modelId cannot be empty") + return errors.New("model cannot be empty") } return nil } @@ -60,6 +60,6 @@ func (ic *classSettings) ApiEndpoint() string { return ic.getStringProperty(apiEndpointProperty, DefaultApiEndpoint) } -func (ic *classSettings) ModelID() string { - return ic.getStringProperty(modelIDProperty, DefaultModelID) +func (ic *classSettings) Model() string { + return ic.getStringProperty(modelProperty, DefaultModel) } diff --git a/modules/generative-ollama/config/class_settings_test.go b/modules/generative-ollama/config/class_settings_test.go index 6e217bd637..83b818696b 100644 --- a/modules/generative-ollama/config/class_settings_test.go +++ b/modules/generative-ollama/config/class_settings_test.go @@ -41,7 +41,7 @@ func Test_classSettings_Validate(t *testing.T) { name: "everything non default configured", cfg: fakeClassConfig{ classConfig: map[string]interface{}{ - "modelId": "mistral", + "model": "mistral", }, }, wantApiEndpoint: "http://localhost:11434", @@ -52,10 +52,10 @@ func Test_classSettings_Validate(t *testing.T) { name: "empty model", cfg: fakeClassConfig{ classConfig: map[string]interface{}{ - "modelId": "", + "model": "", }, }, - wantErr: errors.New("modelId cannot be empty"), + wantErr: errors.New("model cannot be empty"), }, } for _, tt := range tests { @@ -66,7 +66,7 @@ func Test_classSettings_Validate(t *testing.T) { require.Error(t, err) assert.Equal(t, tt.wantErr.Error(), err.Error()) } else { - assert.Equal(t, tt.wantModel, ic.ModelID()) + assert.Equal(t, tt.wantModel, ic.Model()) } }) } diff --git a/modules/text2vec-aws/clients/aws.go b/modules/text2vec-aws/clients/aws.go index 2210eedaaf..643eea511e 100644 --- a/modules/text2vec-aws/clients/aws.go +++ b/modules/text2vec-aws/clients/aws.go @@ -22,6 +22,10 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/google/uuid" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -50,19 +54,21 @@ func buildSagemakerUrl(service, region, endpoint string) string { return fmt.Sprintf(urlTemplate, service, region, endpoint) } -type aws struct { +type awsClient struct { awsAccessKey string awsSecret string + awsSessionToken string buildBedrockUrlFn func(service, region, model string) string buildSagemakerUrlFn func(service, region, endpoint string) string httpClient *http.Client logger logrus.FieldLogger } -func New(awsAccessKey string, awsSecret string, timeout time.Duration, logger logrus.FieldLogger) *aws { - return &aws{ - awsAccessKey: awsAccessKey, - awsSecret: awsSecret, +func New(awsAccessKey, awsSecret, awsSessionToken string, timeout time.Duration, logger logrus.FieldLogger) *awsClient { + return &awsClient{ + awsAccessKey: awsAccessKey, + awsSecret: awsSecret, + awsSessionToken: awsSessionToken, httpClient: &http.Client{ Timeout: timeout, }, @@ -72,22 +78,21 @@ func New(awsAccessKey string, awsSecret string, timeout time.Duration, logger lo } } -func (v *aws) Vectorize(ctx context.Context, input []string, +func (v *awsClient) Vectorize(ctx context.Context, input []string, config ent.VectorizationConfig, ) (*ent.VectorizationResult, error) { return v.vectorize(ctx, input, vectorizeObject, config) } -func (v *aws) VectorizeQuery(ctx context.Context, input []string, +func (v *awsClient) VectorizeQuery(ctx context.Context, input []string, config ent.VectorizationConfig, ) (*ent.VectorizationResult, error) { return v.vectorize(ctx, input, vectorizeQuery, config) } -func (v *aws) vectorize(ctx context.Context, input []string, operation operationType, config ent.VectorizationConfig) (*ent.VectorizationResult, error) { +func (v *awsClient) vectorize(ctx context.Context, input []string, operation operationType, config ent.VectorizationConfig) (*ent.VectorizationResult, error) { service := v.getService(config) region := v.getRegion(config) - model := v.getModel(config) endpoint := v.getEndpoint(config) targetModel := v.getTargetModel(config) targetVariant := v.getTargetVariant(config) @@ -103,20 +108,7 @@ func (v *aws) vectorize(ctx context.Context, input []string, operation operation "content-type": contentType, } - if v.isBedrock(service) { - endpointUrl = v.buildBedrockUrlFn(service, region, model) - host, path, _ = extractHostAndPath(endpointUrl) - - req, err := createRequestBody(model, input, operation) - if err != nil { - return nil, err - } - - body, err = json.Marshal(req) - if err != nil { - return nil, errors.Wrapf(err, "marshal body") - } - } else if v.isSagemaker(service) { + if v.isSagemaker(service) { endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint) host = "runtime." + service + "." + region + ".amazonaws.com" path = "/endpoints/" + endpoint + "/invocations" @@ -132,8 +124,6 @@ func (v *aws) vectorize(ctx context.Context, input []string, operation operation if err != nil { return nil, errors.Wrapf(err, "marshal body") } - } else { - return nil, errors.Wrapf(err, "service error") } accessKey, err := v.getAwsAccessKey(ctx) @@ -144,39 +134,98 @@ func (v *aws) vectorize(ctx context.Context, input []string, operation operation if err != nil { return nil, errors.Wrapf(err, "AWS Secret Key") } + awsSessionToken, err := v.getAwsSessionToken(ctx) + if err != nil { + return nil, err + } + maxRetries := 5 - headers["host"] = host - amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers) - headers["Authorization"] = authorizationHeader - headers["x-amz-date"] = amzDate + if v.isBedrock(service) { + return v.sendBedrockRequest(ctx, input, operation, maxRetries, accessKey, secretKey, awsSessionToken, config) + } else { + headers["host"] = host + amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers) + headers["Authorization"] = authorizationHeader + headers["x-amz-date"] = amzDate - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body)) - if err != nil { - return nil, errors.Wrap(err, "create POST request") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrap(err, "create POST request") + } + + for k, v := range headers { + req.Header.Set(k, v) + } + + res, err := v.makeRequest(req, 30, maxRetries) + if err != nil { + return nil, errors.Wrap(err, "send POST request") + } + defer res.Body.Close() + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, errors.Wrap(err, "read response body") + } + return v.parseSagemakerResponse(bodyBytes, res, input) } +} - for k, v := range headers { - req.Header.Set(k, v) +func (v *awsClient) sendBedrockRequest(ctx context.Context, + input []string, + operation operationType, + maxRetries int, + awsKey, awsSecret, awsSessionToken string, + cfg ent.VectorizationConfig, +) (*ent.VectorizationResult, error) { + model := cfg.Model + region := cfg.Region + + req, err := createRequestBody(model, input, operation) + if err != nil { + return nil, fmt.Errorf("failed to create request for model %s: %w", model, err) } - res, err := v.makeRequest(req, 30, 5) + body, err := json.Marshal(req) if err != nil { - return nil, errors.Wrap(err, "send POST request") + return nil, fmt.Errorf("failed to marshal request for model %s: %w", model, err) } - defer res.Body.Close() - bodyBytes, err := io.ReadAll(res.Body) + sdkConfig, err := config.LoadDefaultConfig(ctx, + config.WithRegion(region), + config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(awsKey, awsSecret, awsSessionToken), + ), + config.WithRetryMaxAttempts(maxRetries), + ) if err != nil { - return nil, errors.Wrap(err, "read response body") + return nil, fmt.Errorf("failed to load AWS configuration: %w", err) } - if v.isBedrock(service) { - return v.parseBedrockResponse(bodyBytes, res, input) - } else { - return v.parseSagemakerResponse(bodyBytes, res, input) + + client := bedrockruntime.NewFromConfig(sdkConfig) + result, err := client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(model), + ContentType: aws.String("application/json"), + Body: body, + }) + if err != nil { + errMsg := err.Error() + if strings.Contains(errMsg, "no such host") { + return nil, fmt.Errorf("Bedrock service is not available in the selected region. " + + "Please double-check the service availability for your region at " + + "https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/") + } else if strings.Contains(errMsg, "Could not resolve the foundation model") { + return nil, fmt.Errorf("Could not resolve the foundation model from model identifier: \"%v\". "+ + "Please verify that the requested model exists and is accessible within the specified region", model) + } else { + return nil, fmt.Errorf("Couldn't invoke %s model: %w", model, err) + } } + + return v.parseBedrockResponse(result.Body, input) } -func (v *aws) makeRequest(req *http.Request, delayInSeconds int, maxRetries int) (*http.Response, error) { +func (v *awsClient) makeRequest(req *http.Request, delayInSeconds int, maxRetries int) (*http.Response, error) { var res *http.Response var err error @@ -201,33 +250,16 @@ func (v *aws) makeRequest(req *http.Request, delayInSeconds int, maxRetries int) // Double the delay for the next iteration delayInSeconds *= 2 - } return res, err } -func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) { - var resBodyMap map[string]interface{} - if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil { - return nil, errors.Wrap(err, "unmarshal response body") - } - - // if resBodyMap has inputTextTokenCount, it's a resonse from an Amazon model - // otherwise, it is a response from a Cohere model +func (v *awsClient) parseBedrockResponse(bodyBytes []byte, input []string) (*ent.VectorizationResult, error) { var resBody bedrockEmbeddingResponse if err := json.Unmarshal(bodyBytes, &resBody); err != nil { return nil, errors.Wrap(err, "unmarshal response body") } - - if res.StatusCode != 200 || resBody.Message != nil { - if resBody.Message != nil { - return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s", - res.StatusCode, *resBody.Message) - } - return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode) - } - if len(resBody.Embedding) == 0 && len(resBody.Embeddings) == 0 { return nil, fmt.Errorf("could not obtain vector from AWS Bedrock") } @@ -244,7 +276,7 @@ func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response, input [ }, nil } -func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) { +func (v *awsClient) parseSagemakerResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) { var resBody sagemakerEmbeddingResponse if err := json.Unmarshal(bodyBytes, &resBody); err != nil { return nil, errors.Wrap(err, "unmarshal response body") @@ -269,71 +301,78 @@ func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response, input }, nil } -func (v *aws) isSagemaker(service string) bool { +func (v *awsClient) isSagemaker(service string) bool { return service == "sagemaker" } -func (v *aws) isBedrock(service string) bool { +func (v *awsClient) isBedrock(service string) bool { return service == "bedrock" } -func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) { - awsAccessKey := ctx.Value("X-Aws-Access-Key") - if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok && - len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 { - return awsAccessKeyHeader[0], nil +func (v *awsClient) getAwsAccessKey(ctx context.Context) (string, error) { + if awsAccessKey := v.getHeaderValue(ctx, "X-Aws-Access-Key"); awsAccessKey != "" { + return awsAccessKey, nil } - if len(v.awsAccessKey) > 0 { + if v.awsAccessKey != "" { return v.awsAccessKey, nil } - // try getting header from GRPC if not successful - if accessKey := modulecomponents.GetValueFromGRPC(ctx, "X-Aws-Access-Key"); len(accessKey) > 0 && len(accessKey[0]) > 0 { - return accessKey[0], nil - } return "", errors.New("no access key found " + "neither in request header: X-AWS-Access-Key " + "nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY") } -func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) { - awsSecretKey := ctx.Value("X-Aws-Secret-Key") - if awsAccessSecretHeader, ok := awsSecretKey.([]string); ok && - len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 { - return awsAccessSecretHeader[0], nil +func (v *awsClient) getAwsAccessSecret(ctx context.Context) (string, error) { + if awsSecret := v.getHeaderValue(ctx, "X-Aws-Secret-Key"); awsSecret != "" { + return awsSecret, nil } - if len(v.awsSecret) > 0 { + if v.awsSecret != "" { return v.awsSecret, nil } - // try getting header from GRPC if not successful - if secretKey := modulecomponents.GetValueFromGRPC(ctx, "X-Aws-Secret-Key"); len(secretKey) > 0 && len(secretKey[0]) > 0 { - return secretKey[0], nil - } return "", errors.New("no secret found " + "neither in request header: X-AWS-Secret-Key " + "nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY") } -func (v *aws) getModel(config ent.VectorizationConfig) string { - return config.Model +func (v *awsClient) getAwsSessionToken(ctx context.Context) (string, error) { + if awsSessionToken := v.getHeaderValue(ctx, "X-Aws-Session-Token"); awsSessionToken != "" { + return awsSessionToken, nil + } + if v.awsSessionToken != "" { + return v.awsSessionToken, nil + } + return "", nil +} + +func (v *awsClient) getHeaderValue(ctx context.Context, header string) string { + headerValue := ctx.Value(header) + if value, ok := headerValue.([]string); ok && + len(value) > 0 && len(value[0]) > 0 { + return value[0] + } + // try getting header from GRPC if not successful + if value := modulecomponents.GetValueFromGRPC(ctx, header); len(value) > 0 && len(value[0]) > 0 { + return value[0] + } + return "" } -func (v *aws) getRegion(config ent.VectorizationConfig) string { +func (v *awsClient) getRegion(config ent.VectorizationConfig) string { return config.Region } -func (v *aws) getService(config ent.VectorizationConfig) string { +func (v *awsClient) getService(config ent.VectorizationConfig) string { return config.Service } -func (v *aws) getEndpoint(config ent.VectorizationConfig) string { +func (v *awsClient) getEndpoint(config ent.VectorizationConfig) string { return config.Endpoint } -func (v *aws) getTargetModel(config ent.VectorizationConfig) string { +func (v *awsClient) getTargetModel(config ent.VectorizationConfig) string { return config.TargetModel } -func (v *aws) getTargetVariant(config ent.VectorizationConfig) string { +func (v *awsClient) getTargetVariant(config ent.VectorizationConfig) string { return config.TargetVariant } diff --git a/modules/text2vec-aws/clients/aws_test.go b/modules/text2vec-aws/clients/aws_test.go index 0f346e706c..6f7bf93a0a 100644 --- a/modules/text2vec-aws/clients/aws_test.go +++ b/modules/text2vec-aws/clients/aws_test.go @@ -35,7 +35,7 @@ func TestClient(t *testing.T) { t.Skip("Skipping this test for now") server := httptest.NewServer(&fakeHandler{t: t}) defer server.Close() - c := &aws{ + c := &awsClient{ httpClient: &http.Client{}, logger: nullLogger(), awsAccessKey: "access_key", @@ -66,7 +66,7 @@ func TestClient(t *testing.T) { t.Run("when all is fine - Sagemaker", func(t *testing.T) { server := httptest.NewServer(&fakeHandler{t: t}) defer server.Close() - c := &aws{ + c := &awsClient{ httpClient: &http.Client{}, logger: nullLogger(), awsAccessKey: "access_key", @@ -101,7 +101,7 @@ func TestClient(t *testing.T) { serverError: errors.Errorf("nope, not gonna happen"), }) defer server.Close() - c := &aws{ + c := &awsClient{ httpClient: &http.Client{}, logger: nullLogger(), awsAccessKey: "access_key", @@ -126,7 +126,7 @@ func TestClient(t *testing.T) { t.Skip("Skipping this test for now") server := httptest.NewServer(&fakeHandler{t: t}) defer server.Close() - c := &aws{ + c := &awsClient{ httpClient: &http.Client{}, logger: nullLogger(), awsAccessKey: "access_key", @@ -158,7 +158,7 @@ func TestClient(t *testing.T) { t.Skip("Skipping this test for now") server := httptest.NewServer(&fakeHandler{t: t}) defer server.Close() - c := &aws{ + c := &awsClient{ httpClient: &http.Client{}, logger: nullLogger(), awsAccessKey: "", @@ -186,7 +186,7 @@ func TestClient(t *testing.T) { t.Skip("Skipping this test for now") server := httptest.NewServer(&fakeHandler{t: t}) defer server.Close() - c := &aws{ + c := &awsClient{ httpClient: &http.Client{}, logger: nullLogger(), awsAccessKey: "123", @@ -282,7 +282,7 @@ func TestVectorize(t *testing.T) { awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID_AMAZON") awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY_AMAZON") - aws := New(awsAccessKeyID, awsSecretAccessKey, 60*time.Second, nil) + aws := New(awsAccessKeyID, awsSecretAccessKey, "sessionToken", 60*time.Second, nil) _, err := aws.Vectorize(ctx, input, config) if err != nil { @@ -301,7 +301,7 @@ func TestVectorize(t *testing.T) { awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID_COHERE") awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY_COHERE") - aws := New(awsAccessKeyID, awsSecretAccessKey, 60*time.Second, nil) + aws := New(awsAccessKeyID, awsSecretAccessKey, "sessionToken", 60*time.Second, nil) _, err := aws.Vectorize(ctx, input, config) if err != nil { diff --git a/modules/text2vec-aws/clients/meta.go b/modules/text2vec-aws/clients/meta.go index d64031ed53..98e868b327 100644 --- a/modules/text2vec-aws/clients/meta.go +++ b/modules/text2vec-aws/clients/meta.go @@ -11,7 +11,7 @@ package clients -func (v *aws) MetaInfo() (map[string]interface{}, error) { +func (v *awsClient) MetaInfo() (map[string]interface{}, error) { return map[string]interface{}{ "name": "AWS Module", "documentationHref": "https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings", diff --git a/modules/text2vec-aws/module.go b/modules/text2vec-aws/module.go index 59845f9359..dfc07a97d3 100644 --- a/modules/text2vec-aws/module.go +++ b/modules/text2vec-aws/module.go @@ -94,7 +94,8 @@ func (m *AwsModule) initVectorizer(ctx context.Context, timeout time.Duration, ) error { awsAccessKey := m.getAWSAccessKey() awsSecret := m.getAWSSecretAccessKey() - client := clients.New(awsAccessKey, awsSecret, timeout, logger) + awsSessionToken := os.Getenv("AWS_SESSION_TOKEN") + client := clients.New(awsAccessKey, awsSecret, awsSessionToken, timeout, logger) m.vectorizer = vectorizer.New(client) m.metaProvider = client diff --git a/modules/text2vec-aws/vectorizer/class_settings.go b/modules/text2vec-aws/vectorizer/class_settings.go index 214847a553..479e63754d 100644 --- a/modules/text2vec-aws/vectorizer/class_settings.go +++ b/modules/text2vec-aws/vectorizer/class_settings.go @@ -46,6 +46,7 @@ var availableAWSServices = []string{ var availableAWSBedrockModels = []string{ "amazon.titan-embed-text-v1", + "amazon.titan-embed-text-v2:0", "cohere.embed-english-v3", "cohere.embed-multilingual-v3", } diff --git a/modules/text2vec-aws/vectorizer/class_settings_test.go b/modules/text2vec-aws/vectorizer/class_settings_test.go index a2af0c3f93..582581d782 100644 --- a/modules/text2vec-aws/vectorizer/class_settings_test.go +++ b/modules/text2vec-aws/vectorizer/class_settings_test.go @@ -90,7 +90,7 @@ func Test_classSettings_Validate(t *testing.T) { "model": "wrong-model", }, }, - wantErr: errors.Errorf("wrong model, available models are: [amazon.titan-embed-text-v1 cohere.embed-english-v3 cohere.embed-multilingual-v3]"), + wantErr: errors.Errorf("wrong model, available models are: [amazon.titan-embed-text-v1 amazon.titan-embed-text-v2:0 cohere.embed-english-v3 cohere.embed-multilingual-v3]"), }, { name: "all wrong", diff --git a/modules/text2vec-huggingface/clients/fakes_for_test.go b/modules/text2vec-huggingface/clients/fakes_for_test.go new file mode 100644 index 0000000000..a068f7c314 --- /dev/null +++ b/modules/text2vec-huggingface/clients/fakes_for_test.go @@ -0,0 +1,54 @@ +// _ _ +// __ _____ __ ___ ___ __ _| |_ ___ +// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ +// \ V V / __/ (_| |\ V /| | (_| | || __/ +// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| +// +// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. +// +// CONTACT: hello@weaviate.io +// + +package clients + +type fakeClassConfig struct { + classConfig map[string]interface{} + vectorizePropertyName bool + skippedProperty string + excludedProperty string +} + +func (f fakeClassConfig) Class() map[string]interface{} { + return f.classConfig +} + +func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { + return f.classConfig +} + +func (f fakeClassConfig) Property(propName string) map[string]interface{} { + if propName == f.skippedProperty { + return map[string]interface{}{ + "skip": true, + } + } + if propName == f.excludedProperty { + return map[string]interface{}{ + "vectorizePropertyName": false, + } + } + if f.vectorizePropertyName { + return map[string]interface{}{ + "vectorizePropertyName": true, + } + } + return nil +} + +func (f fakeClassConfig) Tenant() string { + return "" +} + +func (f fakeClassConfig) TargetVector() string { + return "" +} diff --git a/modules/text2vec-huggingface/clients/huggingface.go b/modules/text2vec-huggingface/clients/huggingface.go index 5afa96e83e..bcfd766f93 100644 --- a/modules/text2vec-huggingface/clients/huggingface.go +++ b/modules/text2vec-huggingface/clients/huggingface.go @@ -14,12 +14,15 @@ package clients import ( "bytes" "context" + "crypto/sha256" "encoding/json" "fmt" "io" "net/http" "time" + "github.com/weaviate/weaviate/entities/moduletools" + "github.com/weaviate/weaviate/usecases/modulecomponents" "github.com/pkg/errors" @@ -32,6 +35,13 @@ const ( DefaultPath = "pipeline/feature-extraction" ) +// there are no explicit rate limits: https://huggingface.co/docs/api-inference/en/faq#rate-limits +// so we set values that work and leave it up to the users to increase these values +const ( + DefaultRPM = 100 // + DefaultTPM = 10000000 // no token limit +) + type embeddingsRequest struct { Inputs []string `json:"inputs"` Options *options `json:"options,omitempty"` @@ -75,23 +85,37 @@ func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vecto } } -func (v *vectorizer) Vectorize(ctx context.Context, input string, - config ent.VectorizationConfig, -) (*ent.VectorizationResult, error) { - return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config)) +func (v *vectorizer) Vectorize(ctx context.Context, input []string, + cfg moduletools.ClassConfig, +) (*modulecomponents.VectorizationResult, *modulecomponents.RateLimits, error) { + config := v.getVectorizationConfig(cfg) + res, err := v.vectorize(ctx, v.getURL(config), input, v.getOptions(config)) + return res, nil, err } -func (v *vectorizer) VectorizeQuery(ctx context.Context, input string, - config ent.VectorizationConfig, -) (*ent.VectorizationResult, error) { +func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string, + cfg moduletools.ClassConfig, +) (*modulecomponents.VectorizationResult, error) { + config := v.getVectorizationConfig(cfg) return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config)) } +func (v *vectorizer) getVectorizationConfig(cfg moduletools.ClassConfig) ent.VectorizationConfig { + icheck := ent.NewClassSettings(cfg) + return ent.VectorizationConfig{ + EndpointURL: icheck.EndpointURL(), + Model: icheck.PassageModel(), + WaitForModel: icheck.OptionWaitForModel(), + UseGPU: icheck.OptionUseGPU(), + UseCache: icheck.OptionUseCache(), + } +} + func (v *vectorizer) vectorize(ctx context.Context, url string, - input string, options options, -) (*ent.VectorizationResult, error) { + input []string, options options, +) (*modulecomponents.VectorizationResult, error) { body, err := json.Marshal(embeddingsRequest{ - Inputs: []string{input}, + Inputs: input, Options: &options, }) if err != nil { @@ -103,7 +127,7 @@ func (v *vectorizer) vectorize(ctx context.Context, url string, if err != nil { return nil, errors.Wrap(err, "create POST request") } - if apiKey := v.getApiKey(ctx); apiKey != "" { + if apiKey, err := v.getApiKey(ctx); apiKey != "" && err == nil { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) } req.Header.Add("Content-Type", "application/json") @@ -123,15 +147,16 @@ func (v *vectorizer) vectorize(ctx context.Context, url string, return nil, err } - vector, err := v.decodeVector(bodyBytes) + vector, errs, err := v.decodeVector(bodyBytes) if err != nil { return nil, errors.Wrap(err, "cannot decode vector") } - return &ent.VectorizationResult{ + return &modulecomponents.VectorizationResult{ Text: input, - Dimensions: len(vector), + Dimensions: len(vector[0]), Vector: vector, + Errors: errs, }, nil } @@ -163,52 +188,88 @@ func checkResponse(res *http.Response, bodyBytes []byte) error { return errors.New(message) } -func (v *vectorizer) decodeVector(bodyBytes []byte) ([]float32, error) { +func (v *vectorizer) decodeVector(bodyBytes []byte) ([][]float32, []error, error) { var emb embedding if err := json.Unmarshal(bodyBytes, &emb); err != nil { var embObject embeddingObject if err := json.Unmarshal(bodyBytes, &embObject); err != nil { var embBert embeddingBert if err := json.Unmarshal(bodyBytes, &embBert); err != nil { - return nil, errors.Wrap(err, "unmarshal response body") + return nil, nil, errors.Wrap(err, "unmarshal response body") } - if len(embBert) == 1 && len(embBert[0]) == 1 { - return v.bertEmbeddingsDecoder.calculateVector(embBert[0][0]) + if len(embBert) == 1 && len(embBert[0]) > 0 { + vectors := make([][]float32, len(embBert[0])) + errs := make([]error, len(embBert[0])) + for i, embBer := range embBert[0] { + vectors[i], errs[i] = v.bertEmbeddingsDecoder.calculateVector(embBer) + } + return vectors, errs, nil } - return nil, errors.New("unprocessable response body") + return nil, nil, errors.New("unprocessable response body") } - if len(embObject.Embeddings) == 1 { - return embObject.Embeddings[0], nil + if len(embObject.Embeddings) > 0 { + return embObject.Embeddings, nil, nil } - return nil, errors.New("unprocessable response body") + return nil, nil, errors.New("unprocessable response body") } - if len(emb) == 1 { - return emb[0], nil + if len(emb) > 0 { + return emb, nil, nil } - return nil, errors.New("unprocessable response body") + return nil, nil, errors.New("unprocessable response body") } -func (v *vectorizer) getApiKey(ctx context.Context) string { - if len(v.apiKey) > 0 { - return v.apiKey +func (v *vectorizer) GetApiKeyHash(ctx context.Context, config moduletools.ClassConfig) [32]byte { + key, err := v.getApiKey(ctx) + if err != nil { + return [32]byte{} } - key := "X-Huggingface-Api-Key" - apiKey := ctx.Value(key) - // try getting header from GRPC if not successful - if apiKey == nil { - apiKey = modulecomponents.GetValueFromGRPC(ctx, key) + return sha256.Sum256([]byte(key)) +} + +func (v *vectorizer) GetVectorizerRateLimit(ctx context.Context) *modulecomponents.RateLimits { + rpm, _ := modulecomponents.GetRateLimitFromContext(ctx, "Cohere", DefaultRPM, 0) + + execAfterRequestFunction := func(limits *modulecomponents.RateLimits, tokensUsed int, deductRequest bool) { + // refresh is after 60 seconds but leave a bit of room for errors. Otherwise, we only deduct the request that just happened + if limits.LastOverwrite.Add(61 * time.Second).After(time.Now()) { + if deductRequest { + limits.RemainingRequests -= 1 + } + return + } + + limits.RemainingRequests = rpm + limits.ResetRequests = time.Now().Add(time.Duration(61) * time.Second) + limits.LimitRequests = rpm + limits.LastOverwrite = time.Now() + + // high dummy values + limits.RemainingTokens = DefaultTPM + limits.LimitTokens = DefaultTPM + limits.ResetTokens = time.Now().Add(time.Duration(1) * time.Second) } - if apiKeyHeader, ok := apiKey.([]string); ok && - len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 { - return apiKeyHeader[0] + initialRL := &modulecomponents.RateLimits{AfterRequestFunction: execAfterRequestFunction, LastOverwrite: time.Now().Add(-61 * time.Minute)} + initialRL.ResetAfterRequestFunction(0) // set initial values + + return initialRL +} + +func (v *vectorizer) getApiKey(ctx context.Context) (string, error) { + if apiKey := modulecomponents.GetValueFromContext(ctx, "X-Huggingface-Api-Key"); apiKey != "" { + return apiKey, nil + } + if v.apiKey != "" { + return v.apiKey, nil } - return "" + return "", errors.New("no api key found " + + "neither in request header: X-Huggingface-Api-Key " + + "nor in environment variable under HUGGINGFACE_APIKEY") } func (v *vectorizer) getOptions(config ent.VectorizationConfig) options { diff --git a/modules/text2vec-huggingface/clients/huggingface_test.go b/modules/text2vec-huggingface/clients/huggingface_test.go index f4c4e82e19..65e044e4f6 100644 --- a/modules/text2vec-huggingface/clients/huggingface_test.go +++ b/modules/text2vec-huggingface/clients/huggingface_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" + "github.com/weaviate/weaviate/usecases/modulecomponents" ) func TestClient(t *testing.T) { @@ -37,19 +38,19 @@ func TestClient(t *testing.T) { httpClient: &http.Client{}, logger: nullLogger(), } - expected := &ent.VectorizationResult{ - Text: "This is my text", - Vector: []float32{0.1, 0.2, 0.3}, + expected := &modulecomponents.VectorizationResult{ + Text: []string{"This is my text"}, + Vector: [][]float32{{0.1, 0.2, 0.3}}, Dimensions: 3, } - res, err := c.Vectorize(context.Background(), "This is my text", - ent.VectorizationConfig{ - Model: "sentence-transformers/gtr-t5-xxl", - WaitForModel: false, - UseGPU: false, - UseCache: true, - EndpointURL: server.URL, - }) + res, _, err := c.Vectorize(context.Background(), []string{"This is my text"}, + fakeClassConfig{classConfig: map[string]interface{}{ + "Model": "sentence-transformers/gtr-t5-xxl", + "endpointURL": server.URL, + "WaitForModel": false, + "UseGPU": false, + "UseCache": true, + }}) assert.Nil(t, err) assert.Equal(t, expected, res) @@ -65,10 +66,8 @@ func TestClient(t *testing.T) { } ctx, cancel := context.WithDeadline(context.Background(), time.Now()) defer cancel() - - _, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{ - EndpointURL: server.URL, - }) + _, _, err := c.Vectorize(ctx, []string{"This is my text"}, + fakeClassConfig{classConfig: map[string]interface{}{"endpointURL": server.URL}}) require.NotNil(t, err) assert.Contains(t, err.Error(), "context deadline exceeded") @@ -85,10 +84,8 @@ func TestClient(t *testing.T) { httpClient: &http.Client{}, logger: nullLogger(), } - _, err := c.Vectorize(context.Background(), "This is my text", - ent.VectorizationConfig{ - EndpointURL: server.URL, - }) + _, _, err := c.Vectorize(context.Background(), []string{"This is my text"}, + fakeClassConfig{classConfig: map[string]interface{}{"endpointURL": server.URL}}) require.NotNil(t, err) assert.Equal(t, err.Error(), "connection to HuggingFace failed with status: 500 error: nope, not gonna happen estimated time: 20") @@ -105,19 +102,20 @@ func TestClient(t *testing.T) { ctxWithValue := context.WithValue(context.Background(), "X-Huggingface-Api-Key", []string{"some-key"}) - expected := &ent.VectorizationResult{ - Text: "This is my text", - Vector: []float32{0.1, 0.2, 0.3}, + expected := &modulecomponents.VectorizationResult{ + Text: []string{"This is my text"}, + Vector: [][]float32{{0.1, 0.2, 0.3}}, Dimensions: 3, } - res, err := c.Vectorize(ctxWithValue, "This is my text", - ent.VectorizationConfig{ - Model: "sentence-transformers/gtr-t5-xxl", - WaitForModel: true, - UseGPU: false, - UseCache: true, - EndpointURL: server.URL, - }) + + res, _, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, + fakeClassConfig{classConfig: map[string]interface{}{ + "Model": "sentence-transformers/gtr-t5-xxl", + "endpointURL": server.URL, + "WaitForModel": true, + "UseGPU": false, + "UseCache": true, + }}) require.Nil(t, err) assert.Equal(t, expected, res) @@ -137,12 +135,11 @@ func TestClient(t *testing.T) { ctxWithValue := context.WithValue(context.Background(), "X-Huggingface-Api-Key", []string{""}) - _, err := c.Vectorize(ctxWithValue, "This is my text", - ent.VectorizationConfig{ - Model: "sentence-transformers/gtr-t5-xxl", - EndpointURL: server.URL, - }) - + _, _, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, + fakeClassConfig{classConfig: map[string]interface{}{ + "Model": "sentence-transformers/gtr-t5-xxl", + "endpointURL": server.URL, + }}) require.NotNil(t, err) assert.Equal(t, err.Error(), "failed with status: 401 error: A valid user or organization token is required") }) @@ -158,10 +155,8 @@ func TestClient(t *testing.T) { httpClient: &http.Client{}, logger: nullLogger(), } - _, err := c.Vectorize(context.Background(), "This is my text", - ent.VectorizationConfig{ - EndpointURL: server.URL, - }) + _, _, err := c.Vectorize(context.Background(), []string{"This is my text"}, + fakeClassConfig{classConfig: map[string]interface{}{"endpointURL": server.URL}}) require.NotNil(t, err) assert.Equal(t, err.Error(), "connection to HuggingFace failed with status: 500 error: with warnings "+ diff --git a/modules/text2vec-huggingface/config.go b/modules/text2vec-huggingface/config.go index b72af343f0..1c94738b63 100644 --- a/modules/text2vec-huggingface/config.go +++ b/modules/text2vec-huggingface/config.go @@ -14,16 +14,17 @@ package modhuggingface import ( "context" + "github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" + "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/modulecapabilities" "github.com/weaviate/weaviate/entities/moduletools" "github.com/weaviate/weaviate/entities/schema" - "github.com/weaviate/weaviate/modules/text2vec-huggingface/vectorizer" ) func (m *HuggingFaceModule) ClassConfigDefaults() map[string]interface{} { return map[string]interface{}{ - "vectorizeClassName": vectorizer.DefaultVectorizeClassName, + "vectorizeClassName": ent.DefaultVectorizeClassName, } } @@ -31,15 +32,15 @@ func (m *HuggingFaceModule) PropertyConfigDefaults( dt *schema.DataType, ) map[string]interface{} { return map[string]interface{}{ - "skip": !vectorizer.DefaultPropertyIndexed, - "vectorizePropertyName": vectorizer.DefaultVectorizePropertyName, + "skip": !ent.DefaultPropertyIndexed, + "vectorizePropertyName": ent.DefaultVectorizePropertyName, } } func (m *HuggingFaceModule) ValidateClass(ctx context.Context, class *models.Class, cfg moduletools.ClassConfig, ) error { - settings := vectorizer.NewClassSettings(cfg) + settings := ent.NewClassSettings(cfg) return settings.Validate(class) } diff --git a/modules/text2vec-huggingface/vectorizer/class_settings.go b/modules/text2vec-huggingface/ent/class_settings.go similarity index 95% rename from modules/text2vec-huggingface/vectorizer/class_settings.go rename to modules/text2vec-huggingface/ent/class_settings.go index 2f867ce9df..2d321c8802 100644 --- a/modules/text2vec-huggingface/vectorizer/class_settings.go +++ b/modules/text2vec-huggingface/ent/class_settings.go @@ -9,11 +9,10 @@ // CONTACT: hello@weaviate.io // -package vectorizer +package ent import ( "github.com/pkg/errors" - "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/moduletools" basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings" @@ -71,11 +70,7 @@ func (cs *classSettings) OptionUseCache() bool { } func (cs *classSettings) Validate(class *models.Class) error { - return cs.BaseClassSettings.Validate(class) -} - -func (cs *classSettings) validateClassSettings() error { - if err := cs.BaseClassSettings.ValidateClassSettings(); err != nil { + if err := cs.BaseClassSettings.Validate(class); err != nil { return err } diff --git a/modules/text2vec-huggingface/ent/vectorization_result.go b/modules/text2vec-huggingface/ent/vectorization_result.go deleted file mode 100644 index 6a1acfaf38..0000000000 --- a/modules/text2vec-huggingface/ent/vectorization_result.go +++ /dev/null @@ -1,18 +0,0 @@ -// _ _ -// __ _____ __ ___ ___ __ _| |_ ___ -// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ -// \ V V / __/ (_| |\ V /| | (_| | || __/ -// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| -// -// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. -// -// CONTACT: hello@weaviate.io -// - -package ent - -type VectorizationResult struct { - Text string - Dimensions int - Vector []float32 -} diff --git a/modules/text2vec-huggingface/module.go b/modules/text2vec-huggingface/module.go index 5bb5673c4f..c7010bb4f9 100644 --- a/modules/text2vec-huggingface/module.go +++ b/modules/text2vec-huggingface/module.go @@ -17,9 +17,9 @@ import ( "os" "time" - "github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase" + "github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" - "github.com/weaviate/weaviate/usecases/modulecomponents/batch" + "github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -38,7 +38,7 @@ func New() *HuggingFaceModule { } type HuggingFaceModule struct { - vectorizer text2vecbase.TextVectorizer + vectorizer text2vecbase.TextVectorizerBatch metaProvider text2vecbase.MetaProvider graphqlProvider modulecapabilities.GraphQLArguments searcher modulecapabilities.Searcher @@ -95,7 +95,7 @@ func (m *HuggingFaceModule) initVectorizer(ctx context.Context, timeout time.Dur apiKey := os.Getenv("HUGGINGFACE_APIKEY") client := clients.New(apiKey, timeout, logger) - m.vectorizer = vectorizer.New(client) + m.vectorizer = vectorizer.New(client, logger) m.metaProvider = client return nil @@ -114,7 +114,7 @@ func (m *HuggingFaceModule) RootHandler() http.Handler { func (m *HuggingFaceModule) VectorizeObject(ctx context.Context, obj *models.Object, cfg moduletools.ClassConfig, ) ([]float32, models.AdditionalProperties, error) { - return m.vectorizer.Object(ctx, obj, cfg) + return m.vectorizer.Object(ctx, obj, cfg, ent.NewClassSettings(cfg)) } func (m *HuggingFaceModule) VectorizableProperties(cfg moduletools.ClassConfig) (bool, []string, error) { @@ -122,7 +122,8 @@ func (m *HuggingFaceModule) VectorizableProperties(cfg moduletools.ClassConfig) } func (m *HuggingFaceModule) VectorizeBatch(ctx context.Context, objs []*models.Object, skipObject []bool, cfg moduletools.ClassConfig) ([][]float32, []models.AdditionalProperties, map[int]error) { - return batch.VectorizeBatch(ctx, objs, skipObject, cfg, m.logger, m.vectorizer.Object) + vecs, errs := m.vectorizer.ObjectBatch(ctx, objs, skipObject, cfg) + return vecs, nil, errs } func (m *HuggingFaceModule) MetaInfo() (map[string]interface{}, error) { diff --git a/modules/text2vec-huggingface/vectorizer/class_settings_test.go b/modules/text2vec-huggingface/vectorizer/class_settings_test.go deleted file mode 100644 index 763ff49a7a..0000000000 --- a/modules/text2vec-huggingface/vectorizer/class_settings_test.go +++ /dev/null @@ -1,122 +0,0 @@ -// _ _ -// __ _____ __ ___ ___ __ _| |_ ___ -// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ -// \ V V / __/ (_| |\ V /| | (_| | || __/ -// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| -// -// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. -// -// CONTACT: hello@weaviate.io -// - -package vectorizer - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/weaviate/weaviate/entities/moduletools" -) - -func Test_classSettings_getPassageModel(t *testing.T) { - tests := []struct { - name string - cfg moduletools.ClassConfig - wantPassageModel string - wantQueryModel string - wantWaitForModel bool - wantUseGPU bool - wantUseCache bool - wantEndpointURL string - wantError error - }{ - { - name: "CShorten/CORD-19-Title-Abstracts", - cfg: fakeClassConfig{ - classConfig: map[string]interface{}{ - "model": "CShorten/CORD-19-Title-Abstracts", - "options": map[string]interface{}{ - "waitForModel": true, - "useGPU": false, - "useCache": false, - }, - }, - }, - wantPassageModel: "CShorten/CORD-19-Title-Abstracts", - wantQueryModel: "CShorten/CORD-19-Title-Abstracts", - wantWaitForModel: true, - wantUseGPU: false, - wantUseCache: false, - }, - { - name: "sentence-transformers/all-MiniLM-L6-v2", - cfg: fakeClassConfig{ - classConfig: map[string]interface{}{ - "model": "sentence-transformers/all-MiniLM-L6-v2", - }, - }, - wantPassageModel: "sentence-transformers/all-MiniLM-L6-v2", - wantQueryModel: "sentence-transformers/all-MiniLM-L6-v2", - wantWaitForModel: false, - wantUseGPU: false, - wantUseCache: true, - }, - { - name: "DPR models", - cfg: fakeClassConfig{ - classConfig: map[string]interface{}{ - "passageModel": "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base", - "queryModel": "sentence-transformers/facebook-dpr-question_encoder-single-nq-base", - }, - }, - wantPassageModel: "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base", - wantQueryModel: "sentence-transformers/facebook-dpr-question_encoder-single-nq-base", - wantWaitForModel: false, - wantUseGPU: false, - wantUseCache: true, - }, - { - name: "Hugging Face Inference API - endpointURL", - cfg: fakeClassConfig{ - classConfig: map[string]interface{}{ - "endpointURL": "http://endpoint.cloud", - }, - }, - wantPassageModel: "", - wantQueryModel: "", - wantWaitForModel: false, - wantUseGPU: false, - wantUseCache: true, - wantEndpointURL: "http://endpoint.cloud", - }, - { - name: "Hugging Face Inference API - wrong properties", - cfg: fakeClassConfig{ - classConfig: map[string]interface{}{ - "endpointUrl": "http://endpoint.cloud", - "properties": "wrong-properties", - }, - }, - wantPassageModel: "", - wantQueryModel: "", - wantWaitForModel: false, - wantUseGPU: false, - wantUseCache: true, - wantEndpointURL: "http://endpoint.cloud", - wantError: errors.New("properties field needs to be of array type, got: string"), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ic := NewClassSettings(tt.cfg) - assert.Equal(t, tt.wantPassageModel, ic.getPassageModel()) - assert.Equal(t, tt.wantQueryModel, ic.getQueryModel()) - assert.Equal(t, tt.wantWaitForModel, ic.OptionWaitForModel()) - assert.Equal(t, tt.wantUseGPU, ic.OptionUseGPU()) - assert.Equal(t, tt.wantUseCache, ic.OptionUseCache()) - assert.Equal(t, tt.wantEndpointURL, ic.EndpointURL()) - assert.Equal(t, tt.wantError, ic.validateClassSettings()) - }) - } -} diff --git a/modules/text2vec-huggingface/vectorizer/fakes_for_test.go b/modules/text2vec-huggingface/vectorizer/fakes_for_test.go index 63da74be84..f8c55c720e 100644 --- a/modules/text2vec-huggingface/vectorizer/fakes_for_test.go +++ b/modules/text2vec-huggingface/vectorizer/fakes_for_test.go @@ -14,38 +14,47 @@ package vectorizer import ( "context" - "github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" + "github.com/weaviate/weaviate/entities/moduletools" + "github.com/weaviate/weaviate/usecases/modulecomponents" ) type fakeClient struct { - lastInput string - lastConfig ent.VectorizationConfig + lastInput []string + lastConfig moduletools.ClassConfig } func (c *fakeClient) Vectorize(ctx context.Context, - text string, cfg ent.VectorizationConfig, -) (*ent.VectorizationResult, error) { + text []string, cfg moduletools.ClassConfig, +) (*modulecomponents.VectorizationResult, *modulecomponents.RateLimits, error) { c.lastInput = text c.lastConfig = cfg - return &ent.VectorizationResult{ - Vector: []float32{0, 1, 2, 3}, + return &modulecomponents.VectorizationResult{ + Vector: [][]float32{{0, 1, 2, 3}}, Dimensions: 4, Text: text, - }, nil + }, nil, nil } func (c *fakeClient) VectorizeQuery(ctx context.Context, - text string, cfg ent.VectorizationConfig, -) (*ent.VectorizationResult, error) { + text []string, cfg moduletools.ClassConfig, +) (*modulecomponents.VectorizationResult, error) { c.lastInput = text c.lastConfig = cfg - return &ent.VectorizationResult{ - Vector: []float32{0.1, 1.1, 2.1, 3.1}, + return &modulecomponents.VectorizationResult{ + Vector: [][]float32{{0.1, 1.1, 2.1, 3.1}}, Dimensions: 4, Text: text, }, nil } +func (c *fakeClient) GetVectorizerRateLimit(ctx context.Context) *modulecomponents.RateLimits { + return &modulecomponents.RateLimits{} +} + +func (c *fakeClient) GetApiKeyHash(ctx context.Context, cfg moduletools.ClassConfig) [32]byte { + return [32]byte{} +} + type fakeClassConfig struct { classConfig map[string]interface{} vectorizeClassName bool diff --git a/modules/text2vec-huggingface/vectorizer/objects.go b/modules/text2vec-huggingface/vectorizer/objects.go index 319ac8827d..ffe12badcc 100644 --- a/modules/text2vec-huggingface/vectorizer/objects.go +++ b/modules/text2vec-huggingface/vectorizer/objects.go @@ -13,6 +13,11 @@ package vectorizer import ( "context" + "time" + + "github.com/sirupsen/logrus" + "github.com/weaviate/weaviate/usecases/modulecomponents/batch" + "github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase" "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/moduletools" @@ -20,59 +25,31 @@ import ( objectsvectorizer "github.com/weaviate/weaviate/usecases/modulecomponents/vectorizer" ) -type Vectorizer struct { - client Client - objectVectorizer *objectsvectorizer.ObjectVectorizer -} - -func New(client Client) *Vectorizer { - return &Vectorizer{ - client: client, - objectVectorizer: objectsvectorizer.New(), - } -} - -type Client interface { - Vectorize(ctx context.Context, input string, - config ent.VectorizationConfig) (*ent.VectorizationResult, error) - VectorizeQuery(ctx context.Context, input string, - config ent.VectorizationConfig) (*ent.VectorizationResult, error) -} - -// IndexCheck returns whether a property of a class should be indexed -type ClassSettings interface { - PropertyIndexed(property string) bool - VectorizePropertyName(propertyName string) bool - VectorizeClassName() bool - EndpointURL() string - PassageModel() string - QueryModel() string - OptionWaitForModel() bool - OptionUseGPU() bool - OptionUseCache() bool -} - -func (v *Vectorizer) Object(ctx context.Context, object *models.Object, cfg moduletools.ClassConfig, -) ([]float32, models.AdditionalProperties, error) { - vec, err := v.object(ctx, object, cfg) - return vec, nil, err -} - -func (v *Vectorizer) object(ctx context.Context, object *models.Object, cfg moduletools.ClassConfig, -) ([]float32, error) { - icheck := NewClassSettings(cfg) - text := v.objectVectorizer.Texts(ctx, object, icheck) +const ( + MaxObjectsPerBatch = 100 // https://docs.cohere.com/reference/embed + MaxTimePerBatch = float64(10) +) - res, err := v.client.Vectorize(ctx, text, ent.VectorizationConfig{ - EndpointURL: icheck.EndpointURL(), - Model: icheck.PassageModel(), - WaitForModel: icheck.OptionWaitForModel(), - UseGPU: icheck.OptionUseGPU(), - UseCache: icheck.OptionUseCache(), - }) - if err != nil { - return nil, err +func New(client text2vecbase.BatchClient, logger logrus.FieldLogger) *text2vecbase.BatchVectorizer { + batchTokenizer := func(ctx context.Context, objects []*models.Object, skipObject []bool, cfg moduletools.ClassConfig, objectVectorizer *objectsvectorizer.ObjectVectorizer) ([]string, []int, bool, error) { + texts := make([]string, len(objects)) + tokenCounts := make([]int, len(objects)) + icheck := ent.NewClassSettings(cfg) + + // prepare input for vectorizer, and send it to the queue. Prepare here to avoid work in the queue-worker + skipAll := true + for i := range texts { + if skipObject[i] { + continue + } + skipAll = false + text := objectVectorizer.Texts(ctx, objects[i], icheck) + texts[i] = text + tokenCounts[i] = 0 + } + return texts, tokenCounts, skipAll, nil } - - return res.Vector, nil + // there does not seem to be a limit + maxTokensPerBatch := func(cfg moduletools.ClassConfig) int { return 500000 } + return text2vecbase.New(client, batch.NewBatchVectorizer(client, 50*time.Second, MaxObjectsPerBatch, maxTokensPerBatch, MaxTimePerBatch, logger), batchTokenizer) } diff --git a/modules/text2vec-huggingface/vectorizer/objects_test.go b/modules/text2vec-huggingface/vectorizer/objects_test.go index 806d7ef8f8..c90a4aaa13 100644 --- a/modules/text2vec-huggingface/vectorizer/objects_test.go +++ b/modules/text2vec-huggingface/vectorizer/objects_test.go @@ -13,9 +13,11 @@ package vectorizer import ( "context" - "strings" "testing" + "github.com/sirupsen/logrus/hooks/test" + "github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/weaviate/weaviate/entities/models" @@ -35,6 +37,7 @@ func TestVectorizingObjects(t *testing.T) { passageModel string endpointURL string } + logger, _ := test.NewNullLogger() tests := []testCase{ { @@ -183,7 +186,7 @@ func TestVectorizingObjects(t *testing.T) { t.Run(test.name, func(t *testing.T) { client := &fakeClient{} - v := New(client) + v := New(client, logger) ic := &fakeClassConfig{ excludedProperty: test.excludedProperty, @@ -193,15 +196,14 @@ func TestVectorizingObjects(t *testing.T) { endpointURL: test.endpointURL, vectorizePropertyName: true, } - vector, _, err := v.Object(context.Background(), test.input, ic) + vector, _, err := v.Object(context.Background(), test.input, ic, ent.NewClassSettings(ic)) require.Nil(t, err) assert.Equal(t, []float32{0, 1, 2, 3}, vector) - expected := strings.Split(test.expectedClientCall, " ") - actual := strings.Split(client.lastInput, " ") - assert.Equal(t, expected, actual) + assert.Equal(t, []string{test.expectedClientCall}, client.lastInput) if test.expectedHuggingFaceModel != "" { - assert.Equal(t, test.expectedHuggingFaceModel, client.lastConfig.Model) + ic := ent.NewClassSettings(client.lastConfig) + assert.Equal(t, test.expectedHuggingFaceModel, ic.PassageModel()) } }) } diff --git a/modules/text2vec-huggingface/vectorizer/texts.go b/modules/text2vec-huggingface/vectorizer/texts.go deleted file mode 100644 index 8c78cea8dd..0000000000 --- a/modules/text2vec-huggingface/vectorizer/texts.go +++ /dev/null @@ -1,54 +0,0 @@ -// _ _ -// __ _____ __ ___ ___ __ _| |_ ___ -// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ -// \ V V / __/ (_| |\ V /| | (_| | || __/ -// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| -// -// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. -// -// CONTACT: hello@weaviate.io -// - -package vectorizer - -import ( - "context" - - "github.com/pkg/errors" - "github.com/weaviate/weaviate/entities/moduletools" - "github.com/weaviate/weaviate/modules/text2vec-contextionary/vectorizer" - "github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" - libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer" -) - -func (v *Vectorizer) VectorizeInput(ctx context.Context, input string, - icheck vectorizer.ClassIndexCheck, -) ([]float32, error) { - res, err := v.client.VectorizeQuery(ctx, input, ent.VectorizationConfig{}) - if err != nil { - return nil, err - } - return res.Vector, nil -} - -func (v *Vectorizer) Texts(ctx context.Context, inputs []string, - cfg moduletools.ClassConfig, -) ([]float32, error) { - settings := NewClassSettings(cfg) - vectors := make([][]float32, len(inputs)) - for i := range inputs { - res, err := v.client.VectorizeQuery(ctx, inputs[i], ent.VectorizationConfig{ - EndpointURL: settings.EndpointURL(), - Model: settings.QueryModel(), - WaitForModel: settings.OptionWaitForModel(), - UseGPU: settings.OptionUseGPU(), - UseCache: settings.OptionUseCache(), - }) - if err != nil { - return nil, errors.Wrap(err, "remote client vectorize") - } - vectors[i] = res.Vector - } - - return libvectorizer.CombineVectors(vectors), nil -} diff --git a/modules/text2vec-huggingface/vectorizer/texts_test.go b/modules/text2vec-huggingface/vectorizer/texts_test.go deleted file mode 100644 index 43850a0ee6..0000000000 --- a/modules/text2vec-huggingface/vectorizer/texts_test.go +++ /dev/null @@ -1,100 +0,0 @@ -// _ _ -// __ _____ __ ___ ___ __ _| |_ ___ -// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ -// \ V V / __/ (_| |\ V /| | (_| | || __/ -// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| -// -// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. -// -// CONTACT: hello@weaviate.io -// - -package vectorizer - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// as used in the nearText searcher -func TestVectorizingTexts(t *testing.T) { - type testCase struct { - name string - input []string - expectedHuggingFaceModel string - huggingFaceModel string - huggingFaceEndpointURL string - } - - tests := []testCase{ - { - name: "single word", - input: []string{"hello"}, - huggingFaceModel: "sentence-transformers/gtr-t5-xl", - expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl", - }, - { - name: "multiple words", - input: []string{"hello world, this is me!"}, - huggingFaceModel: "sentence-transformers/gtr-t5-xl", - expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl", - }, - { - name: "multiple sentences (joined with a dot)", - input: []string{"this is sentence 1", "and here's number 2"}, - huggingFaceModel: "sentence-transformers/gtr-t5-xl", - expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl", - }, - { - name: "multiple sentences already containing a dot", - input: []string{"this is sentence 1.", "and here's number 2"}, - huggingFaceModel: "sentence-transformers/gtr-t5-xl", - expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl", - }, - { - name: "multiple sentences already containing a question mark", - input: []string{"this is sentence 1?", "and here's number 2"}, - huggingFaceModel: "sentence-transformers/gtr-t5-xl", - expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl", - }, - { - name: "multiple sentences already containing an exclamation mark", - input: []string{"this is sentence 1!", "and here's number 2"}, - huggingFaceModel: "sentence-transformers/gtr-t5-xl", - expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl", - }, - { - name: "multiple sentences already containing comma", - input: []string{"this is sentence 1,", "and here's number 2"}, - huggingFaceModel: "sentence-transformers/gtr-t5-xl", - expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl", - }, - { - name: "single word with inference url", - input: []string{"hello"}, - huggingFaceEndpointURL: "http://url.cloud", - expectedHuggingFaceModel: "sentence-transformers/msmarco-bert-base-dot-v5", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - client := &fakeClient{} - - v := New(client) - - settings := &fakeClassConfig{ - model: test.huggingFaceModel, - endpointURL: test.huggingFaceEndpointURL, - } - vec, err := v.Texts(context.Background(), test.input, settings) - - require.Nil(t, err) - assert.Equal(t, []float32{0.1, 1.1, 2.1, 3.1}, vec) - assert.Equal(t, client.lastConfig.Model, test.expectedHuggingFaceModel) - }) - } -} diff --git a/modules/text2vec-jinaai/vectorizer/texts_test.go b/modules/text2vec-jinaai/vectorizer/texts_test.go deleted file mode 100644 index 2db0068abd..0000000000 --- a/modules/text2vec-jinaai/vectorizer/texts_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// _ _ -// __ _____ __ ___ ___ __ _| |_ ___ -// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ -// \ V V / __/ (_| |\ V /| | (_| | || __/ -// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| -// -// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. -// -// CONTACT: hello@weaviate.io -// - -package vectorizer - -import ( - "context" - "testing" - - "github.com/sirupsen/logrus/hooks/test" - "github.com/weaviate/weaviate/modules/text2vec-cohere/ent" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// as used in the nearText searcher -func TestVectorizingTexts(t *testing.T) { - logger, _ := test.NewNullLogger() - type testCase struct { - name string - input []string - expectedJinaAIModel string - jinaAIModel string - } - - tests := []testCase{ - { - name: "single word", - input: []string{"hello"}, - jinaAIModel: "jina-embedding-v2", - expectedJinaAIModel: "jina-embedding-v2", - }, - { - name: "multiple words", - input: []string{"hello world, this is me!"}, - jinaAIModel: "jina-embedding-v2", - expectedJinaAIModel: "jina-embedding-v2", - }, - { - name: "multiple sentences (joined with a dot)", - input: []string{"this is sentence 1", "and here's number 2"}, - jinaAIModel: "jina-embedding-v2", - expectedJinaAIModel: "jina-embedding-v2", - }, - { - name: "multiple sentences already containing a dot", - input: []string{"this is sentence 1.", "and here's number 2"}, - jinaAIModel: "jina-embedding-v2", - expectedJinaAIModel: "jina-embedding-v2", - }, - { - name: "multiple sentences already containing a question mark", - input: []string{"this is sentence 1?", "and here's number 2"}, - jinaAIModel: "jina-embedding-v2", - expectedJinaAIModel: "jina-embedding-v2", - }, - { - name: "multiple sentences already containing an exclamation mark", - input: []string{"this is sentence 1!", "and here's number 2"}, - jinaAIModel: "jina-embedding-v2", - expectedJinaAIModel: "jina-embedding-v2", - }, - { - name: "multiple sentences already containing comma", - input: []string{"this is sentence 1,", "and here's number 2"}, - jinaAIModel: "jina-embedding-v2", - expectedJinaAIModel: "jina-embedding-v2", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - client := &fakeClient{} - - v := New(client, logger) - - settings := &fakeClassConfig{ - jinaAIModel: test.jinaAIModel, - } - vec, err := v.Texts(context.Background(), test.input, settings) - - require.Nil(t, err) - assert.Equal(t, []float32{0.1, 1.1, 2.1, 3.1}, vec) - assert.Equal(t, test.input, client.lastInput) - config := ent.NewClassSettings(client.lastConfig) - assert.Equal(t, config.Model(), test.expectedJinaAIModel) - }) - } -} diff --git a/modules/text2vec-octoai/clients/octoai.go b/modules/text2vec-octoai/clients/octoai.go index 5f0c7235b4..0c3451436d 100644 --- a/modules/text2vec-octoai/clients/octoai.go +++ b/modules/text2vec-octoai/clients/octoai.go @@ -67,15 +67,15 @@ func buildUrl(config ent.VectorizationConfig) (string, error) { } type vectorizer struct { - octoAIApiKey string - httpClient *http.Client - buildUrlFn func(config ent.VectorizationConfig) (string, error) - logger logrus.FieldLogger + apiKey string + httpClient *http.Client + buildUrlFn func(config ent.VectorizationConfig) (string, error) + logger logrus.FieldLogger } -func New(octoAIApiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { +func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { return &vectorizer{ - octoAIApiKey: octoAIApiKey, + apiKey: apiKey, httpClient: &http.Client{ Timeout: timeout, }, @@ -182,36 +182,15 @@ func (v *vectorizer) getApiKeyHeaderAndValue(apiKey string) (string, string) { } func (v *vectorizer) getApiKey(ctx context.Context) (string, error) { - var apiKey, envVar string - - apiKey = "X-OctoAI-Api-Key" - envVar = "OCTOAI_APIKEY" - if len(v.octoAIApiKey) > 0 { - return v.octoAIApiKey, nil - } - - return v.getApiKeyFromContext(ctx, apiKey, envVar) -} - -func (v *vectorizer) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) { - if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" { - return apiKeyValue, nil + if v.apiKey != "" { + return v.apiKey, nil } - return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar) -} - -func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string { - if value := ctx.Value(key); value != nil { - if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { - return keyHeader[0] - } + if apiKey := modulecomponents.GetValueFromContext(ctx, "X-OctoAI-Api-Key"); apiKey != "" { + return apiKey, nil } - // try getting header from GRPC if not successful - if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { - return apiKey[0] - } - - return "" + return "", errors.New("no api key found " + + "neither in request header: X-OctoAI-Api-Key " + + "nor in environment variable under OCTOAI_APIKEY") } func (v *vectorizer) GetApiKeyHash(ctx context.Context, config moduletools.ClassConfig) [32]byte { diff --git a/modules/text2vec-ollama/vectorizer/class_settings.go b/modules/text2vec-ollama/vectorizer/class_settings.go index 0f5fc015d0..9a6d77e95f 100644 --- a/modules/text2vec-ollama/vectorizer/class_settings.go +++ b/modules/text2vec-ollama/vectorizer/class_settings.go @@ -21,7 +21,7 @@ import ( const ( apiEndpointProperty = "apiEndpoint" - modelIDProperty = "modelId" + modelProperty = "model" ) const ( @@ -29,7 +29,7 @@ const ( DefaultPropertyIndexed = true DefaultVectorizePropertyName = false DefaultApiEndpoint = "http://localhost:11434" - DefaultModelID = "nomic-embed-text" + DefaultModel = "nomic-embed-text" ) type classSettings struct { @@ -48,8 +48,8 @@ func (ic *classSettings) Validate(class *models.Class) error { if ic.ApiEndpoint() == "" { return errors.New("apiEndpoint cannot be empty") } - if ic.ModelID() == "" { - return errors.New("modelId cannot be empty") + if ic.Model() == "" { + return errors.New("model cannot be empty") } return nil } @@ -58,14 +58,10 @@ func (ic *classSettings) getStringProperty(name, defaultValue string) string { return ic.BaseClassSettings.GetPropertyAsString(name, defaultValue) } -func (ic *classSettings) getDefaultModel() string { - return DefaultModelID -} - func (ic *classSettings) ApiEndpoint() string { return ic.getStringProperty(apiEndpointProperty, DefaultApiEndpoint) } -func (ic *classSettings) ModelID() string { - return ic.getStringProperty(modelIDProperty, ic.getDefaultModel()) +func (ic *classSettings) Model() string { + return ic.getStringProperty(modelProperty, DefaultModel) } diff --git a/modules/text2vec-ollama/vectorizer/class_settings_test.go b/modules/text2vec-ollama/vectorizer/class_settings_test.go index 7bc67f0c51..c6d6b65b47 100644 --- a/modules/text2vec-ollama/vectorizer/class_settings_test.go +++ b/modules/text2vec-ollama/vectorizer/class_settings_test.go @@ -27,7 +27,7 @@ func Test_classSettings_Validate(t *testing.T) { name string cfg moduletools.ClassConfig wantApiEndpoint string - wantModelID string + wantModel string wantErr error }{ { @@ -36,7 +36,7 @@ func Test_classSettings_Validate(t *testing.T) { classConfig: map[string]interface{}{}, }, wantApiEndpoint: "http://localhost:11434", - wantModelID: "nomic-embed-text", + wantModel: "nomic-embed-text", wantErr: nil, }, { @@ -44,11 +44,11 @@ func Test_classSettings_Validate(t *testing.T) { cfg: fakeClassConfig{ classConfig: map[string]interface{}{ "apiEndpoint": "https://localhost:11434", - "modelId": "future-text-embed", + "model": "future-text-embed", }, }, wantApiEndpoint: "https://localhost:11434", - wantModelID: "future-text-embed", + wantModel: "future-text-embed", wantErr: nil, }, { @@ -56,7 +56,7 @@ func Test_classSettings_Validate(t *testing.T) { cfg: fakeClassConfig{ classConfig: map[string]interface{}{ "apiEndpoint": "", - "modelId": "test", + "model": "test", }, }, wantErr: errors.Errorf("apiEndpoint cannot be empty"), @@ -66,10 +66,10 @@ func Test_classSettings_Validate(t *testing.T) { cfg: fakeClassConfig{ classConfig: map[string]interface{}{ "apiEndpoint": "http://localhost:8080", - "modelId": "", + "model": "", }, }, - wantErr: errors.Errorf("modelId cannot be empty"), + wantErr: errors.Errorf("model cannot be empty"), }, } for _, tt := range tests { @@ -84,7 +84,7 @@ func Test_classSettings_Validate(t *testing.T) { }}), tt.wantErr.Error()) } else { assert.Equal(t, tt.wantApiEndpoint, ic.ApiEndpoint()) - assert.Equal(t, tt.wantModelID, ic.ModelID()) + assert.Equal(t, tt.wantModel, ic.Model()) } }) } diff --git a/modules/text2vec-ollama/vectorizer/objects.go b/modules/text2vec-ollama/vectorizer/objects.go index 473407b076..f0d7e099df 100644 --- a/modules/text2vec-ollama/vectorizer/objects.go +++ b/modules/text2vec-ollama/vectorizer/objects.go @@ -60,7 +60,7 @@ func (v *Vectorizer) object(ctx context.Context, object *models.Object, cfg modu text := v.objectVectorizer.Texts(ctx, object, icheck) res, err := v.client.Vectorize(ctx, text, ent.VectorizationConfig{ ApiEndpoint: icheck.ApiEndpoint(), - Model: icheck.ModelID(), + Model: icheck.Model(), }) if err != nil { return nil, err diff --git a/modules/text2vec-ollama/vectorizer/texts.go b/modules/text2vec-ollama/vectorizer/texts.go index 6c521c94c6..e03af9aad6 100644 --- a/modules/text2vec-ollama/vectorizer/texts.go +++ b/modules/text2vec-ollama/vectorizer/texts.go @@ -28,7 +28,7 @@ func (v *Vectorizer) Texts(ctx context.Context, inputs []string, for i := range inputs { res, err := v.client.Vectorize(ctx, inputs[i], ent.VectorizationConfig{ ApiEndpoint: settings.ApiEndpoint(), - Model: settings.ModelID(), + Model: settings.Model(), }) if err != nil { return nil, errors.Wrap(err, "remote client vectorize") diff --git a/openapi-specs/schema.json b/openapi-specs/schema.json index 2acd889fd6..5e37eceffa 100644 --- a/openapi-specs/schema.json +++ b/openapi-specs/schema.json @@ -2112,7 +2112,7 @@ }, "description": "Cloud-native, modular vector database", "title": "Weaviate", - "version": "1.25.0-rc.0" + "version": "1.25.0" }, "parameters": { "CommonAfterParameterQuery": { diff --git a/test/acceptance/grpc/tenants_test.go b/test/acceptance/grpc/tenants_test.go index 474cf24905..4fcb6aae76 100644 --- a/test/acceptance/grpc/tenants_test.go +++ b/test/acceptance/grpc/tenants_test.go @@ -55,59 +55,51 @@ func TestGRPCTenantsGet(t *testing.T) { } helper.CreateTenants(t, className, tenants) - t.Run("Gets consistent tenants of a class", func(t *testing.T) { - for _, isConsistent := range []bool{true, false} { - resp, err := grpcClient.TenantsGet(context.TODO(), &pb.TenantsGetRequest{ - Collection: className, - IsConsistent: isConsistent, - }) - if err != nil { - t.Fatalf("error while getting tenants: %v", err) - } - for _, tenant := range resp.Tenants { - require.Equal(t, slices.Contains(tenantNames, tenant.Name), true) - require.Equal(t, tenant.ActivityStatus, pb.TenantActivityStatus_TENANT_ACTIVITY_STATUS_HOT) - } + t.Run("Gets tenants of a class", func(t *testing.T) { + resp, err := grpcClient.TenantsGet(context.TODO(), &pb.TenantsGetRequest{ + Collection: className, + }) + if err != nil { + t.Fatalf("error while getting tenants: %v", err) + } + for _, tenant := range resp.Tenants { + require.Equal(t, slices.Contains(tenantNames, tenant.Name), true) + require.Equal(t, tenant.ActivityStatus, pb.TenantActivityStatus_TENANT_ACTIVITY_STATUS_HOT) } }) t.Run("Gets two tenants by their names", func(t *testing.T) { - for _, isConsistent := range []bool{true, false} { - resp, err := grpcClient.TenantsGet(context.TODO(), &pb.TenantsGetRequest{ - Collection: className, - IsConsistent: isConsistent, - Params: &pb.TenantsGetRequest_Names{ - Names: &pb.TenantNames{ - Values: []string{tenantNames[0], tenantNames[2]}, - }, + resp, err := grpcClient.TenantsGet(context.TODO(), &pb.TenantsGetRequest{ + Collection: className, + Params: &pb.TenantsGetRequest_Names{ + Names: &pb.TenantNames{ + Values: []string{tenantNames[0], tenantNames[2]}, }, - }) - if err != nil { - t.Fatalf("error while getting tenants: %v", err) - } - require.Equal(t, resp.Tenants, []*pb.Tenant{{ - Name: tenantNames[0], - ActivityStatus: pb.TenantActivityStatus_TENANT_ACTIVITY_STATUS_HOT, - }, { - Name: tenantNames[2], - ActivityStatus: pb.TenantActivityStatus_TENANT_ACTIVITY_STATUS_HOT, - }}) + }, + }) + if err != nil { + t.Fatalf("error while getting tenants: %v", err) } + require.Equal(t, resp.Tenants, []*pb.Tenant{{ + Name: tenantNames[0], + ActivityStatus: pb.TenantActivityStatus_TENANT_ACTIVITY_STATUS_HOT, + }, { + Name: tenantNames[2], + ActivityStatus: pb.TenantActivityStatus_TENANT_ACTIVITY_STATUS_HOT, + }}) }) t.Run("Returns error when tenant names are missing", func(t *testing.T) { _, err := grpcClient.TenantsGet(context.TODO(), &pb.TenantsGetRequest{ - Collection: className, - IsConsistent: true, - Params: &pb.TenantsGetRequest_Names{}, + Collection: className, + Params: &pb.TenantsGetRequest_Names{}, }) require.NotNil(t, err) }) t.Run("Returns error when tenant names are specified empty", func(t *testing.T) { _, err := grpcClient.TenantsGet(context.TODO(), &pb.TenantsGetRequest{ - Collection: className, - IsConsistent: true, + Collection: className, Params: &pb.TenantsGetRequest_Names{ Names: &pb.TenantNames{ Values: []string{}, @@ -118,18 +110,15 @@ func TestGRPCTenantsGet(t *testing.T) { }) t.Run("Returns nothing when tenant names are not found", func(t *testing.T) { - for _, isConsistent := range []bool{true, false} { - resp, err := grpcClient.TenantsGet(context.TODO(), &pb.TenantsGetRequest{ - Collection: className, - IsConsistent: isConsistent, - Params: &pb.TenantsGetRequest_Names{ - Names: &pb.TenantNames{ - Values: []string{"NonExistentTenant"}, - }, + resp, err := grpcClient.TenantsGet(context.TODO(), &pb.TenantsGetRequest{ + Collection: className, + Params: &pb.TenantsGetRequest_Names{ + Names: &pb.TenantNames{ + Values: []string{"NonExistentTenant"}, }, - }) - require.Nil(t, err) - require.Empty(t, resp.Tenants) - } + }, + }) + require.Nil(t, err) + require.Empty(t, resp.Tenants) }) } diff --git a/test/acceptance/replication/crud_test.go b/test/acceptance/replication/crud_test.go index c4551516b2..64479bd5bd 100644 --- a/test/acceptance/replication/crud_test.go +++ b/test/acceptance/replication/crud_test.go @@ -22,13 +22,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/weaviate/weaviate/client/objects" + "github.com/weaviate/weaviate/client/schema" "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/schema/crossref" "github.com/weaviate/weaviate/test/docker" "github.com/weaviate/weaviate/test/helper" "github.com/weaviate/weaviate/test/helper/sample-schema/articles" "github.com/weaviate/weaviate/usecases/replica" - "golang.org/x/sync/errgroup" ) var ( @@ -252,7 +252,7 @@ func immediateReplicaCRUD(t *testing.T) { }) t.Run("OnNode-1", func(t *testing.T) { - _, err := getObjectFromNode(t, compose.ContainerURI(1), "Article", articleIDs[0], "node2") + _, err := getObjectFromNode(t, compose.ContainerURI(1), "Article", articleIDs[0], "node1") assert.Equal(t, &objects.ObjectsClassGetNotFound{}, err) }) t.Run("OnNode-2", func(t *testing.T) { @@ -292,12 +292,11 @@ func immediateReplicaCRUD(t *testing.T) { } func eventualReplicaCRUD(t *testing.T) { - t.Skip("Skip until https://github.com/weaviate/weaviate/issues/4840 is resolved") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() compose, err := docker.New(). - WithWeaviateCluster(). + With3NodeCluster(). WithText2VecContextionary(). Start(ctx) require.Nil(t, err) @@ -340,36 +339,44 @@ func eventualReplicaCRUD(t *testing.T) { createObjects(t, compose.GetWeaviate().URI(), batch) }) - t.Run("configure classes to replicate to node 2", func(t *testing.T) { + t.Run("configure classes to replicate to node 2 and 3", func(t *testing.T) { ac := helper.GetClass(t, "Article") ac.ReplicationConfig = &models.ReplicationConfig{ - Factor: 2, + Factor: 3, } helper.UpdateClass(t, ac) pc := helper.GetClass(t, "Paragraph") pc.ReplicationConfig = &models.ReplicationConfig{ - Factor: 2, + Factor: 3, } helper.UpdateClass(t, pc) }) - t.Run("StopNode-1", func(t *testing.T) { - stopNodeAt(ctx, t, compose, 1) + t.Run("StopNode-3", func(t *testing.T) { + stopNodeAt(ctx, t, compose, 3) }) t.Run("assert all previous data replicated to node 2", func(t *testing.T) { - // TODO-RAFT : we need to avoid any sleeps, come back and remove it - // sleep 2 sec to make sure data not affected by EC issue - time.Sleep(2 * time.Second) - resp := gqlGet(t, compose.GetWeaviateNode2().URI(), "Article", replica.One) - assert.Len(t, resp, len(articleIDs)) - resp = gqlGet(t, compose.GetWeaviateNode2().URI(), "Paragraph", replica.One) - assert.Len(t, resp, len(paragraphIDs)) + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + resp := gqlGet(t, compose.GetWeaviateNode2().URI(), "Article", replica.One) + assert.Len(collect, resp, len(articleIDs)) + resp = gqlGet(t, compose.GetWeaviateNode2().URI(), "Paragraph", replica.One) + assert.Len(collect, resp, len(paragraphIDs)) + }, 5*time.Second, 100*time.Millisecond) }) - t.Run("RestartNode-1", func(t *testing.T) { - restartNode1(ctx, t, compose) + t.Run("RestartNode-3", func(t *testing.T) { + startNodeAt(ctx, t, compose, 3) + }) + + t.Run("assert all previous data replicated to node 3", func(t *testing.T) { + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + resp := gqlGet(t, compose.GetWeaviateNode3().URI(), "Article", replica.All) + assert.Len(collect, resp, len(articleIDs)) + resp = gqlGet(t, compose.GetWeaviateNode3().URI(), "Paragraph", replica.All) + assert.Len(collect, resp, len(paragraphIDs)) + }, 5*time.Second, 100*time.Millisecond) }) t.Run("assert any future writes are replicated", func(t *testing.T) { @@ -387,92 +394,81 @@ func eventualReplicaCRUD(t *testing.T) { patchObject(t, compose.GetWeaviateNode2().URI(), patch) }) - t.Run("StopNode-2", func(t *testing.T) { - stopNodeAt(ctx, t, compose, 2) - }) - t.Run("PatchedOnNode-1", func(t *testing.T) { after, err := getObjectFromNode(t, compose.GetWeaviate().URI(), "Article", articleIDs[0], "node1") require.Nil(t, err) - newVal, ok := after.Properties.(map[string]interface{})["title"] - require.True(t, ok) - assert.Equal(t, newTitle, newVal) + require.Contains(t, after.Properties.(map[string]interface{}), "title") + assert.Equal(t, newTitle, after.Properties.(map[string]interface{})["title"]) }) - t.Run("RestartNode-2", func(t *testing.T) { - err = compose.Start(ctx, compose.GetWeaviateNode2().Name()) - require.Nil(t, err) - }) - }) + t.Run("PatchedOnNode-2", func(t *testing.T) { + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + after, err := getObjectFromNode(t, compose.GetWeaviateNode2().URI(), "Article", articleIDs[0], "node2") + require.Nil(collect, err) - t.Run("DeleteObject", func(t *testing.T) { - t.Run("OnNode-1", func(t *testing.T) { - deleteObject(t, compose.GetWeaviate().URI(), "Article", articleIDs[0]) + require.Contains(collect, after.Properties.(map[string]interface{}), "title") + assert.Equal(collect, newTitle, after.Properties.(map[string]interface{})["title"]) + }, 5*time.Second, 100*time.Millisecond) }) - t.Run("StopNode-1", func(t *testing.T) { - stopNodeAt(ctx, t, compose, 1) + t.Run("PatchedOnNode-3", func(t *testing.T) { + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + after, err := getObjectFromNode(t, compose.GetWeaviate().URI(), "Article", articleIDs[0], "node3") + require.Nil(collect, err) + + require.Contains(collect, after.Properties.(map[string]interface{}), "title") + assert.Equal(collect, newTitle, after.Properties.(map[string]interface{})["title"]) + }, 5*time.Second, 100*time.Millisecond) }) + }) + t.Run("DeleteObject", func(t *testing.T) { t.Run("OnNode-2", func(t *testing.T) { - _, err := getObjectFromNode(t, compose.GetWeaviateNode2().URI(), "Article", articleIDs[0], "node2") - assert.Equal(t, &objects.ObjectsClassGetNotFound{}, err) + deleteObject(t, compose.GetWeaviateNode2().URI(), "Article", articleIDs[0]) }) - t.Run("RestartNode-1", func(t *testing.T) { - restartNode1(ctx, t, compose) + t.Run("OnNode-1", func(t *testing.T) { + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + _, err := getObjectFromNode(t, compose.GetWeaviate().URI(), "Article", articleIDs[0], "node1") + assert.Equal(collect, &objects.ObjectsClassGetNotFound{}, err) + }, 5*time.Second, 100*time.Millisecond) }) }) - t.Run("BatchAllObjects", func(t *testing.T) { + t.Run("BatchDeleteAllObjects", func(t *testing.T) { t.Run("OnNode-2", func(t *testing.T) { deleteObjects(t, compose.GetWeaviateNode2().URI(), "Article", []string{"title"}, "Article#*") }) - t.Run("StopNode-2", func(t *testing.T) { - stopNodeAt(ctx, t, compose, 2) - }) - t.Run("OnNode-1", func(t *testing.T) { - resp := gqlGet(t, compose.GetWeaviate().URI(), "Article", replica.One) - assert.Empty(t, resp) + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + resp := gqlGet(t, compose.GetWeaviate().URI(), "Article", replica.One) + assert.Empty(collect, resp) + }, 5*time.Second, 100*time.Millisecond) }) + }) - t.Run("RestartNode-2", func(t *testing.T) { - err = compose.Start(ctx, compose.GetWeaviateNode2().Name()) - require.Nil(t, err) + t.Run("configure classes to decrease replication factor should fail", func(t *testing.T) { + ac := helper.GetClass(t, "Article") + ac.ReplicationConfig = &models.ReplicationConfig{ + Factor: 2, + } + + params := schema.NewSchemaObjectsUpdateParams(). + WithObjectClass(ac).WithClassName(ac.Class) + resp, err := helper.Client(t).Schema.SchemaObjectsUpdate(params, nil) + assert.NotNil(t, err) + helper.AssertRequestFail(t, resp, err, func() { + errResponse, ok := err.(*schema.SchemaObjectsUpdateUnprocessableEntity) + assert.True(t, ok) + assert.Equal(t, fmt.Sprintf("scale \"%s\" from 3 replicas to 2: scaling in not supported yet", ac.Class), errResponse.Payload.Error[0].Message) }) }) }) } -func restartNode1(ctx context.Context, t *testing.T, compose *docker.DockerCompose) { - // since node1 is the gossip "leader", node 2 must be stopped and restarted - // after node1 to re-facilitate internode communication - eg := errgroup.Group{} - eg.Go(func() error { - require.Nil(t, compose.StartAt(ctx, 1)) - return nil - }) - eg.Go(func() error { // restart node 2 - time.Sleep(3 * time.Second) // wait for member list initialization - stopNodeAt(ctx, t, compose, 2) - require.Nil(t, compose.StartAt(ctx, 2)) - return nil - }) - eg.Go(func() error { // restart node 3 - time.Sleep(3 * time.Second) // wait for member list initialization - stopNodeAt(ctx, t, compose, 3) - require.Nil(t, compose.StartAt(ctx, 3)) - return nil - }) - - eg.Wait() - <-time.After(3 * time.Second) // wait for initialization -} - func stopNodeAt(ctx context.Context, t *testing.T, compose *docker.DockerCompose, index int) { <-time.After(1 * time.Second) require.Nil(t, compose.StopAt(ctx, index, nil)) diff --git a/test/acceptance/replication/scale_test.go b/test/acceptance/replication/scale_test.go index ce960a96ba..acca6b6511 100644 --- a/test/acceptance/replication/scale_test.go +++ b/test/acceptance/replication/scale_test.go @@ -29,12 +29,11 @@ import ( ) func multiShardScaleOut(t *testing.T) { - t.Skip("Skip until https://github.com/weaviate/weaviate/issues/4840 is resolved") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() compose, err := docker.New(). - WithWeaviateCluster(). + With3NodeCluster(). WithText2VecContextionary(). Start(ctx) require.Nil(t, err) @@ -99,17 +98,20 @@ func multiShardScaleOut(t *testing.T) { }) t.Run("assert paragraphs were scaled out", func(t *testing.T) { - n := getNodes(t, compose.GetWeaviate().URI()) - var shardsFound int - for _, node := range n.Nodes { - for _, shard := range node.Shards { - if shard.Class == paragraphClass.Class { - assert.EqualValues(t, 10, shard.ObjectCount) - shardsFound++ + // shard.ObjectCount is eventually consistent, see Bucket::CountAsync() + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + n := getNodes(t, compose.GetWeaviate().URI()) + var shardsFound int + for _, node := range n.Nodes { + for _, shard := range node.Shards { + if shard.Class == paragraphClass.Class { + assert.EqualValues(collect, int64(10), shard.ObjectCount) + shardsFound++ + } } } - } - assert.Equal(t, 2, shardsFound) + assert.Equal(collect, 2, shardsFound) + }, 10*time.Second, 100*time.Millisecond) }) t.Run("scale out articles", func(t *testing.T) { @@ -119,24 +121,24 @@ func multiShardScaleOut(t *testing.T) { }) t.Run("assert articles were scaled out", func(t *testing.T) { - n := getNodes(t, compose.GetWeaviate().URI()) - var shardsFound int - for _, node := range n.Nodes { - for _, shard := range node.Shards { - if shard.Class == articleClass.Class { - assert.EqualValues(t, 10, shard.ObjectCount) - shardsFound++ + // shard.ObjectCount is eventually consistent, see Bucket::CountAsync() + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + n := getNodes(t, compose.GetWeaviate().URI()) + var shardsFound int + for _, node := range n.Nodes { + for _, shard := range node.Shards { + if shard.Class == articleClass.Class { + assert.EqualValues(collect, int64(10), shard.ObjectCount) + shardsFound++ + } } } - } - assert.Equal(t, 2, shardsFound) + assert.Equal(collect, 2, shardsFound) + }, 10*time.Second, 100*time.Millisecond) }) t.Run("kill a node and check contents of remaining node", func(t *testing.T) { stopNodeAt(ctx, t, compose, 2) - // TODO-RAFT : we need to avoid any sleeps, come back and remove it - // sleep 2 sec to make sure data not affected by EC issue - time.Sleep(2 * time.Second) p := gqlGet(t, compose.GetWeaviate().URI(), paragraphClass.Class, replica.One) assert.Len(t, p, 10) a := gqlGet(t, compose.GetWeaviate().URI(), articleClass.Class, replica.One) diff --git a/test/acceptance_with_go_client/go.mod b/test/acceptance_with_go_client/go.mod index 3dcbb12211..dc0744441f 100644 --- a/test/acceptance_with_go_client/go.mod +++ b/test/acceptance_with_go_client/go.mod @@ -31,6 +31,21 @@ require ( github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/armon/go-metrics v0.4.1 // indirect github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d // indirect + github.com/aws/aws-sdk-go-v2 v1.26.1 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/config v1.27.12 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.12 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.20.6 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.28.7 // indirect + github.com/aws/smithy-go v1.20.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/test/acceptance_with_go_client/go.sum b/test/acceptance_with_go_client/go.sum index 30a04972c2..2fab486772 100644 --- a/test/acceptance_with_go_client/go.sum +++ b/test/acceptance_with_go_client/go.sum @@ -48,6 +48,36 @@ github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/config v1.27.12 h1:vq88mBaZI4NGLXk8ierArwSILmYHDJZGJOeAc/pzEVQ= +github.com/aws/aws-sdk-go-v2/config v1.27.12/go.mod h1:IOrsf4IiN68+CgzyuyGUYTpCrtUQTbbMEAtR/MR/4ZU= +github.com/aws/aws-sdk-go-v2/credentials v1.17.12 h1:PVbKQ0KjDosI5+nEdRMU8ygEQDmkJTSHBqPjEX30lqc= +github.com/aws/aws-sdk-go-v2/credentials v1.17.12/go.mod h1:jlWtGFRtKsqc5zqerHZYmKmRkUXo3KPM14YJ13ZEjwE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1 h1:vTHgBjsGhgKWWIgioxd7MkBH5Ekr8C6Cb+/8iWf1dpc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.6 h1:o5cTaeunSpfXiLTIBx5xo2enQmiChtu1IBbzXnfU9Hs= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.6/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.5 h1:Ciiz/plN+Z+pPO1G0W2zJoYIIl0KtKzY0LJ78NXYTws= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.5/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.7 h1:et3Ta53gotFR4ERLXXHIHl/Uuk1qYpP5uU7cvNql8ns= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.7/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= diff --git a/test/acceptance_with_go_client/named_vectors_tests/named_vectors_restart_test.go b/test/acceptance_with_go_client/named_vectors_tests/named_vectors_restart_test.go index dfd60a652b..6b8176ab0c 100644 --- a/test/acceptance_with_go_client/named_vectors_tests/named_vectors_restart_test.go +++ b/test/acceptance_with_go_client/named_vectors_tests/named_vectors_restart_test.go @@ -12,10 +12,11 @@ package named_vectors_tests import ( - "acceptance_tests_with_client/fixtures" "context" "testing" + "acceptance_tests_with_client/fixtures" + "github.com/go-openapi/strfmt" "github.com/stretchr/testify/require" wvt "github.com/weaviate/weaviate-go-client/v4/weaviate" diff --git a/test/docker/compose.go b/test/docker/compose.go index a94030b514..65517cb408 100644 --- a/test/docker/compose.go +++ b/test/docker/compose.go @@ -253,7 +253,10 @@ func (d *Compose) WithText2VecPaLM(apiKey string) *Compose { return d } -func (d *Compose) WithText2VecAWS() *Compose { +func (d *Compose) WithText2VecAWS(accessKey, secretKey, sessionToken string) *Compose { + d.weaviateEnvs["AWS_ACCESS_KEY"] = accessKey + d.weaviateEnvs["AWS_SECRET_KEY"] = secretKey + d.weaviateEnvs["AWS_SESSION_TOKEN"] = sessionToken d.enableModules = append(d.enableModules, modaws.Name) return d } @@ -268,7 +271,10 @@ func (d *Compose) WithGenerativeOpenAI() *Compose { return d } -func (d *Compose) WithGenerativeAWS() *Compose { +func (d *Compose) WithGenerativeAWS(accessKey, secretKey, sessionToken string) *Compose { + d.weaviateEnvs["AWS_ACCESS_KEY"] = accessKey + d.weaviateEnvs["AWS_SECRET_KEY"] = secretKey + d.weaviateEnvs["AWS_SESSION_TOKEN"] = sessionToken d.enableModules = append(d.enableModules, modgenerativeaws.Name) return d } diff --git a/test/docker/docker.go b/test/docker/docker.go index 42e5fe22c2..582cbc9231 100644 --- a/test/docker/docker.go +++ b/test/docker/docker.go @@ -134,6 +134,10 @@ func (d *DockerCompose) GetWeaviateNode2() *DockerContainer { return d.getContainerByName(Weaviate2) } +func (d *DockerCompose) GetWeaviateNode3() *DockerContainer { + return d.getContainerByName(Weaviate3) +} + func (d *DockerCompose) GetText2VecTransformers() *DockerContainer { return d.getContainerByName(Text2VecTransformers) } diff --git a/test/docker/ollama.go b/test/docker/ollama.go index 1974340a93..b2e7c4e3c3 100644 --- a/test/docker/ollama.go +++ b/test/docker/ollama.go @@ -38,7 +38,7 @@ func startOllama(ctx context.Context, networkName, hostname, model string) (*Doc port := nat.Port("11434/tcp") container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ - Image: "ollama/ollama:0.1.30", + Image: "ollama/ollama:0.1.33", Hostname: hostname, Networks: []string{networkName}, NetworkAliases: map[string][]string{ diff --git a/test/modules/generative-aws/generative_aws_test.go b/test/modules/generative-aws/generative_aws_test.go new file mode 100644 index 0000000000..34484134aa --- /dev/null +++ b/test/modules/generative-aws/generative_aws_test.go @@ -0,0 +1,266 @@ +// _ _ +// __ _____ __ ___ ___ __ _| |_ ___ +// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ +// \ V V / __/ (_| |\ V /| | (_| | || __/ +// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| +// +// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. +// +// CONTACT: hello@weaviate.io +// + +package generative_palm_tests + +import ( + "fmt" + "testing" + + "github.com/go-openapi/strfmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/weaviate/weaviate/entities/models" + "github.com/weaviate/weaviate/entities/schema" + "github.com/weaviate/weaviate/test/helper" + graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql" +) + +func testGenerativeAWS(host, region string) func(t *testing.T) { + return func(t *testing.T) { + helper.SetupClient(host) + // Data + companies := []struct { + id strfmt.UUID + name, description string + }{ + { + id: strfmt.UUID("00000000-0000-0000-0000-000000000001"), + name: "OpenAI", + description: ` + OpenAI is a research organization and AI development company that focuses on artificial intelligence (AI) and machine learning (ML). + Founded in December 2015, OpenAI's mission is to ensure that artificial general intelligence (AGI) benefits all of humanity. + The organization has been at the forefront of AI research, producing cutting-edge advancements in natural language processing, + reinforcement learning, robotics, and other AI-related fields. + + OpenAI has garnered attention for its work on various projects, including the development of the GPT (Generative Pre-trained Transformer) + series of models, such as GPT-2 and GPT-3, which have demonstrated remarkable capabilities in generating human-like text. + Additionally, OpenAI has contributed to advancements in reinforcement learning through projects like OpenAI Five, an AI system + capable of playing the complex strategy game Dota 2 at a high level. + `, + }, + { + id: strfmt.UUID("00000000-0000-0000-0000-000000000002"), + name: "SpaceX", + description: ` + SpaceX, short for Space Exploration Technologies Corp., is an American aerospace manufacturer and space transportation company + founded by Elon Musk in 2002. The company's primary goal is to reduce space transportation costs and enable the colonization of Mars, + among other ambitious objectives. + + SpaceX has made significant strides in the aerospace industry by developing advanced rocket technology, spacecraft, + and satellite systems. The company is best known for its Falcon series of rockets, including the Falcon 1, Falcon 9, + and Falcon Heavy, which have been designed with reusability in mind. Reusability has been a key innovation pioneered by SpaceX, + aiming to drastically reduce the cost of space travel by reusing rocket components multiple times. + `, + }, + } + // Define class + className := "BooksGenerativeTest" + class := &models.Class{ + Class: className, + Properties: []*models.Property{ + { + Name: "name", DataType: []string{schema.DataTypeText.String()}, + }, + { + Name: "description", DataType: []string{schema.DataTypeText.String()}, + }, + }, + VectorConfig: map[string]models.VectorConfig{ + "description": { + Vectorizer: map[string]interface{}{ + "text2vec-aws": map[string]interface{}{ + "properties": []interface{}{"description"}, + "vectorizeClassName": false, + "service": "bedrock", + "region": region, + "model": "amazon.titan-embed-text-v2:0", + }, + }, + VectorIndexType: "flat", + }, + }, + } + tests := []struct { + name string + generativeModel string + }{ + { + name: "cohere.command-text-v14", + generativeModel: "cohere.command-text-v14", + }, + { + name: "cohere.command-light-text-v14", + generativeModel: "cohere.command-light-text-v14", + }, + { + name: "cohere.command-r-v1:0", + generativeModel: "cohere.command-r-v1:0", + }, + { + name: "cohere.command-r-plus-v1:0", + generativeModel: "cohere.command-r-plus-v1:0", + }, + { + name: "anthropic.claude-v2", + generativeModel: "anthropic.claude-v2", + }, + { + name: "anthropic.claude-v2:1", + generativeModel: "anthropic.claude-v2:1", + }, + { + name: "anthropic.claude-instant-v1", + generativeModel: "anthropic.claude-instant-v1", + }, + { + name: "anthropic.claude-3-sonnet-20240229-v1:0", + generativeModel: "anthropic.claude-3-sonnet-20240229-v1:0", + }, + { + name: "anthropic.claude-3-haiku-20240307-v1:0", + generativeModel: "anthropic.claude-3-haiku-20240307-v1:0", + }, + { + name: "ai21.j2-ultra-v1", + generativeModel: "ai21.j2-ultra-v1", + }, + { + name: "ai21.j2-mid-v1", + generativeModel: "ai21.j2-mid-v1", + }, + { + name: "amazon.titan-text-lite-v1", + generativeModel: "amazon.titan-text-lite-v1", + }, + { + name: "amazon.titan-text-premier-v1:0", + generativeModel: "amazon.titan-text-premier-v1:0", + }, + { + name: "amazon.titan-text-express-v1", + generativeModel: "amazon.titan-text-express-v1", + }, + { + name: "mistral.mistral-7b-instruct-v0:2", + generativeModel: "mistral.mistral-7b-instruct-v0:2", + }, + { + name: "mistral.mixtral-8x7b-instruct-v0:1", + generativeModel: "mistral.mixtral-8x7b-instruct-v0:1", + }, + { + name: "mistral.mistral-large-2402-v1:0", + generativeModel: "mistral.mistral-large-2402-v1:0", + }, + { + name: "meta.llama3-8b-instruct-v1:0", + generativeModel: "meta.llama3-8b-instruct-v1:0", + }, + { + name: "meta.llama3-70b-instruct-v1:0", + generativeModel: "meta.llama3-70b-instruct-v1:0", + }, + { + name: "meta.llama2-13b-chat-v1", + generativeModel: "meta.llama2-13b-chat-v1", + }, + { + name: "meta.llama2-70b-chat-v1", + generativeModel: "meta.llama2-70b-chat-v1", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + class.ModuleConfig = map[string]interface{}{ + "generative-aws": map[string]interface{}{ + "projectId": "semi-random-dev", + "service": "bedrock", + "region": region, + "model": tt.generativeModel, + }, + } + // create schema + helper.CreateClass(t, class) + defer helper.DeleteClass(t, class.Class) + // create objects + t.Run("create objects", func(t *testing.T) { + for _, company := range companies { + obj := &models.Object{ + Class: class.Class, + ID: company.id, + Properties: map[string]interface{}{ + "name": company.name, + "description": company.description, + }, + } + helper.CreateObject(t, obj) + helper.AssertGetObjectEventually(t, obj.Class, obj.ID) + } + }) + t.Run("check objects existence", func(t *testing.T) { + for _, company := range companies { + t.Run(company.id.String(), func(t *testing.T) { + obj, err := helper.GetObject(t, class.Class, company.id, "vector") + require.NoError(t, err) + require.NotNil(t, obj) + require.Len(t, obj.Vectors, 1) + assert.True(t, len(obj.Vectors["description"]) > 0) + }) + } + }) + // generative task + t.Run("create a tweet", func(t *testing.T) { + prompt := "Generate a funny tweet out of this content: {description}" + query := fmt.Sprintf(` + { + Get { + %s{ + name + _additional { + generate( + singleResult: { + prompt: """ + %s + """ + } + ) { + singleResult + error + } + } + } + } + } + `, class.Class, prompt) + result := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) + objs := result.Get("Get", class.Class).AsSlice() + require.Len(t, objs, 2) + for _, obj := range objs { + name := obj.(map[string]interface{})["name"] + assert.NotEmpty(t, name) + additional, ok := obj.(map[string]interface{})["_additional"].(map[string]interface{}) + require.True(t, ok) + require.NotNil(t, additional) + generate, ok := additional["generate"].(map[string]interface{}) + require.True(t, ok) + require.NotNil(t, generate) + require.Nil(t, generate["error"]) + require.NotNil(t, generate["singleResult"]) + singleResult, ok := generate["singleResult"].(string) + require.True(t, ok) + require.NotEmpty(t, singleResult) + } + }) + }) + } + } +} diff --git a/test/modules/generative-aws/setup_test.go b/test/modules/generative-aws/setup_test.go new file mode 100644 index 0000000000..1453f9a50e --- /dev/null +++ b/test/modules/generative-aws/setup_test.go @@ -0,0 +1,70 @@ +// _ _ +// __ _____ __ ___ ___ __ _| |_ ___ +// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ +// \ V V / __/ (_| |\ V /| | (_| | || __/ +// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| +// +// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. +// +// CONTACT: hello@weaviate.io +// + +package generative_palm_tests + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/weaviate/weaviate/test/docker" +) + +func TestGenerativeAWS_SingleNode(t *testing.T) { + accessKey := os.Getenv("AWS_ACCESS_KEY") + if accessKey == "" { + accessKey = os.Getenv("AWS_ACCESS_KEY_ID") + if accessKey == "" { + t.Skip("skipping, AWS_ACCESS_KEY environment variable not present") + } + } + secretKey := os.Getenv("AWS_SECRET_KEY") + if secretKey == "" { + secretKey = os.Getenv("AWS_SECRET_ACCESS_KEY") + if secretKey == "" { + t.Skip("skipping, AWS_SECRET_KEY environment variable not present") + } + } + sessionToken := os.Getenv("AWS_SESSION_TOKEN") + if sessionToken == "" { + t.Skip("skipping, AWS_SESSION_TOKEN environment variable not present") + } + region := os.Getenv("AWS_REGION") + if region == "" { + t.Skip("skipping, AWS_REGION environment variable not present") + } + ctx := context.Background() + compose, err := createSingleNodeEnvironment(ctx, accessKey, secretKey, sessionToken) + require.NoError(t, err) + defer func() { + require.NoError(t, compose.Terminate(ctx)) + }() + endpoint := compose.GetWeaviate().URI() + + t.Run("tests", testGenerativeAWS(endpoint, region)) +} + +func createSingleNodeEnvironment(ctx context.Context, accessKey, secretKey, sessionToken string, +) (compose *docker.DockerCompose, err error) { + compose, err = composeModules(accessKey, secretKey, sessionToken). + WithWeaviate(). + Start(ctx) + return +} + +func composeModules(accessKey, secretKey, sessionToken string) (composeModules *docker.Compose) { + composeModules = docker.New(). + WithText2VecAWS(accessKey, secretKey, sessionToken). + WithGenerativeAWS(accessKey, secretKey, sessionToken) + return +} diff --git a/test/modules/generative-ollama/ollama_generative_test.go b/test/modules/generative-ollama/ollama_generative_test.go index f95a631b51..5013876d37 100644 --- a/test/modules/generative-ollama/ollama_generative_test.go +++ b/test/modules/generative-ollama/ollama_generative_test.go @@ -68,7 +68,7 @@ func testGenerativeOllama(host, ollamaApiEndpoint string) func(t *testing.T) { class.ModuleConfig = map[string]interface{}{ "generative-ollama": map[string]interface{}{ "apiEndpoint": ollamaApiEndpoint, - "modelId": tt.generativeModel, + "model": tt.generativeModel, }, } // create schema diff --git a/test/modules/many-modules/many_modules_test.go b/test/modules/many-modules/many_modules_test.go index 16bd538d06..5dbe8fe0d2 100644 --- a/test/modules/many-modules/many_modules_test.go +++ b/test/modules/many-modules/many_modules_test.go @@ -67,11 +67,11 @@ func composeModules() (composeModules *docker.Compose) { WithText2VecVoyageAI(). WithText2VecPaLM(os.Getenv("PALM_APIKEY")). WithText2VecHuggingFace(). - WithText2VecAWS(). + WithText2VecAWS(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("AWS_SESSION_TOKEN")). WithGenerativeOpenAI(). WithGenerativeCohere(). WithGenerativePaLM(os.Getenv("PALM_APIKEY")). - WithGenerativeAWS(). + WithGenerativeAWS(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("AWS_SESSION_TOKEN")). WithGenerativeAnyscale(). WithQnAOpenAI(). WithRerankerCohere(). diff --git a/test/modules/text2vec-aws/setup_test.go b/test/modules/text2vec-aws/setup_test.go new file mode 100644 index 0000000000..9f8a11622a --- /dev/null +++ b/test/modules/text2vec-aws/setup_test.go @@ -0,0 +1,69 @@ +// _ _ +// __ _____ __ ___ ___ __ _| |_ ___ +// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ +// \ V V / __/ (_| |\ V /| | (_| | || __/ +// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| +// +// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. +// +// CONTACT: hello@weaviate.io +// + +package tests + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/weaviate/weaviate/test/docker" +) + +func TestText2VecAWS_SingleNode(t *testing.T) { + accessKey := os.Getenv("AWS_ACCESS_KEY") + if accessKey == "" { + accessKey = os.Getenv("AWS_ACCESS_KEY_ID") + if accessKey == "" { + t.Skip("skipping, AWS_ACCESS_KEY environment variable not present") + } + } + secretKey := os.Getenv("AWS_SECRET_KEY") + if secretKey == "" { + secretKey = os.Getenv("AWS_SECRET_ACCESS_KEY") + if secretKey == "" { + t.Skip("skipping, AWS_SECRET_KEY environment variable not present") + } + } + sessionToken := os.Getenv("AWS_SESSION_TOKEN") + if sessionToken == "" { + t.Skip("skipping, AWS_SESSION_TOKEN environment variable not present") + } + region := os.Getenv("AWS_REGION") + if region == "" { + t.Skip("skipping, AWS_REGION environment variable not present") + } + ctx := context.Background() + compose, err := createSingleNodeEnvironment(ctx, accessKey, secretKey, sessionToken) + require.NoError(t, err) + defer func() { + require.NoError(t, compose.Terminate(ctx)) + }() + endpoint := compose.GetWeaviate().URI() + + t.Run("tests", testText2VecAWS(endpoint, region)) +} + +func createSingleNodeEnvironment(ctx context.Context, accessKey, secretKey, sessionToken string, +) (compose *docker.DockerCompose, err error) { + compose, err = composeModules(accessKey, secretKey, sessionToken). + WithWeaviate(). + Start(ctx) + return +} + +func composeModules(accessKey, secretKey, sessionToken string) (composeModules *docker.Compose) { + composeModules = docker.New(). + WithText2VecAWS(accessKey, secretKey, sessionToken) + return +} diff --git a/test/modules/text2vec-aws/text2vec_aws_test.go b/test/modules/text2vec-aws/text2vec_aws_test.go new file mode 100644 index 0000000000..20cdea9706 --- /dev/null +++ b/test/modules/text2vec-aws/text2vec_aws_test.go @@ -0,0 +1,179 @@ +// _ _ +// __ _____ __ ___ ___ __ _| |_ ___ +// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ +// \ V V / __/ (_| |\ V /| | (_| | || __/ +// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| +// +// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. +// +// CONTACT: hello@weaviate.io +// + +package tests + +import ( + "fmt" + "testing" + + "github.com/go-openapi/strfmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/weaviate/weaviate/entities/models" + "github.com/weaviate/weaviate/entities/schema" + "github.com/weaviate/weaviate/test/helper" + graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql" +) + +func testText2VecAWS(host, region string) func(t *testing.T) { + return func(t *testing.T) { + helper.SetupClient(host) + // Data + companies := []struct { + id strfmt.UUID + name, description string + }{ + { + id: strfmt.UUID("00000000-0000-0000-0000-000000000001"), + name: "OpenAI", + description: ` + OpenAI is a research organization and AI development company that focuses on artificial intelligence (AI) and machine learning (ML). + Founded in December 2015, OpenAI's mission is to ensure that artificial general intelligence (AGI) benefits all of humanity. + The organization has been at the forefront of AI research, producing cutting-edge advancements in natural language processing, + reinforcement learning, robotics, and other AI-related fields. + + OpenAI has garnered attention for its work on various projects, including the development of the GPT (Generative Pre-trained Transformer) + series of models, such as GPT-2 and GPT-3, which have demonstrated remarkable capabilities in generating human-like text. + Additionally, OpenAI has contributed to advancements in reinforcement learning through projects like OpenAI Five, an AI system + capable of playing the complex strategy game Dota 2 at a high level. + `, + }, + { + id: strfmt.UUID("00000000-0000-0000-0000-000000000002"), + name: "SpaceX", + description: ` + SpaceX, short for Space Exploration Technologies Corp., is an American aerospace manufacturer and space transportation company + founded by Elon Musk in 2002. The company's primary goal is to reduce space transportation costs and enable the colonization of Mars, + among other ambitious objectives. + + SpaceX has made significant strides in the aerospace industry by developing advanced rocket technology, spacecraft, + and satellite systems. The company is best known for its Falcon series of rockets, including the Falcon 1, Falcon 9, + and Falcon Heavy, which have been designed with reusability in mind. Reusability has been a key innovation pioneered by SpaceX, + aiming to drastically reduce the cost of space travel by reusing rocket components multiple times. + `, + }, + } + tests := []struct { + name string + model string + }{ + { + name: "amazon.titan-embed-text-v1", + model: "amazon.titan-embed-text-v1", + }, + { + name: "amazon.titan-embed-text-v2:0", + model: "amazon.titan-embed-text-v2:0", + }, + { + name: "cohere.embed-english-v3", + model: "cohere.embed-english-v3", + }, + { + name: "cohere.embed-multilingual-v3", + model: "cohere.embed-multilingual-v3", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Define class + className := "VectorizerTest" + class := &models.Class{ + Class: className, + Properties: []*models.Property{ + { + Name: "name", DataType: []string{schema.DataTypeText.String()}, + }, + { + Name: "description", DataType: []string{schema.DataTypeText.String()}, + }, + }, + VectorConfig: map[string]models.VectorConfig{ + "description": { + Vectorizer: map[string]interface{}{ + "text2vec-aws": map[string]interface{}{ + "properties": []interface{}{"description"}, + "vectorizeClassName": false, + "service": "bedrock", + "region": region, + "model": tt.model, + }, + }, + VectorIndexType: "flat", + }, + }, + } + // create schema + helper.CreateClass(t, class) + defer helper.DeleteClass(t, class.Class) + // create objects + t.Run("create objects", func(t *testing.T) { + for _, company := range companies { + obj := &models.Object{ + Class: class.Class, + ID: company.id, + Properties: map[string]interface{}{ + "name": company.name, + "description": company.description, + }, + } + helper.CreateObject(t, obj) + helper.AssertGetObjectEventually(t, obj.Class, obj.ID) + } + }) + t.Run("check objects existence", func(t *testing.T) { + for _, company := range companies { + t.Run(company.id.String(), func(t *testing.T) { + obj, err := helper.GetObject(t, class.Class, company.id, "vector") + require.NoError(t, err) + require.NotNil(t, obj) + require.Len(t, obj.Vectors, 1) + assert.True(t, len(obj.Vectors["description"]) > 0) + }) + } + }) + // vector search + t.Run("perform vector search", func(t *testing.T) { + query := fmt.Sprintf(` + { + Get { + %s( + nearText:{ + concepts:["SpaceX"] + } + ){ + name + _additional { + id + } + } + } + } + `, class.Class) + result := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) + objs := result.Get("Get", class.Class).AsSlice() + require.Len(t, objs, 2) + for _, obj := range objs { + name := obj.(map[string]interface{})["name"] + assert.NotEmpty(t, name) + additional, ok := obj.(map[string]interface{})["_additional"].(map[string]interface{}) + require.True(t, ok) + require.NotNil(t, additional) + id, ok := additional["id"].(string) + require.True(t, ok) + require.NotEmpty(t, id) + } + }) + }) + } + } +} diff --git a/test/modules/text2vec-ollama/ollama_vectorizer_test.go b/test/modules/text2vec-ollama/ollama_vectorizer_test.go index 137894cc49..e3693acc68 100644 --- a/test/modules/text2vec-ollama/ollama_vectorizer_test.go +++ b/test/modules/text2vec-ollama/ollama_vectorizer_test.go @@ -49,6 +49,7 @@ func testText2VecOllama(host, ollamaApiEndpoint string) func(t *testing.T) { "properties": []interface{}{"description"}, "vectorizeClassName": false, "apiEndpoint": ollamaApiEndpoint, + "model": "nomic-embed-text", }, }, VectorIndexType: "flat", diff --git a/test/run.sh b/test/run.sh index d6f4200de5..3f8eb26b7a 100755 --- a/test/run.sh +++ b/test/run.sh @@ -234,7 +234,7 @@ function run_acceptance_graphql_tests() { function run_acceptance_replication_tests() { for pkg in $(go list ./.../ | grep 'test/acceptance/replication'); do - if ! go test -count 1 -race "$pkg"; then + if ! go test -count 1 -v -race "$pkg"; then echo "Test for $pkg failed" >&2 return 1 fi diff --git a/tools/dev/restart_dev_environment.sh b/tools/dev/restart_dev_environment.sh index 4c32c828ba..0fed830362 100755 --- a/tools/dev/restart_dev_environment.sh +++ b/tools/dev/restart_dev_environment.sh @@ -64,7 +64,7 @@ fi docker compose -f $DOCKER_COMPOSE_FILE down --remove-orphans -rm -rf data-weaviate-0 data-weaviate-1 data-weaviate-2 connector_state.json schema_state.json +rm -rf data data-weaviate-0 data-weaviate-1 data-weaviate-2 backups-weaviate-0 backups-weaviate-1 backups-weaviate-2 connector_state.json schema_state.json docker compose -f $DOCKER_COMPOSE_FILE up -d "${ADDITIONAL_SERVICES[@]}" diff --git a/tools/dev/run_dev_server.sh b/tools/dev/run_dev_server.sh index 6151095307..9aba0e968d 100755 --- a/tools/dev/run_dev_server.sh +++ b/tools/dev/run_dev_server.sh @@ -40,7 +40,8 @@ case $CONFIG in local-single-node) AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true \ - BACKUP_FILESYSTEM_PATH="${PWD}/backups" \ + PERSISTENCE_DATA_PATH="./data-weaviate-0" \ + BACKUP_FILESYSTEM_PATH="${PWD}/backups-weaviate-0" \ ENABLE_MODULES="backup-filesystem" \ CLUSTER_IN_LOCALHOST=true \ CLUSTER_GOSSIP_BIND_PORT="7100" \ @@ -57,8 +58,9 @@ case $CONFIG in local-development) CONTEXTIONARY_URL=localhost:9999 \ AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true \ + PERSISTENCE_DATA_PATH="./data-weaviate-0" \ + BACKUP_FILESYSTEM_PATH="${PWD}/backups-weaviate-0" \ DEFAULT_VECTORIZER_MODULE=text2vec-contextionary \ - BACKUP_FILESYSTEM_PATH="${PWD}/backups" \ ENABLE_MODULES="text2vec-contextionary,backup-filesystem" \ CLUSTER_IN_LOCALHOST=true \ CLUSTER_GOSSIP_BIND_PORT="7100" \ diff --git a/usecases/config/config_handler.go b/usecases/config/config_handler.go index 704c66c13a..57cc2c8589 100644 --- a/usecases/config/config_handler.go +++ b/usecases/config/config_handler.go @@ -14,6 +14,7 @@ package config import ( "encoding/json" "fmt" + "math" "os" "regexp" "strings" @@ -214,11 +215,20 @@ type Persistence struct { MemtablesMaxSizeMB int `json:"memtablesMaxSizeMB" yaml:"memtablesMaxSizeMB"` MemtablesMinActiveDurationSeconds int `json:"memtablesMinActiveDurationSeconds" yaml:"memtablesMinActiveDurationSeconds"` MemtablesMaxActiveDurationSeconds int `json:"memtablesMaxActiveDurationSeconds" yaml:"memtablesMaxActiveDurationSeconds"` + LSMMaxSegmentSize int64 `json:"lsmMaxSegmentSize" yaml:"lsmMaxSegmentSize"` + HNSWMaxLogSize int64 `json:"hnswMaxLogSize" yaml:"hnswMaxLogSize"` } // DefaultPersistenceDataPath is the default location for data directory when no location is provided const DefaultPersistenceDataPath string = "./data" +// DefaultPersistenceLSMMaxSegmentSize is effectively unlimited for backward +// compatibility. TODO: consider changing this in a future release and make +// some noise about it. This is technically a breaking change. +const DefaultPersistenceLSMMaxSegmentSize = math.MaxInt64 + +const DefaultPersistenceHNSWMaxLogSize = 500 * 1024 * 1024 // 500MB for backward compatibility + func (p Persistence) Validate() error { if p.DataPath == "" { return fmt.Errorf("persistence.dataPath must be set") @@ -291,15 +301,16 @@ func (r ResourceUsage) Validate() error { } type Raft struct { - Port int - InternalRPCPort int - RPCMessageMaxSize int - Join []string - SnapshotThreshold uint64 - HeartbeatTimeout time.Duration - RecoveryTimeout time.Duration - ElectionTimeout time.Duration - SnapshotInterval time.Duration + Port int + InternalRPCPort int + RPCMessageMaxSize int + Join []string + SnapshotThreshold uint64 + HeartbeatTimeout time.Duration + RecoveryTimeout time.Duration + ElectionTimeout time.Duration + SnapshotInterval time.Duration + ConsistencyWaitTimeout time.Duration BootstrapTimeout time.Duration BootstrapExpect int @@ -351,6 +362,18 @@ func (r *Raft) Validate() error { if r.BootstrapExpect > len(r.Join) { return fmt.Errorf("raft.bootstrap.expect must be less than or equal to the length of raft.join") } + + if r.SnapshotInterval <= 0 { + return fmt.Errorf("raft.bootstrap.snapshot_interval must be more than 0") + } + + if r.SnapshotThreshold <= 0 { + return fmt.Errorf("raft.bootstrap.snapshot_threshold must be more than 0") + } + + if r.ConsistencyWaitTimeout <= 0 { + return fmt.Errorf("raft.bootstrap.consistency_wait_timeout must be more than 0") + } return nil } diff --git a/usecases/config/environment.go b/usecases/config/environment.go index 5a4655e984..4d053f26e3 100644 --- a/usecases/config/environment.go +++ b/usecases/config/environment.go @@ -31,6 +31,7 @@ const ( DefaultRaftGRPCMaxSize = 1024 * 1024 * 1024 DefaultRaftBootstrapTimeout = 90 DefaultRaftBootstrapExpect = 1 + DefaultRaftDir = "raft" ) // FromEnv takes a *Config as it will respect initial config that has been @@ -178,6 +179,28 @@ func FromEnv(config *Config) error { config.AvoidMmap = true } + if v := os.Getenv("PERSISTENCE_LSM_MAX_SEGMENT_SIZE"); v != "" { + parsed, err := parseResourceString(v) + if err != nil { + return fmt.Errorf("parse PERSISTENCE_LSM_MAX_SEGMENT_SIZE: %w", err) + } + + config.Persistence.LSMMaxSegmentSize = parsed + } else { + config.Persistence.LSMMaxSegmentSize = DefaultPersistenceLSMMaxSegmentSize + } + + if v := os.Getenv("PERSISTENCE_HNSW_MAX_LOG_SIZE"); v != "" { + parsed, err := parseResourceString(v) + if err != nil { + return fmt.Errorf("parse PERSISTENCE_HNSW_MAX_LOG_SIZE: %w", err) + } + + config.Persistence.HNSWMaxLogSize = parsed + } else { + config.Persistence.HNSWMaxLogSize = DefaultPersistenceHNSWMaxLogSize + } + clusterCfg, err := parseClusterConfig() if err != nil { return err @@ -467,6 +490,14 @@ func parseRAFTConfig(hostname string) (Raft, error) { return cfg, err } + if err := parsePositiveInt( + "RAFT_CONSISTENCY_WAIT_TIMEOUT", + func(val int) { cfg.ConsistencyWaitTimeout = time.Second * time.Duration(val) }, + 10, + ); err != nil { + return cfg, err + } + return cfg, nil } diff --git a/usecases/config/environment_test.go b/usecases/config/environment_test.go index 298bd14f6b..cb4375b7b0 100644 --- a/usecases/config/environment_test.go +++ b/usecases/config/environment_test.go @@ -743,3 +743,34 @@ func TestEnvironmentAuthentication(t *testing.T) { }) } } + +func TestEnvironmentHNSWMaxLogSize(t *testing.T) { + factors := []struct { + name string + value []string + expected int64 + expectedErr bool + }{ + {"Valid no unit", []string{"3"}, 3, false}, + {"Valid IEC unit", []string{"3KB"}, 3000, false}, + {"Valid SI unit", []string{"3KiB"}, 3 * 1024, false}, + {"not given", []string{}, DefaultPersistenceHNSWMaxLogSize, false}, + {"invalid factor", []string{"-1"}, -1, true}, + {"not parsable", []string{"I'm not a number"}, -1, true}, + } + for _, tt := range factors { + t.Run(tt.name, func(t *testing.T) { + if len(tt.value) == 1 { + t.Setenv("PERSISTENCE_HNSW_MAX_LOG_SIZE", tt.value[0]) + } + conf := Config{} + err := FromEnv(&conf) + + if tt.expectedErr { + require.NotNil(t, err) + } else { + require.Equal(t, tt.expected, conf.Persistence.HNSWMaxLogSize) + } + }) + } +} diff --git a/usecases/config/parse_resource_strings.go b/usecases/config/parse_resource_strings.go new file mode 100644 index 0000000000..505d98c4e9 --- /dev/null +++ b/usecases/config/parse_resource_strings.go @@ -0,0 +1,64 @@ +// _ _ +// __ _____ __ ___ ___ __ _| |_ ___ +// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ +// \ V V / __/ (_| |\ V /| | (_| | || __/ +// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| +// +// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. +// +// CONTACT: hello@weaviate.io +// + +package config + +import ( + "fmt" + "math" + "strconv" + "strings" +) + +// parseResourceString takes a string like "1024", "1KiB", "43TiB" and converts it to an integer number of bytes. +func parseResourceString(resource string) (int64, error) { + resource = strings.TrimSpace(resource) + + if strings.EqualFold(resource, "unlimited") || strings.EqualFold(resource, "nolimit") { + return math.MaxInt64, nil + } + + // Find where the digits end + lastDigit := len(resource) + for i, r := range resource { + if r < '0' || r > '9' { + lastDigit = i + break + } + } + + // Split the numeric part and the unit + number, unit := resource[:lastDigit], resource[lastDigit:] + unit = strings.TrimSpace(unit) // Clean up any surrounding whitespace + value, err := strconv.ParseInt(number, 10, 64) + if err != nil { + return 0, err + } + + unitMultipliers := map[string]int64{ + "": 1, // No unit means bytes + "B": 1, + "KiB": 1024, + "MiB": 1024 * 1024, + "GiB": 1024 * 1024 * 1024, + "TiB": 1024 * 1024 * 1024 * 1024, + "KB": 1000, + "MB": 1000 * 1000, + "GB": 1000 * 1000 * 1000, + "TB": 1000 * 1000 * 1000 * 1000, + } + multiplier, exists := unitMultipliers[unit] + if !exists { + return 0, fmt.Errorf("invalid or unsupported unit") + } + + return value * multiplier, nil +} diff --git a/usecases/config/parse_resource_strings_test.go b/usecases/config/parse_resource_strings_test.go new file mode 100644 index 0000000000..06e2077f6f --- /dev/null +++ b/usecases/config/parse_resource_strings_test.go @@ -0,0 +1,55 @@ +// _ _ +// __ _____ __ ___ ___ __ _| |_ ___ +// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ +// \ V V / __/ (_| |\ V /| | (_| | || __/ +// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| +// +// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. +// +// CONTACT: hello@weaviate.io +// + +package config + +import ( + "math" + "testing" +) + +func TestParseResourceString(t *testing.T) { + tests := []struct { + name string + input string + expected int64 + err bool + }{ + {"ValidBytes", "1024", 1024, false}, + {"ValidKiB", "1KiB", 1024, false}, + {"ValidMiB", "500MiB", 500 * 1024 * 1024, false}, + {"ValidTiB", "43TiB", 43 * 1024 * 1024 * 1024 * 1024, false}, + {"ValidKB", "1KB", 1000, false}, + {"ValidMB", "500MB", 500 * 1e6, false}, + {"ValidTB", "43TB", 43 * 1e12, false}, + {"InvalidUnit", "100GiL", 0, true}, + {"InvalidNumber", "tenKiB", 0, true}, + {"InvalidFormat", "1024 KiB", 1024 * 1024, false}, + {"EmptyString", "", 0, true}, + {"NoUnit", "12345", 12345, false}, + {"Unlimited lower case", "unlimited", math.MaxInt64, false}, + {"Unlimited unlimited upper case", "UNLIMITED", math.MaxInt64, false}, + {"Nolimit lower case", "nolimit", math.MaxInt64, false}, + {"Nolimit upper case", "NOLIMIT", math.MaxInt64, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := parseResourceString(tc.input) + if (err != nil) != tc.err { + t.Errorf("parseResourceString(%s) expected error: %v, got: %v", tc.input, tc.err, err != nil) + } + if result != tc.expected { + t.Errorf("parseResourceString(%s) expected %d, got %d", tc.input, tc.expected, result) + } + }) + } +} diff --git a/usecases/objects/validate.go b/usecases/objects/validate.go index c6a03c7ad0..54ae3a1985 100644 --- a/usecases/objects/validate.go +++ b/usecases/objects/validate.go @@ -15,6 +15,7 @@ import ( "context" "github.com/weaviate/weaviate/entities/additional" + "github.com/weaviate/weaviate/entities/classcache" "github.com/weaviate/weaviate/entities/models" ) @@ -34,6 +35,7 @@ func (m *Manager) ValidateObject(ctx context.Context, principal *models.Principa } defer unlock() + ctx = classcache.ContextWithClassCache(ctx) err = m.validateObjectAndNormalizeNames(ctx, principal, repl, obj, nil) if err != nil { return NewErrInvalidUserInput("invalid object: %v", err) diff --git a/usecases/schema/class.go b/usecases/schema/class.go index 19d8024c61..c665bac41e 100644 --- a/usecases/schema/class.go +++ b/usecases/schema/class.go @@ -222,23 +222,21 @@ func (h *Handler) UpdateClass(ctx context.Context, principal *models.Principal, return err } - // TODO: fix PushShard issues before enabling scale out - // https://github.com/weaviate/weaviate/issues/4840 - //initialRF := initial.ReplicationConfig.Factor - //updatedRF := updated.ReplicationConfig.Factor - // - //if initialRF != updatedRF { - // ss, _, err := h.metaWriter.QueryShardingState(className) - // if err != nil { - // return fmt.Errorf("query sharding state for %q: %w", className, err) - // } - // shardingState, err = h.scaleOut.Scale(ctx, className, ss.Config, initialRF, updatedRF) - // if err != nil { - // return fmt.Errorf( - // "scale %q from %d replicas to %d: %w", - // className, initialRF, updatedRF, err) - // } - //} + initialRF := initial.ReplicationConfig.Factor + updatedRF := updated.ReplicationConfig.Factor + + if initialRF != updatedRF { + ss, _, err := h.metaWriter.QueryShardingState(className) + if err != nil { + return fmt.Errorf("query sharding state for %q: %w", className, err) + } + shardingState, err = h.scaleOut.Scale(ctx, className, ss.Config, initialRF, updatedRF) + if err != nil { + return fmt.Errorf( + "scale %q from %d replicas to %d: %w", + className, initialRF, updatedRF, err) + } + } if err := validateImmutableFields(initial, updated); err != nil { return err diff --git a/usecases/schema/class_test.go b/usecases/schema/class_test.go index 86440913ab..a5d2bb62a9 100644 --- a/usecases/schema/class_test.go +++ b/usecases/schema/class_test.go @@ -73,6 +73,29 @@ func Test_AddClass(t *testing.T) { assert.EqualError(t, err, "'' is not a valid class name") }) + t.Run("with reserved class name", func(t *testing.T) { + handler, _ := newTestHandler(t, &fakeDB{}) + class := models.Class{Class: config.DefaultRaftDir} + _, _, err := handler.AddClass(ctx, nil, &class) + assert.EqualError(t, err, fmt.Sprintf("parse class name: class name `%s` is reserved", config.DefaultRaftDir)) + + class = models.Class{Class: "rAFT"} + _, _, err = handler.AddClass(ctx, nil, &class) + assert.EqualError(t, err, fmt.Sprintf("parse class name: class name `%s` is reserved", config.DefaultRaftDir)) + + class = models.Class{Class: "rAfT"} + _, _, err = handler.AddClass(ctx, nil, &class) + assert.EqualError(t, err, fmt.Sprintf("parse class name: class name `%s` is reserved", config.DefaultRaftDir)) + + class = models.Class{Class: "RaFT"} + _, _, err = handler.AddClass(ctx, nil, &class) + assert.EqualError(t, err, fmt.Sprintf("parse class name: class name `%s` is reserved", config.DefaultRaftDir)) + + class = models.Class{Class: "RAFT"} + _, _, err = handler.AddClass(ctx, nil, &class) + assert.EqualError(t, err, fmt.Sprintf("parse class name: class name `%s` is reserved", config.DefaultRaftDir)) + }) + t.Run("with default params", func(t *testing.T) { handler, fakeMetaHandler := newTestHandler(t, &fakeDB{}) class := models.Class{ diff --git a/usecases/schema/executor.go b/usecases/schema/executor.go index eaeca60636..9e62b11377 100644 --- a/usecases/schema/executor.go +++ b/usecases/schema/executor.go @@ -14,6 +14,7 @@ package schema import ( "context" "fmt" + "sync" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -23,9 +24,12 @@ import ( ) type executor struct { - store metaReader - migrator Migrator - callbacks []func(updatedSchema schema.Schema) + store metaReader + migrator Migrator + + callbacksLock sync.RWMutex + callbacks []func(updatedSchema schema.Schema) + logger logrus.FieldLogger restoreClassDir func(string) error } @@ -98,6 +102,11 @@ func (e *executor) UpdateClass(req api.UpdateClassRequest) error { req.Class.InvertedIndexConfig); err != nil { return errors.Wrap(err, "inverted index config") } + + if err := e.migrator.UpdateReplicationFactor(ctx, className, req.Class.ReplicationConfig.Factor); err != nil { + return fmt.Errorf("replication index update: %w", err) + } + return nil } @@ -223,6 +232,9 @@ func (e *executor) GetShardsStatus(class, tenant string) (models.ShardStatusList } func (e *executor) rebuildGQL(s models.Schema) { + e.callbacksLock.RLock() + defer e.callbacksLock.RUnlock() + for _, cb := range e.callbacks { cb(schema.Schema{ Objects: &s, @@ -238,5 +250,8 @@ func (e *executor) TriggerSchemaUpdateCallbacks() { // type update callback. The callbacks will be called any time we persist a // schema update func (e *executor) RegisterSchemaUpdateCallback(callback func(updatedSchema schema.Schema)) { + e.callbacksLock.Lock() + defer e.callbacksLock.Unlock() + e.callbacks = append(e.callbacks, callback) } diff --git a/usecases/schema/executor_test.go b/usecases/schema/executor_test.go index 10e6ce3f8f..844dedef17 100644 --- a/usecases/schema/executor_test.go +++ b/usecases/schema/executor_test.go @@ -43,6 +43,9 @@ func TestExecutor(t *testing.T) { cls := &models.Class{ Class: "A", VectorIndexConfig: flat.NewDefaultUserConfig(), + ReplicationConfig: &models.ReplicationConfig{ + Factor: 1, + }, } store.On("ReadOnlySchema").Return(models.Schema{}) store.On("ReadOnlyClass", "A", mock.Anything).Return(cls) diff --git a/usecases/schema/helpers_test.go b/usecases/schema/helpers_test.go index 6dbdeac646..7c2c14a746 100644 --- a/usecases/schema/helpers_test.go +++ b/usecases/schema/helpers_test.go @@ -323,6 +323,10 @@ func (f *fakeMigrator) UpdateInvertedIndexConfig(ctx context.Context, className return args.Error(0) } +func (f *fakeMigrator) UpdateReplicationFactor(ctx context.Context, className string, factor int64) error { + return nil +} + func (f *fakeMigrator) WaitForStartup(ctx context.Context) error { args := f.Called(ctx) return args.Error(0) diff --git a/usecases/schema/migrator.go b/usecases/schema/migrator.go index 54cf3b655c..9f5cca740a 100644 --- a/usecases/schema/migrator.go +++ b/usecases/schema/migrator.go @@ -59,6 +59,7 @@ type Migrator interface { ValidateInvertedIndexConfigUpdate(old, updated *models.InvertedIndexConfig) error UpdateInvertedIndexConfig(ctx context.Context, className string, updated *models.InvertedIndexConfig) error + UpdateReplicationFactor(ctx context.Context, className string, factor int64) error WaitForStartup(context.Context) error Shutdown(context.Context) error } diff --git a/usecases/schema/parser.go b/usecases/schema/parser.go index 05d9302c09..5b115b4362 100644 --- a/usecases/schema/parser.go +++ b/usecases/schema/parser.go @@ -14,12 +14,14 @@ package schema import ( "fmt" "reflect" + "strings" "github.com/pkg/errors" "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/schema" schemaConfig "github.com/weaviate/weaviate/entities/schema/config" "github.com/weaviate/weaviate/entities/vectorindex" + "github.com/weaviate/weaviate/usecases/config" shardingConfig "github.com/weaviate/weaviate/usecases/sharding/config" ) @@ -42,6 +44,10 @@ func (m *Parser) ParseClass(class *models.Class) error { return fmt.Errorf("class cannot be nil") } + if strings.EqualFold(class.Class, config.DefaultRaftDir) { + return fmt.Errorf("parse class name: %w", fmt.Errorf("class name `raft` is reserved")) + } + if err := m.parseShardingConfig(class); err != nil { return fmt.Errorf("parse sharding config: %w", err) } @@ -126,12 +132,6 @@ func (p *Parser) ParseClassUpdate(class, update *models.Class) (*models.Class, e return nil, err } - // TODO: fix PushShard issues before enabling scale out - // https://github.com/weaviate/weaviate/issues/4840 - if class.ReplicationConfig.Factor != update.ReplicationConfig.Factor { - return nil, fmt.Errorf("updating replication factor is not supported yet") - } - if err := validateImmutableFields(class, update); err != nil { return nil, err } diff --git a/usecases/traverser/hybrid_group_by.go b/usecases/traverser/hybrid_group_by.go index 8b3d7fca1a..0f6ab595ea 100644 --- a/usecases/traverser/hybrid_group_by.go +++ b/usecases/traverser/hybrid_group_by.go @@ -25,7 +25,6 @@ func (e *Explorer) groupSearchResults(ctx context.Context, sr search.Results, gr groups := map[string][]search.Result{} for _, result := range sr { - prop_i := result.Object().Properties prop := prop_i.(map[string]interface{}) val, ok := prop[groupBy.Property].(string)