From 8d57fc8fa2f15a63d7f43f0f24107d4f0889ea18 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Wed, 3 Apr 2024 14:46:32 +0000 Subject: [PATCH 01/11] Add rudimentary database watcher Adds a simple database watcher. At this point it's just one process, but the plan is to allow different implementations that inform the local running workers of changes that have occured on entities of interest in the database. Signed-off-by: Gabriel Adrian Samfira --- README.md | 2 +- cmd/garm/main.go | 3 + database/common/errors.go | 12 ++ database/common/{common.go => store.go} | 7 +- database/common/watcher.go | 50 ++++++++ database/sql/instances.go | 20 ---- database/sql/pools.go | 33 +++++- database/sql/repositories.go | 34 +++++- database/sql/repositories_test.go | 7 ++ database/sql/sql.go | 19 ++- database/sql/util.go | 29 +++++ database/watcher/consumer.go | 75 ++++++++++++ database/watcher/producer.go | 49 ++++++++ database/watcher/test_export.go | 12 ++ database/watcher/watcher.go | 151 ++++++++++++++++++++++++ internal/testing/mock_watcher.go | 45 +++++++ runner/pool/pool.go | 2 +- runner/repositories_test.go | 5 + 18 files changed, 514 insertions(+), 41 deletions(-) create mode 100644 database/common/errors.go rename database/common/{common.go => store.go} (98%) create mode 100644 database/common/watcher.go create mode 100644 database/watcher/consumer.go create mode 100644 database/watcher/producer.go create mode 100644 database/watcher/test_export.go create mode 100644 database/watcher/watcher.go create mode 100644 internal/testing/mock_watcher.go diff --git a/README.md b/README.md index 9536b8fb..f68da05b 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ GARM supports creating pools on either GitHub itself or on your own deployment o Through the use of providers, `GARM` can create runners in a variety of environments using the same `GARM` instance. Whether you want to create pools of runners in your OpenStack cloud, your Azure cloud and your Kubernetes cluster, that is easily achieved by just installing the appropriate providers, configuring them in `GARM` and creating pools that use them. You can create zero-runner pools for instances with high costs (large VMs, GPU enabled instances, etc) and have them spin up on demand, or you can create large pools of k8s backed runners that can be used for your CI/CD pipelines at a moment's notice. You can mix them up and create pools in any combination of providers or resource allocations you want. -:warning: **Important note**: The README and documentation in the `main` branch are relevant to the not yet released code that is present in `main`. Following the documentation from the `main` branch for a stable release of GARM, may lead to errors. To view the documentation for the latest stable release, please switch to the appropriate tag. For information about setting up `v0.1.4`, please refer to the [v0.1.4 tag](https://github.com/cloudbase/garm/tree/v0.1.4) +:warning: **Important note**: The README and documentation in the `main` branch are relevant to the not yet released code that is present in `main`. Following the documentation from the `main` branch for a stable release of GARM, may lead to errors. To view the documentation for the latest stable release, please switch to the appropriate tag. For information about setting up `v0.1.4`, please refer to the [v0.1.4 tag](https://github.com/cloudbase/garm/tree/v0.1.4). ## Join us on slack diff --git a/cmd/garm/main.go b/cmd/garm/main.go index 83d70326..d8eed80c 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -40,6 +40,7 @@ import ( "github.com/cloudbase/garm/config" "github.com/cloudbase/garm/database" "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/metrics" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner" //nolint:typecheck @@ -183,6 +184,7 @@ func main() { } ctx, stop := signal.NotifyContext(context.Background(), signals...) defer stop() + watcher.InitWatcher(ctx) ctx = auth.GetAdminContext(ctx) @@ -313,6 +315,7 @@ func main() { }() <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 60*time.Second) defer shutdownCancel() if err := srv.Shutdown(shutdownCtx); err != nil { diff --git a/database/common/errors.go b/database/common/errors.go new file mode 100644 index 00000000..b68d8311 --- /dev/null +++ b/database/common/errors.go @@ -0,0 +1,12 @@ +package common + +import "fmt" + +var ( + ErrProducerClosed = fmt.Errorf("producer is closed") + ErrProducerTimeoutErr = fmt.Errorf("producer timeout error") + ErrProducerAlreadyRegistered = fmt.Errorf("producer already registered") + ErrConsumerAlreadyRegistered = fmt.Errorf("consumer already registered") + ErrWatcherAlreadyStarted = fmt.Errorf("watcher already started") + ErrWatcherNotInitialized = fmt.Errorf("watcher not initialized") +) diff --git a/database/common/common.go b/database/common/store.go similarity index 98% rename from database/common/common.go rename to database/common/store.go index 4f0df368..18075c1d 100644 --- a/database/common/common.go +++ b/database/common/store.go @@ -119,7 +119,7 @@ type JobsStore interface { DeleteCompletedJobs(ctx context.Context) error } -type EntityPools interface { +type EntityPoolStore interface { CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error @@ -144,8 +144,11 @@ type Store interface { UserStore InstanceStore JobsStore - EntityPools GithubEndpointStore GithubCredentialsStore ControllerStore + EntityPoolStore + + ControllerInfo() (params.ControllerInfo, error) + InitController() (params.ControllerInfo, error) } diff --git a/database/common/watcher.go b/database/common/watcher.go new file mode 100644 index 00000000..4903e4ab --- /dev/null +++ b/database/common/watcher.go @@ -0,0 +1,50 @@ +package common + +type ( + DatabaseEntityType string + OperationType string + PayloadFilterFunc func(ChangePayload) bool +) + +const ( + RepositoryEntityType DatabaseEntityType = "repository" + OrganizationEntityType DatabaseEntityType = "organization" + EnterpriseEntityType DatabaseEntityType = "enterprise" + PoolEntityType DatabaseEntityType = "pool" + UserEntityType DatabaseEntityType = "user" + InstanceEntityType DatabaseEntityType = "instance" + JobEntityType DatabaseEntityType = "job" + ControllerEntityType DatabaseEntityType = "controller" + GithubCredentialsEntityType DatabaseEntityType = "github_credentials" + GithubEndpointEntityType DatabaseEntityType = "github_endpoint" +) + +const ( + CreateOperation OperationType = "create" + UpdateOperation OperationType = "update" + DeleteOperation OperationType = "delete" +) + +type ChangePayload struct { + EntityType DatabaseEntityType + Operation OperationType + Payload interface{} +} + +type Consumer interface { + Watch() <-chan ChangePayload + IsClosed() bool + Close() + SetFilters(filters ...PayloadFilterFunc) +} + +type Producer interface { + Notify(ChangePayload) error + IsClosed() bool + Close() +} + +type Watcher interface { + RegisterProducer(ID string) (Producer, error) + RegisterConsumer(ID string, filters ...PayloadFilterFunc) (Consumer, error) +} diff --git a/database/sql/instances.go b/database/sql/instances.go index 1544090b..c09b60f3 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -25,29 +25,9 @@ import ( "gorm.io/gorm/clause" runnerErrors "github.com/cloudbase/garm-provider-common/errors" - "github.com/cloudbase/garm-provider-common/util" "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) marshalAndSeal(data interface{}) ([]byte, error) { - enc, err := json.Marshal(data) - if err != nil { - return nil, errors.Wrap(err, "marshalling data") - } - return util.Seal(enc, []byte(s.cfg.Passphrase)) -} - -func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error { - decrypted, err := util.Unseal(data, []byte(s.cfg.Passphrase)) - if err != nil { - return errors.Wrap(err, "decrypting data") - } - if err := json.Unmarshal(decrypted, target); err != nil { - return errors.Wrap(err, "unmarshalling data") - } - return nil -} - func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { pool, err := s.getPoolByID(s.conn, poolID) if err != nil { diff --git a/database/sql/pools.go b/database/sql/pools.go index 161b1d58..4d33343d 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -24,6 +24,7 @@ import ( "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) @@ -66,12 +67,18 @@ func (s *sqlDatabase) GetPoolByID(_ context.Context, poolID string) (params.Pool return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) error { +func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) (err error) { pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return errors.Wrap(err, "fetching pool by ID") } + defer func() { + if err == nil { + s.sendNotify(common.PoolEntityType, common.DeleteOperation, pool) + } + }() + if q := s.conn.Unscoped().Delete(&pool); q.Error != nil { return errors.Wrap(q.Error, "removing pool") } @@ -247,11 +254,17 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par return pools, nil } -func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) { +func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (pool params.Pool, err error) { if len(param.Tags) == 0 { return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") } + defer func() { + if err == nil { + s.sendNotify(common.PoolEntityType, common.CreateOperation, pool) + } + }() + newPool := Pool{ ProviderName: param.ProviderName, MaxRunners: param.MaxRunners, @@ -313,12 +326,12 @@ func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEn return params.Pool{}, err } - pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + dbPool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } - return s.sqlToCommonPool(pool) + return s.sqlToCommonPool(dbPool) } func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) { @@ -329,12 +342,21 @@ func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntit return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) error { +func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (err error) { entityID, err := uuid.Parse(entity.ID) if err != nil { return errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } + defer func() { + if err == nil { + pool := params.Pool{ + ID: poolID, + } + s.sendNotify(common.PoolEntityType, common.DeleteOperation, pool) + } + }() + poolUUID, err := uuid.Parse(poolID) if err != nil { return errors.Wrap(runnerErrors.ErrBadRequest, "parsing pool id") @@ -374,6 +396,7 @@ func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEn if err != nil { return params.Pool{}, err } + s.sendNotify(common.PoolEntityType, common.UpdateOperation, updatedPool) return updatedPool, nil } diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 510a9959..7ab1c522 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -24,10 +24,17 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm-provider-common/util" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) { +func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (param params.Repository, err error) { + defer func() { + if err == nil { + s.sendNotify(common.RepositoryEntityType, common.CreateOperation, param) + } + }() + if webhookSecret == "" { return params.Repository{}, errors.New("creating repo: missing secret") } @@ -68,7 +75,7 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credent return params.Repository{}, errors.Wrap(err, "creating repository") } - param, err := s.sqlToCommonRepository(newRepo, true) + param, err = s.sqlToCommonRepository(newRepo, true) if err != nil { return params.Repository{}, errors.Wrap(err, "creating repository") } @@ -113,12 +120,21 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, return ret, nil } -func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error { +func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) { repo, err := s.getRepoByID(ctx, s.conn, repoID) if err != nil { return errors.Wrap(err, "fetching repo") } + defer func(repo Repository) { + if err == nil { + asParam, innerErr := s.sqlToCommonRepository(repo, true) + if innerErr == nil { + s.sendNotify(common.RepositoryEntityType, common.DeleteOperation, asParam) + } + } + }(repo) + q := s.conn.Unscoped().Delete(&repo) if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { return errors.Wrap(q.Error, "deleting repo") @@ -127,10 +143,15 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error return nil } -func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) { +func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (newParams params.Repository, err error) { + defer func() { + if err == nil { + s.sendNotify(common.RepositoryEntityType, common.UpdateOperation, newParams) + } + }() var repo Repository var creds GithubCredentials - err := s.conn.Transaction(func(tx *gorm.DB) error { + err = s.conn.Transaction(func(tx *gorm.DB) error { var err error repo, err = s.getRepoByID(ctx, tx, repoID) if err != nil { @@ -186,7 +207,8 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param if err != nil { return params.Repository{}, errors.Wrap(err, "updating enterprise") } - newParams, err := s.sqlToCommonRepository(repo, true) + + newParams, err = s.sqlToCommonRepository(repo, true) if err != nil { return params.Repository{}, errors.Wrap(err, "saving repo") } diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 6eb20a2c..34b07bd1 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -30,6 +30,7 @@ import ( "github.com/cloudbase/garm/auth" dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" garmTesting "github.com/cloudbase/garm/internal/testing" "github.com/cloudbase/garm/params" ) @@ -827,5 +828,11 @@ func (s *RepoTestSuite) TestUpdateRepositoryPoolInvalidRepoID() { func TestRepoTestSuite(t *testing.T) { t.Parallel() + + watcher.SetWatcher(&garmTesting.MockWatcher{}) suite.Run(t, new(RepoTestSuite)) } + +func init() { + watcher.SetWatcher(&garmTesting.MockWatcher{}) +} diff --git a/database/sql/sql.go b/database/sql/sql.go index 5f970558..6ee8a2d9 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -31,6 +31,7 @@ import ( "github.com/cloudbase/garm/auth" "github.com/cloudbase/garm/config" "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/util/appdefaults" ) @@ -68,10 +69,15 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err if err != nil { return nil, errors.Wrap(err, "creating DB connection") } + producer, err := watcher.RegisterProducer("sql") + if err != nil { + return nil, errors.Wrap(err, "registering producer") + } db := &sqlDatabase{ - conn: conn, - ctx: ctx, - cfg: cfg, + conn: conn, + ctx: ctx, + cfg: cfg, + producer: producer, } if err := db.migrateDB(); err != nil { @@ -81,9 +87,10 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err } type sqlDatabase struct { - conn *gorm.DB - ctx context.Context - cfg config.Database + conn *gorm.DB + ctx context.Context + cfg config.Database + producer common.Producer } var renameTemplate = ` diff --git a/database/sql/util.go b/database/sql/util.go index b7f8a058..5814483d 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -26,6 +26,7 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm-provider-common/util" + dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) @@ -467,3 +468,31 @@ func (s *sqlDatabase) hasGithubEntity(tx *gorm.DB, entityType params.GithubEntit } return nil } + +func (s *sqlDatabase) marshalAndSeal(data interface{}) ([]byte, error) { + enc, err := json.Marshal(data) + if err != nil { + return nil, errors.Wrap(err, "marshalling data") + } + return util.Seal(enc, []byte(s.cfg.Passphrase)) +} + +func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error { + decrypted, err := util.Unseal(data, []byte(s.cfg.Passphrase)) + if err != nil { + return errors.Wrap(err, "decrypting data") + } + if err := json.Unmarshal(decrypted, target); err != nil { + return errors.Wrap(err, "unmarshalling data") + } + return nil +} + +func (s *sqlDatabase) sendNotify(entityType dbCommon.DatabaseEntityType, op dbCommon.OperationType, payload interface{}) { + message := dbCommon.ChangePayload{ + Operation: op, + Payload: payload, + EntityType: entityType, + } + s.producer.Notify(message) +} diff --git a/database/watcher/consumer.go b/database/watcher/consumer.go new file mode 100644 index 00000000..369344ba --- /dev/null +++ b/database/watcher/consumer.go @@ -0,0 +1,75 @@ +package watcher + +import ( + "log/slog" + "sync" + "time" + + "github.com/cloudbase/garm/database/common" +) + +type consumer struct { + messages chan common.ChangePayload + filters []common.PayloadFilterFunc + id string + + mux sync.Mutex + closed bool + quit chan struct{} +} + +func (w *consumer) SetFilters(filters ...common.PayloadFilterFunc) { + w.mux.Lock() + defer w.mux.Unlock() + w.filters = filters +} + +func (w *consumer) Watch() <-chan common.ChangePayload { + return w.messages +} + +func (w *consumer) Close() { + w.mux.Lock() + defer w.mux.Unlock() + if w.closed { + return + } + close(w.messages) + close(w.quit) + w.closed = true +} + +func (w *consumer) IsClosed() bool { + w.mux.Lock() + defer w.mux.Unlock() + return w.closed +} + +func (w *consumer) Send(payload common.ChangePayload) { + w.mux.Lock() + defer w.mux.Unlock() + + if w.closed { + return + } + + if len(w.filters) > 0 { + shouldSend := false + for _, filter := range w.filters { + if filter(payload) { + shouldSend = true + break + } + } + + if !shouldSend { + return + } + } + + slog.Info("Sending payload to consumer", "consumer", w.id) + select { + case w.messages <- payload: + case <-time.After(1 * time.Second): + } +} diff --git a/database/watcher/producer.go b/database/watcher/producer.go new file mode 100644 index 00000000..70578004 --- /dev/null +++ b/database/watcher/producer.go @@ -0,0 +1,49 @@ +package watcher + +import ( + "sync" + + "github.com/cloudbase/garm/database/common" +) + +type producer struct { + closed bool + mux sync.Mutex + id string + + messages chan common.ChangePayload + quit chan struct{} +} + +func (w *producer) Notify(payload common.ChangePayload) error { + w.mux.Lock() + defer w.mux.Unlock() + + if w.closed { + return common.ErrProducerClosed + } + + select { + case w.messages <- payload: + default: + return common.ErrProducerTimeoutErr + } + return nil +} + +func (w *producer) Close() { + w.mux.Lock() + defer w.mux.Unlock() + if w.closed { + return + } + w.closed = true + close(w.messages) + close(w.quit) +} + +func (w *producer) IsClosed() bool { + w.mux.Lock() + defer w.mux.Unlock() + return w.closed +} diff --git a/database/watcher/test_export.go b/database/watcher/test_export.go new file mode 100644 index 00000000..4c75233e --- /dev/null +++ b/database/watcher/test_export.go @@ -0,0 +1,12 @@ +//go:build testing +// +build testing + +package watcher + +import "github.com/cloudbase/garm/database/common" + +// SetWatcher sets the watcher to be used by the database package. +// This function is intended for use in tests only. +func SetWatcher(w common.Watcher) { + databaseWatcher = w +} diff --git a/database/watcher/watcher.go b/database/watcher/watcher.go new file mode 100644 index 00000000..23400e21 --- /dev/null +++ b/database/watcher/watcher.go @@ -0,0 +1,151 @@ +package watcher + +import ( + "context" + "sync" + + "github.com/cloudbase/garm/database/common" +) + +var databaseWatcher common.Watcher + +func InitWatcher(ctx context.Context) { + if databaseWatcher != nil { + return + } + w := &watcher{ + producers: make(map[string]*producer), + consumers: make(map[string]*consumer), + quit: make(chan struct{}), + ctx: ctx, + } + + go w.loop() + databaseWatcher = w +} + +func RegisterProducer(id string) (common.Producer, error) { + if databaseWatcher == nil { + return nil, common.ErrWatcherNotInitialized + } + return databaseWatcher.RegisterProducer(id) +} + +func RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { + if databaseWatcher == nil { + return nil, common.ErrWatcherNotInitialized + } + return databaseWatcher.RegisterConsumer(id, filters...) +} + +type watcher struct { + producers map[string]*producer + consumers map[string]*consumer + + mux sync.Mutex + closed bool + quit chan struct{} + ctx context.Context +} + +func (w *watcher) RegisterProducer(id string) (common.Producer, error) { + if _, ok := w.producers[id]; ok { + return nil, common.ErrProducerAlreadyRegistered + } + p := &producer{ + id: id, + messages: make(chan common.ChangePayload, 1), + quit: make(chan struct{}), + } + w.producers[id] = p + go w.serviceProducer(p) + return p, nil +} + +func (w *watcher) serviceProducer(prod *producer) { + defer func() { + w.mux.Lock() + defer w.mux.Unlock() + prod.Close() + delete(w.producers, prod.id) + }() + for { + select { + case <-w.quit: + return + case <-w.ctx.Done(): + return + case payload := <-prod.messages: + for _, c := range w.consumers { + go c.Send(payload) + } + } + } +} + +func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { + if _, ok := w.consumers[id]; ok { + return nil, common.ErrConsumerAlreadyRegistered + } + c := &consumer{ + messages: make(chan common.ChangePayload, 1), + filters: filters, + quit: make(chan struct{}), + id: id, + } + w.consumers[id] = c + go w.serviceConsumer(c) + return c, nil +} + +func (w *watcher) serviceConsumer(consumer *consumer) { + defer func() { + w.mux.Lock() + defer w.mux.Unlock() + consumer.Close() + delete(w.consumers, consumer.id) + }() + for { + select { + case <-consumer.quit: + return + case <-w.quit: + return + case <-w.ctx.Done(): + return + } + } +} + +func (w *watcher) Close() { + w.mux.Lock() + defer w.mux.Unlock() + if w.closed { + return + } + + close(w.quit) + w.closed = true + + for _, p := range w.producers { + p.Close() + } + + for _, c := range w.consumers { + c.Close() + } +} + +func (w *watcher) loop() { + defer func() { + w.Close() + }() + for { + select { + case <-w.quit: + return + case <-w.ctx.Done(): + return + } + } +} diff --git a/internal/testing/mock_watcher.go b/internal/testing/mock_watcher.go new file mode 100644 index 00000000..394091bd --- /dev/null +++ b/internal/testing/mock_watcher.go @@ -0,0 +1,45 @@ +//go:build testing +// +build testing + +package testing + +import "github.com/cloudbase/garm/database/common" + +type MockWatcher struct{} + +func (w *MockWatcher) RegisterProducer(_ string) (common.Producer, error) { + return &MockProducer{}, nil +} + +func (w *MockWatcher) RegisterConsumer(_ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) { + return &MockConsumer{}, nil +} + +type MockProducer struct{} + +func (p *MockProducer) Notify(_ common.ChangePayload) error { + return nil +} + +func (p *MockProducer) IsClosed() bool { + return false +} + +func (p *MockProducer) Close() { +} + +type MockConsumer struct{} + +func (c *MockConsumer) Watch() <-chan common.ChangePayload { + return nil +} + +func (c *MockConsumer) SetFilters(_ ...common.PayloadFilterFunc) { +} + +func (c *MockConsumer) Close() { +} + +func (c *MockConsumer) IsClosed() bool { + return false +} diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 1261ad21..f7a8ea68 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -69,7 +69,7 @@ type urls struct { } func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) { - ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", params.GithubEntityTypeRepository)) + ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", entity.EntityType)) ghc, err := garmUtil.GithubClient(ctx, entity, cfgInternal.GithubCredentialsDetails) if err != nil { return nil, errors.Wrap(err, "getting github client") diff --git a/runner/repositories_test.go b/runner/repositories_test.go index d0a6ab61..f17bd93a 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -26,6 +26,7 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm/database" dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" garmTesting "github.com/cloudbase/garm/internal/testing" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner/common" @@ -51,6 +52,10 @@ type RepoTestFixtures struct { PoolMgrCtrlMock *runnerMocks.PoolManagerController } +func init() { + watcher.SetWatcher(&garmTesting.MockWatcher{}) +} + type RepoTestSuite struct { suite.Suite Fixtures *RepoTestFixtures From 7f9db2e4139ef95dfa2d6724dbb3849bc2418091 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Fri, 14 Jun 2024 20:24:45 +0000 Subject: [PATCH 02/11] Send notify on update controller Signed-off-by: Gabriel Adrian Samfira --- database/common/watcher.go | 2 +- database/sql/controller.go | 58 +++++++++++++++++++++++--------------- database/sql/pools.go | 11 +++++--- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/database/common/watcher.go b/database/common/watcher.go index 4903e4ab..69bf9788 100644 --- a/database/common/watcher.go +++ b/database/common/watcher.go @@ -15,7 +15,7 @@ const ( InstanceEntityType DatabaseEntityType = "instance" JobEntityType DatabaseEntityType = "job" ControllerEntityType DatabaseEntityType = "controller" - GithubCredentialsEntityType DatabaseEntityType = "github_credentials" + GithubCredentialsEntityType DatabaseEntityType = "github_credentials" // #nosec G101 GithubEndpointEntityType DatabaseEntityType = "github_endpoint" ) diff --git a/database/sql/controller.go b/database/sql/controller.go index 8d6c3477..c5b900f3 100644 --- a/database/sql/controller.go +++ b/database/sql/controller.go @@ -22,6 +22,7 @@ import ( "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) @@ -82,38 +83,49 @@ func (s *sqlDatabase) InitController() (params.ControllerInfo, error) { }, nil } -func (s *sqlDatabase) UpdateController(info params.UpdateControllerParams) (params.ControllerInfo, error) { +func (s *sqlDatabase) UpdateController(info params.UpdateControllerParams) (paramInfo params.ControllerInfo, err error) { + defer func() { + if err == nil { + s.sendNotify(common.ControllerEntityType, common.UpdateOperation, paramInfo) + } + }() var dbInfo ControllerInfo - q := s.conn.Model(&ControllerInfo{}).First(&dbInfo) - if q.Error != nil { - if errors.Is(q.Error, gorm.ErrRecordNotFound) { - return params.ControllerInfo{}, errors.Wrap(runnerErrors.ErrNotFound, "fetching controller info") + err = s.conn.Transaction(func(tx *gorm.DB) error { + q := tx.Model(&ControllerInfo{}).First(&dbInfo) + if q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return errors.Wrap(runnerErrors.ErrNotFound, "fetching controller info") + } + return errors.Wrap(q.Error, "fetching controller info") } - return params.ControllerInfo{}, errors.Wrap(q.Error, "fetching controller info") - } - if err := info.Validate(); err != nil { - return params.ControllerInfo{}, errors.Wrap(err, "validating controller info") - } + if err := info.Validate(); err != nil { + return errors.Wrap(err, "validating controller info") + } - if info.MetadataURL != nil { - dbInfo.MetadataURL = *info.MetadataURL - } + if info.MetadataURL != nil { + dbInfo.MetadataURL = *info.MetadataURL + } - if info.CallbackURL != nil { - dbInfo.CallbackURL = *info.CallbackURL - } + if info.CallbackURL != nil { + dbInfo.CallbackURL = *info.CallbackURL + } - if info.WebhookURL != nil { - dbInfo.WebhookBaseURL = *info.WebhookURL - } + if info.WebhookURL != nil { + dbInfo.WebhookBaseURL = *info.WebhookURL + } - q = s.conn.Save(&dbInfo) - if q.Error != nil { - return params.ControllerInfo{}, errors.Wrap(q.Error, "saving controller info") + q = tx.Save(&dbInfo) + if q.Error != nil { + return errors.Wrap(q.Error, "saving controller info") + } + return nil + }) + if err != nil { + return params.ControllerInfo{}, errors.Wrap(err, "updating controller info") } - paramInfo, err := dbControllerToCommonController(dbInfo) + paramInfo, err = dbControllerToCommonController(dbInfo) if err != nil { return params.ControllerInfo{}, errors.Wrap(err, "converting controller info") } diff --git a/database/sql/pools.go b/database/sql/pools.go index 4d33343d..01d1afc4 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -379,9 +379,13 @@ func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEn return nil } -func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - var updatedPool params.Pool - err := s.conn.Transaction(func(tx *gorm.DB) error { +func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (updatedPool params.Pool, err error) { + defer func() { + if err == nil { + s.sendNotify(common.PoolEntityType, common.UpdateOperation, updatedPool) + } + }() + err = s.conn.Transaction(func(tx *gorm.DB) error { pool, err := s.getEntityPool(tx, entity.EntityType, entity.ID, poolID, "Tags", "Instances") if err != nil { return errors.Wrap(err, "fetching pool") @@ -396,7 +400,6 @@ func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEn if err != nil { return params.Pool{}, err } - s.sendNotify(common.PoolEntityType, common.UpdateOperation, updatedPool) return updatedPool, nil } From 6051629810d135a9209349a51c5aef1d61e3472a Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Fri, 14 Jun 2024 20:35:02 +0000 Subject: [PATCH 03/11] Add watcher for github creds and endpoints Signed-off-by: Gabriel Adrian Samfira --- database/sql/github.go | 84 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 15 deletions(-) diff --git a/database/sql/github.go b/database/sql/github.go index e62cc8ba..b0911222 100644 --- a/database/sql/github.go +++ b/database/sql/github.go @@ -24,6 +24,7 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm-provider-common/util" "github.com/cloudbase/garm/auth" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) @@ -109,9 +110,14 @@ func getUIDFromContext(ctx context.Context) (uuid.UUID, error) { return asUUID, nil } -func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.CreateGithubEndpointParams) (params.GithubEndpoint, error) { +func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.CreateGithubEndpointParams) (ghEndpoint params.GithubEndpoint, err error) { + defer func() { + if err == nil { + s.sendNotify(common.GithubEndpointEntityType, common.CreateOperation, ghEndpoint) + } + }() var endpoint GithubEndpoint - err := s.conn.Transaction(func(tx *gorm.DB) error { + err = s.conn.Transaction(func(tx *gorm.DB) error { if err := tx.Where("name = ?", param.Name).First(&endpoint).Error; err == nil { return errors.Wrap(runnerErrors.ErrDuplicateEntity, "github endpoint already exists") } @@ -132,7 +138,11 @@ func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.Creat if err != nil { return params.GithubEndpoint{}, errors.Wrap(err, "creating github endpoint") } - return s.sqlToCommonGithubEndpoint(endpoint) + ghEndpoint, err = s.sqlToCommonGithubEndpoint(endpoint) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "converting github endpoint") + } + return ghEndpoint, nil } func (s *sqlDatabase) ListGithubEndpoints(_ context.Context) ([]params.GithubEndpoint, error) { @@ -153,12 +163,18 @@ func (s *sqlDatabase) ListGithubEndpoints(_ context.Context) ([]params.GithubEnd return ret, nil } -func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param params.UpdateGithubEndpointParams) (params.GithubEndpoint, error) { +func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param params.UpdateGithubEndpointParams) (ghEndpoint params.GithubEndpoint, err error) { if name == defaultGithubEndpoint { return params.GithubEndpoint{}, errors.Wrap(runnerErrors.ErrBadRequest, "cannot update default github endpoint") } + + defer func() { + if err == nil { + s.sendNotify(common.GithubEndpointEntityType, common.UpdateOperation, ghEndpoint) + } + }() var endpoint GithubEndpoint - err := s.conn.Transaction(func(tx *gorm.DB) error { + err = s.conn.Transaction(func(tx *gorm.DB) error { if err := tx.Where("name = ?", name).First(&endpoint).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.Wrap(runnerErrors.ErrNotFound, "github endpoint not found") @@ -194,7 +210,11 @@ func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param if err != nil { return params.GithubEndpoint{}, errors.Wrap(err, "updating github endpoint") } - return s.sqlToCommonGithubEndpoint(endpoint) + ghEndpoint, err = s.sqlToCommonGithubEndpoint(endpoint) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "converting github endpoint") + } + return ghEndpoint, nil } func (s *sqlDatabase) GetGithubEndpoint(_ context.Context, name string) (params.GithubEndpoint, error) { @@ -211,11 +231,17 @@ func (s *sqlDatabase) GetGithubEndpoint(_ context.Context, name string) (params. return s.sqlToCommonGithubEndpoint(endpoint) } -func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) error { +func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) (err error) { if name == defaultGithubEndpoint { return errors.Wrap(runnerErrors.ErrBadRequest, "cannot delete default github endpoint") } - err := s.conn.Transaction(func(tx *gorm.DB) error { + + defer func() { + if err == nil { + s.sendNotify(common.GithubEndpointEntityType, common.DeleteOperation, params.GithubEndpoint{Name: name}) + } + }() + err = s.conn.Transaction(func(tx *gorm.DB) error { var endpoint GithubEndpoint if err := tx.Where("name = ?", name).First(&endpoint).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -267,7 +293,7 @@ func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) error return nil } -func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params.CreateGithubCredentialsParams) (params.GithubCredentials, error) { +func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params.CreateGithubCredentialsParams) (ghCreds params.GithubCredentials, err error) { userID, err := getUIDFromContext(ctx) if err != nil { return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials") @@ -275,6 +301,12 @@ func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params. if param.Endpoint == "" { return params.GithubCredentials{}, errors.Wrap(runnerErrors.ErrBadRequest, "endpoint name is required") } + + defer func() { + if err == nil { + s.sendNotify(common.GithubCredentialsEntityType, common.CreateOperation, ghCreds) + } + }() var creds GithubCredentials err = s.conn.Transaction(func(tx *gorm.DB) error { var endpoint GithubEndpoint @@ -323,7 +355,11 @@ func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params. if err != nil { return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials") } - return s.sqlToCommonGithubCredentials(creds) + ghCreds, err = s.sqlToCommonGithubCredentials(creds) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "converting github credentials") + } + return ghCreds, nil } func (s *sqlDatabase) getGithubCredentialsByName(ctx context.Context, tx *gorm.DB, name string, detailed bool) (GithubCredentials, error) { @@ -420,9 +456,14 @@ func (s *sqlDatabase) ListGithubCredentials(ctx context.Context) ([]params.Githu return ret, nil } -func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (params.GithubCredentials, error) { +func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (ghCreds params.GithubCredentials, err error) { + defer func() { + if err == nil { + s.sendNotify(common.GithubCredentialsEntityType, common.UpdateOperation, ghCreds) + } + }() var creds GithubCredentials - err := s.conn.Transaction(func(tx *gorm.DB) error { + err = s.conn.Transaction(func(tx *gorm.DB) error { q := tx.Preload("Endpoint") if !auth.IsAdmin(ctx) { userID, err := getUIDFromContext(ctx) @@ -486,11 +527,22 @@ func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, para if err != nil { return params.GithubCredentials{}, errors.Wrap(err, "updating github credentials") } - return s.sqlToCommonGithubCredentials(creds) + + ghCreds, err = s.sqlToCommonGithubCredentials(creds) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "converting github credentials") + } + return ghCreds, nil } -func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) error { - err := s.conn.Transaction(func(tx *gorm.DB) error { +func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) (err error) { + var name string + defer func() { + if err == nil { + s.sendNotify(common.GithubCredentialsEntityType, common.DeleteOperation, params.GithubCredentials{ID: id, Name: name}) + } + }() + err = s.conn.Transaction(func(tx *gorm.DB) error { q := tx.Where("id = ?", id). Preload("Repositories"). Preload("Organizations"). @@ -511,6 +563,8 @@ func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) erro } return errors.Wrap(err, "fetching github credentials") } + name = creds.Name + if len(creds.Repositories) > 0 { return errors.Wrap(runnerErrors.ErrBadRequest, "cannot delete credentials with repositories") } From b51683f1ae0427c314cfda4d22730d77b2075ebd Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 17 Jun 2024 19:42:50 +0000 Subject: [PATCH 04/11] Add some tests Signed-off-by: Gabriel Adrian Samfira --- database/common/watcher.go | 7 +- database/sql/sql.go | 2 +- database/watcher/consumer.go | 17 ++- database/watcher/filters.go | 141 ++++++++++++++++++++++ database/watcher/producer.go | 11 +- database/watcher/test_export.go | 5 + database/watcher/watcher.go | 28 ++++- database/watcher/watcher_store_test.go | 45 +++++++ database/watcher/watcher_test.go | 159 +++++++++++++++++++++++++ internal/testing/mock_watcher.go | 13 +- 10 files changed, 409 insertions(+), 19 deletions(-) create mode 100644 database/watcher/filters.go create mode 100644 database/watcher/watcher_store_test.go create mode 100644 database/watcher/watcher_test.go diff --git a/database/common/watcher.go b/database/common/watcher.go index 69bf9788..73af32bd 100644 --- a/database/common/watcher.go +++ b/database/common/watcher.go @@ -1,5 +1,7 @@ package common +import "context" + type ( DatabaseEntityType string OperationType string @@ -45,6 +47,7 @@ type Producer interface { } type Watcher interface { - RegisterProducer(ID string) (Producer, error) - RegisterConsumer(ID string, filters ...PayloadFilterFunc) (Consumer, error) + RegisterProducer(ctx context.Context, ID string) (Producer, error) + RegisterConsumer(ctx context.Context, ID string, filters ...PayloadFilterFunc) (Consumer, error) + Close() } diff --git a/database/sql/sql.go b/database/sql/sql.go index 6ee8a2d9..937ef676 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -69,7 +69,7 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err if err != nil { return nil, errors.Wrap(err, "creating DB connection") } - producer, err := watcher.RegisterProducer("sql") + producer, err := watcher.RegisterProducer(ctx, "sql") if err != nil { return nil, errors.Wrap(err, "registering producer") } diff --git a/database/watcher/consumer.go b/database/watcher/consumer.go index 369344ba..fb36c694 100644 --- a/database/watcher/consumer.go +++ b/database/watcher/consumer.go @@ -1,6 +1,7 @@ package watcher import ( + "context" "log/slog" "sync" "time" @@ -16,6 +17,7 @@ type consumer struct { mux sync.Mutex closed bool quit chan struct{} + ctx context.Context } func (w *consumer) SetFilters(filters ...common.PayloadFilterFunc) { @@ -54,10 +56,10 @@ func (w *consumer) Send(payload common.ChangePayload) { } if len(w.filters) > 0 { - shouldSend := false + shouldSend := true for _, filter := range w.filters { - if filter(payload) { - shouldSend = true + if !filter(payload) { + shouldSend = false break } } @@ -67,9 +69,14 @@ func (w *consumer) Send(payload common.ChangePayload) { } } - slog.Info("Sending payload to consumer", "consumer", w.id) + slog.DebugContext(w.ctx, "sending payload") select { - case w.messages <- payload: + case <-w.quit: + slog.DebugContext(w.ctx, "consumer is closed") + case <-w.ctx.Done(): + slog.DebugContext(w.ctx, "consumer is closed") case <-time.After(1 * time.Second): + slog.DebugContext(w.ctx, "timeout trying to send payload", "payload", payload) + case w.messages <- payload: } } diff --git a/database/watcher/filters.go b/database/watcher/filters.go new file mode 100644 index 00000000..9b175d7a --- /dev/null +++ b/database/watcher/filters.go @@ -0,0 +1,141 @@ +package watcher + +import ( + dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/params" +) + +type idGetter interface { + GetID() string +} + +// WithAny returns a filter function that returns true if any of the provided filters return true. +// This filter is useful if for example you want to watch for update operations on any of the supplied +// entities. +// Example: +// +// // Watch for any update operation on repositories or organizations +// consumer.SetFilters( +// watcher.WithOperationTypeFilter(common.UpdateOperation), +// watcher.WithAny( +// watcher.WithEntityTypeFilter(common.RepositoryEntityType), +// watcher.WithEntityTypeFilter(common.OrganizationEntityType), +// )) +func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + for _, filter := range filters { + if filter(payload) { + return true + } + } + return false + } +} + +// WithEntityTypeFilter returns a filter function that filters payloads by entity type. +// The filter function returns true if the payload's entity type matches the provided entity type. +func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + return payload.EntityType == entityType + } +} + +// WithOperationTypeFilter returns a filter function that filters payloads by operation type. +func WithOperationTypeFilter(operationType dbCommon.OperationType) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + return payload.Operation == operationType + } +} + +// WithEntityPoolFilter returns true if the change payload is a pool that belongs to the +// supplied Github entity. This is useful when an entity worker wants to watch for changes +// in pools that belong to it. +func WithEntityPoolFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + switch payload.EntityType { + case dbCommon.PoolEntityType: + pool, ok := payload.Payload.(params.Pool) + if !ok { + return false + } + switch ghEntity.EntityType { + case params.GithubEntityTypeRepository: + if pool.RepoID != ghEntity.ID { + return false + } + case params.GithubEntityTypeOrganization: + if pool.OrgID != ghEntity.ID { + return false + } + case params.GithubEntityTypeEnterprise: + if pool.EnterpriseID != ghEntity.ID { + return false + } + default: + return false + } + return true + default: + return false + } + } +} + +// WithEntityFilter returns a filter function that filters payloads by entity. +// Change payloads that match the entity type and ID will return true. +func WithEntityFilter(entity params.GithubEntity) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + if params.GithubEntityType(payload.EntityType) != entity.EntityType { + return false + } + var ent idGetter + var ok bool + switch payload.EntityType { + case dbCommon.RepositoryEntityType: + ent, ok = payload.Payload.(params.Repository) + case dbCommon.OrganizationEntityType: + ent, ok = payload.Payload.(params.Organization) + case dbCommon.EnterpriseEntityType: + ent, ok = payload.Payload.(params.Enterprise) + default: + return false + } + if !ok { + return false + } + return ent.GetID() == entity.ID + } +} + +func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + switch payload.EntityType { + case dbCommon.JobEntityType: + job, ok := payload.Payload.(params.Job) + if !ok { + return false + } + + switch ghEntity.EntityType { + case params.GithubEntityTypeRepository: + if job.RepoID != nil && job.RepoID.String() != ghEntity.ID { + return false + } + case params.GithubEntityTypeOrganization: + if job.OrgID != nil && job.OrgID.String() != ghEntity.ID { + return false + } + case params.GithubEntityTypeEnterprise: + if job.EnterpriseID != nil && job.EnterpriseID.String() != ghEntity.ID { + return false + } + default: + return false + } + + return true + default: + return false + } + } +} diff --git a/database/watcher/producer.go b/database/watcher/producer.go index 70578004..fd61aa16 100644 --- a/database/watcher/producer.go +++ b/database/watcher/producer.go @@ -1,7 +1,9 @@ package watcher import ( + "context" "sync" + "time" "github.com/cloudbase/garm/database/common" ) @@ -13,6 +15,7 @@ type producer struct { messages chan common.ChangePayload quit chan struct{} + ctx context.Context } func (w *producer) Notify(payload common.ChangePayload) error { @@ -24,9 +27,13 @@ func (w *producer) Notify(payload common.ChangePayload) error { } select { - case w.messages <- payload: - default: + case <-w.quit: + return common.ErrProducerClosed + case <-w.ctx.Done(): + return common.ErrProducerClosed + case <-time.After(1 * time.Second): return common.ErrProducerTimeoutErr + case w.messages <- payload: } return nil } diff --git a/database/watcher/test_export.go b/database/watcher/test_export.go index 4c75233e..f9b4ecf1 100644 --- a/database/watcher/test_export.go +++ b/database/watcher/test_export.go @@ -10,3 +10,8 @@ import "github.com/cloudbase/garm/database/common" func SetWatcher(w common.Watcher) { databaseWatcher = w } + +// GetWatcher returns the current watcher. +func GetWatcher() common.Watcher { + return databaseWatcher +} diff --git a/database/watcher/watcher.go b/database/watcher/watcher.go index 23400e21..86ba594e 100644 --- a/database/watcher/watcher.go +++ b/database/watcher/watcher.go @@ -2,9 +2,11 @@ package watcher import ( "context" + "log/slog" "sync" "github.com/cloudbase/garm/database/common" + garmUtil "github.com/cloudbase/garm/util" ) var databaseWatcher common.Watcher @@ -13,6 +15,7 @@ func InitWatcher(ctx context.Context) { if databaseWatcher != nil { return } + ctx = garmUtil.WithContext(ctx, slog.Any("watcher", "database")) w := &watcher{ producers: make(map[string]*producer), consumers: make(map[string]*consumer), @@ -24,18 +27,20 @@ func InitWatcher(ctx context.Context) { databaseWatcher = w } -func RegisterProducer(id string) (common.Producer, error) { +func RegisterProducer(ctx context.Context, id string) (common.Producer, error) { if databaseWatcher == nil { return nil, common.ErrWatcherNotInitialized } - return databaseWatcher.RegisterProducer(id) + ctx = garmUtil.WithContext(ctx, slog.Any("producer_id", id)) + return databaseWatcher.RegisterProducer(ctx, id) } -func RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { +func RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { if databaseWatcher == nil { return nil, common.ErrWatcherNotInitialized } - return databaseWatcher.RegisterConsumer(id, filters...) + ctx = garmUtil.WithContext(ctx, slog.Any("consumer_id", id)) + return databaseWatcher.RegisterConsumer(ctx, id, filters...) } type watcher struct { @@ -48,7 +53,10 @@ type watcher struct { ctx context.Context } -func (w *watcher) RegisterProducer(id string) (common.Producer, error) { +func (w *watcher) RegisterProducer(ctx context.Context, id string) (common.Producer, error) { + w.mux.Lock() + defer w.mux.Unlock() + if _, ok := w.producers[id]; ok { return nil, common.ErrProducerAlreadyRegistered } @@ -56,6 +64,7 @@ func (w *watcher) RegisterProducer(id string) (common.Producer, error) { id: id, messages: make(chan common.ChangePayload, 1), quit: make(chan struct{}), + ctx: ctx, } w.producers[id] = p go w.serviceProducer(p) @@ -67,13 +76,16 @@ func (w *watcher) serviceProducer(prod *producer) { w.mux.Lock() defer w.mux.Unlock() prod.Close() + slog.InfoContext(w.ctx, "removing producer from watcher", "consumer_id", prod.id) delete(w.producers, prod.id) }() for { select { case <-w.quit: + slog.InfoContext(w.ctx, "shutting down watcher") return case <-w.ctx.Done(): + slog.InfoContext(w.ctx, "shutting down watcher") return case payload := <-prod.messages: for _, c := range w.consumers { @@ -83,7 +95,7 @@ func (w *watcher) serviceProducer(prod *producer) { } } -func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { +func (w *watcher) RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { if _, ok := w.consumers[id]; ok { return nil, common.ErrConsumerAlreadyRegistered } @@ -92,6 +104,7 @@ func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFun filters: filters, quit: make(chan struct{}), id: id, + ctx: ctx, } w.consumers[id] = c go w.serviceConsumer(c) @@ -103,6 +116,7 @@ func (w *watcher) serviceConsumer(consumer *consumer) { w.mux.Lock() defer w.mux.Unlock() consumer.Close() + slog.InfoContext(w.ctx, "removing consumer from watcher", "consumer_id", consumer.id) delete(w.consumers, consumer.id) }() for { @@ -134,6 +148,8 @@ func (w *watcher) Close() { for _, c := range w.consumers { c.Close() } + + databaseWatcher = nil } func (w *watcher) loop() { diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go new file mode 100644 index 00000000..b5353c03 --- /dev/null +++ b/database/watcher/watcher_store_test.go @@ -0,0 +1,45 @@ +package watcher_test + +import ( + "context" + "testing" + + "github.com/cloudbase/garm/database" + "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" + garmTesting "github.com/cloudbase/garm/internal/testing" + "github.com/stretchr/testify/suite" +) + +type WatcherStoreTestSuite struct { + suite.Suite + + store common.Store + ctx context.Context +} + +func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() { + // ghEpParams := params.CreateGithubEndpointParams{ + // Name: "test", + // Description: "test endpoint", + // APIBaseURL: "https://api.ghes.example.com", + // UploadBaseURL: "https://upload.ghes.example.com", + // BaseURL: "https://ghes.example.com", + // } + +} + +func TestWatcherStoreTestSuite(t *testing.T) { + ctx := context.TODO() + watcher.InitWatcher(ctx) + + store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t)) + if err != nil { + t.Fatalf("failed to create db connection: %s", err) + } + watcherSuite := &WatcherStoreTestSuite{ + ctx: context.TODO(), + store: store, + } + suite.Run(t, watcherSuite) +} diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go new file mode 100644 index 00000000..838cdeb0 --- /dev/null +++ b/database/watcher/watcher_test.go @@ -0,0 +1,159 @@ +//go:build testing + +package watcher_test + +import ( + "context" + "testing" + "time" + + "github.com/cloudbase/garm/database" + "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" + garmTesting "github.com/cloudbase/garm/internal/testing" + "github.com/stretchr/testify/suite" +) + +type WatcherTestSuite struct { + suite.Suite + store common.Store + ctx context.Context +} + +func (s *WatcherTestSuite) SetupTest() { + ctx := context.TODO() + watcher.InitWatcher(ctx) + + store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T())) + if err != nil { + s.T().Fatalf("failed to create db connection: %s", err) + } + s.store = store +} + +func (s *WatcherTestSuite) TearDownTest() { + s.store = nil + currentWatcher := watcher.GetWatcher() + if currentWatcher != nil { + currentWatcher.Close() + } +} + +func (s *WatcherTestSuite) TestRegisterConsumer() { + consumer, err := watcher.RegisterConsumer(s.ctx, "test") + s.Require().NoError(err) + s.Require().NotNil(consumer) + + consumer, err = watcher.RegisterConsumer(s.ctx, "test") + s.Require().Error(err) + s.Require().Nil(consumer) +} + +func (s *WatcherTestSuite) TestRegisterProducer() { + producer, err := watcher.RegisterProducer(s.ctx, "test") + s.Require().NoError(err) + s.Require().NotNil(producer) + + producer, err = watcher.RegisterProducer(s.ctx, "test") + s.Require().Error(err) + s.Require().Nil(producer) +} + +func (s *WatcherTestSuite) TestInitWatcherRanTwiceDoesNotReplaceWatcher() { + ctx := context.TODO() + currentWatcher := watcher.GetWatcher() + s.Require().NotNil(currentWatcher) + watcher.InitWatcher(ctx) + newWatcher := watcher.GetWatcher() + s.Require().Equal(currentWatcher, newWatcher) +} + +func (s *WatcherTestSuite) TestRegisterConsumerFailsIfWatcherIsNotInitialized() { + s.store = nil + currentWatcher := watcher.GetWatcher() + currentWatcher.Close() + + consumer, err := watcher.RegisterConsumer(s.ctx, "test") + s.Require().Nil(consumer) + s.Require().ErrorIs(err, common.ErrWatcherNotInitialized) +} + +func (s *WatcherTestSuite) TestRegisterProducerFailsIfWatcherIsNotInitialized() { + s.store = nil + currentWatcher := watcher.GetWatcher() + currentWatcher.Close() + + producer, err := watcher.RegisterProducer(s.ctx, "test") + s.Require().Nil(producer) + s.Require().ErrorIs(err, common.ErrWatcherNotInitialized) +} + +func (s *WatcherTestSuite) TestProducerAndConsumer() { + producer, err := watcher.RegisterProducer(s.ctx, "test-producer") + s.Require().NoError(err) + s.Require().NotNil(producer) + + consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer") + s.Require().NoError(err) + s.Require().NotNil(consumer) + + payload := common.ChangePayload{ + EntityType: common.ControllerEntityType, + Operation: common.UpdateOperation, + Payload: "test", + } + err = producer.Notify(payload) + s.Require().NoError(err) + + receivedPayload := <-consumer.Watch() + s.Require().Equal(payload, receivedPayload) +} + +func (s *WatcherTestSuite) TestConsumetWithFilter() { + producer, err := watcher.RegisterProducer(s.ctx, "test-producer") + s.Require().NoError(err) + s.Require().NotNil(producer) + + consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer", func(payload common.ChangePayload) bool { + return payload.Operation == common.UpdateOperation + }) + s.Require().NoError(err) + s.Require().NotNil(consumer) + + payload := common.ChangePayload{ + EntityType: common.ControllerEntityType, + Operation: common.UpdateOperation, + Payload: "test", + } + err = producer.Notify(payload) + s.Require().NoError(err) + + select { + case receivedPayload := <-consumer.Watch(): + s.Require().Equal(payload, receivedPayload) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + payload = common.ChangePayload{ + EntityType: common.ControllerEntityType, + Operation: common.CreateOperation, + Payload: "test", + } + err = producer.Notify(payload) + s.Require().NoError(err) + + select { + case <-consumer.Watch(): + s.T().Fatal("unexpected payload received") + case <-time.After(1 * time.Second): + } + +} + +func TestWatcherTestSuite(t *testing.T) { + watcherSuite := &WatcherTestSuite{ + ctx: context.TODO(), + } + suite.Run(t, watcherSuite) +} diff --git a/internal/testing/mock_watcher.go b/internal/testing/mock_watcher.go index 394091bd..67ae5da4 100644 --- a/internal/testing/mock_watcher.go +++ b/internal/testing/mock_watcher.go @@ -3,18 +3,25 @@ package testing -import "github.com/cloudbase/garm/database/common" +import ( + "context" + + "github.com/cloudbase/garm/database/common" +) type MockWatcher struct{} -func (w *MockWatcher) RegisterProducer(_ string) (common.Producer, error) { +func (w *MockWatcher) RegisterProducer(_ context.Context, _ string) (common.Producer, error) { return &MockProducer{}, nil } -func (w *MockWatcher) RegisterConsumer(_ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) { +func (w *MockWatcher) RegisterConsumer(_ context.Context, _ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) { return &MockConsumer{}, nil } +func (w *MockWatcher) Close() { +} + type MockProducer struct{} func (p *MockProducer) Notify(_ common.ChangePayload) error { From 37f6434ed8308353c38c0ed15fce887a8b94eef6 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Tue, 18 Jun 2024 16:42:24 +0000 Subject: [PATCH 05/11] Fix race condition and add some tests Signed-off-by: Gabriel Adrian Samfira --- database/watcher/watcher.go | 16 +++++- database/watcher/watcher_store_test.go | 72 ++++++++++++++++++-------- database/watcher/watcher_test.go | 43 +++++++++++---- 3 files changed, 97 insertions(+), 34 deletions(-) diff --git a/database/watcher/watcher.go b/database/watcher/watcher.go index 86ba594e..ef5a5525 100644 --- a/database/watcher/watcher.go +++ b/database/watcher/watcher.go @@ -5,6 +5,8 @@ import ( "log/slog" "sync" + "github.com/pkg/errors" + "github.com/cloudbase/garm/database/common" garmUtil "github.com/cloudbase/garm/util" ) @@ -58,7 +60,7 @@ func (w *watcher) RegisterProducer(ctx context.Context, id string) (common.Produ defer w.mux.Unlock() if _, ok := w.producers[id]; ok { - return nil, common.ErrProducerAlreadyRegistered + return nil, errors.Wrapf(common.ErrProducerAlreadyRegistered, "producer_id: %s", id) } p := &producer{ id: id, @@ -87,15 +89,25 @@ func (w *watcher) serviceProducer(prod *producer) { case <-w.ctx.Done(): slog.InfoContext(w.ctx, "shutting down watcher") return + case <-prod.quit: + slog.InfoContext(w.ctx, "closing producer") + return + case <-prod.ctx.Done(): + slog.InfoContext(w.ctx, "closing producer") + return case payload := <-prod.messages: + w.mux.Lock() for _, c := range w.consumers { go c.Send(payload) } + w.mux.Unlock() } } } func (w *watcher) RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { + w.mux.Lock() + defer w.mux.Unlock() if _, ok := w.consumers[id]; ok { return nil, common.ErrConsumerAlreadyRegistered } @@ -123,6 +135,8 @@ func (w *watcher) serviceConsumer(consumer *consumer) { select { case <-consumer.quit: return + case <-consumer.ctx.Done(): + return case <-w.quit: return case <-w.ctx.Done(): diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index b5353c03..f7a2e4c3 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -2,13 +2,13 @@ package watcher_test import ( "context" - "testing" + "time" + + "github.com/stretchr/testify/suite" - "github.com/cloudbase/garm/database" "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" - garmTesting "github.com/cloudbase/garm/internal/testing" - "github.com/stretchr/testify/suite" + "github.com/cloudbase/garm/params" ) type WatcherStoreTestSuite struct { @@ -19,27 +19,55 @@ type WatcherStoreTestSuite struct { } func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() { - // ghEpParams := params.CreateGithubEndpointParams{ - // Name: "test", - // Description: "test endpoint", - // APIBaseURL: "https://api.ghes.example.com", - // UploadBaseURL: "https://upload.ghes.example.com", - // BaseURL: "https://ghes.example.com", - // } + consumer, err := watcher.RegisterConsumer( + s.ctx, "gh-ep-test", + watcher.WithEntityTypeFilter(common.GithubEndpointEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + ghEpParams := params.CreateGithubEndpointParams{ + Name: "test", + Description: "test endpoint", + APIBaseURL: "https://api.ghes.example.com", + UploadBaseURL: "https://upload.ghes.example.com", + BaseURL: "https://ghes.example.com", + } -} + ghEp, err := s.store.CreateGithubEndpoint(s.ctx, ghEpParams) + s.Require().NoError(err) + s.Require().NotEmpty(ghEp.Name) -func TestWatcherStoreTestSuite(t *testing.T) { - ctx := context.TODO() - watcher.InitWatcher(ctx) + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubEndpointEntityType, + Operation: common.CreateOperation, + Payload: ghEp, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } - store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t)) - if err != nil { - t.Fatalf("failed to create db connection: %s", err) + newDesc := "updated description" + updateParams := params.UpdateGithubEndpointParams{ + Description: &newDesc, } - watcherSuite := &WatcherStoreTestSuite{ - ctx: context.TODO(), - store: store, + + updatedGhEp, err := s.store.UpdateGithubEndpoint(s.ctx, ghEp.Name, updateParams) + s.Require().NoError(err) + s.Require().Equal(newDesc, updatedGhEp.Description) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubEndpointEntityType, + Operation: common.UpdateOperation, + Payload: updatedGhEp, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") } - suite.Run(t, watcherSuite) } diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go index 838cdeb0..b44c152e 100644 --- a/database/watcher/watcher_test.go +++ b/database/watcher/watcher_test.go @@ -4,14 +4,16 @@ package watcher_test import ( "context" + "fmt" "testing" "time" + "github.com/stretchr/testify/suite" + "github.com/cloudbase/garm/database" "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" garmTesting "github.com/cloudbase/garm/internal/testing" - "github.com/stretchr/testify/suite" ) type WatcherTestSuite struct { @@ -23,7 +25,7 @@ type WatcherTestSuite struct { func (s *WatcherTestSuite) SetupTest() { ctx := context.TODO() watcher.InitWatcher(ctx) - + fmt.Printf("creating store: %v\n", s.store) store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T())) if err != nil { s.T().Fatalf("failed to create db connection: %s", err) @@ -39,23 +41,23 @@ func (s *WatcherTestSuite) TearDownTest() { } } -func (s *WatcherTestSuite) TestRegisterConsumer() { +func (s *WatcherTestSuite) TestRegisterConsumerTwiceWillError() { consumer, err := watcher.RegisterConsumer(s.ctx, "test") s.Require().NoError(err) s.Require().NotNil(consumer) consumer, err = watcher.RegisterConsumer(s.ctx, "test") - s.Require().Error(err) + s.Require().ErrorIs(err, common.ErrConsumerAlreadyRegistered) s.Require().Nil(consumer) } -func (s *WatcherTestSuite) TestRegisterProducer() { +func (s *WatcherTestSuite) TestRegisterProducerTwiceWillError() { producer, err := watcher.RegisterProducer(s.ctx, "test") s.Require().NoError(err) s.Require().NotNil(producer) producer, err = watcher.RegisterProducer(s.ctx, "test") - s.Require().Error(err) + s.Require().ErrorIs(err, common.ErrProducerAlreadyRegistered) s.Require().Nil(producer) } @@ -93,7 +95,10 @@ func (s *WatcherTestSuite) TestProducerAndConsumer() { s.Require().NoError(err) s.Require().NotNil(producer) - consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer") + consumer, err := watcher.RegisterConsumer( + s.ctx, "test-consumer", + watcher.WithEntityTypeFilter(common.ControllerEntityType), + watcher.WithOperationTypeFilter(common.UpdateOperation)) s.Require().NoError(err) s.Require().NotNil(consumer) @@ -114,9 +119,10 @@ func (s *WatcherTestSuite) TestConsumetWithFilter() { s.Require().NoError(err) s.Require().NotNil(producer) - consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer", func(payload common.ChangePayload) bool { - return payload.Operation == common.UpdateOperation - }) + consumer, err := watcher.RegisterConsumer( + s.ctx, "test-consumer", + watcher.WithEntityTypeFilter(common.ControllerEntityType), + watcher.WithOperationTypeFilter(common.UpdateOperation)) s.Require().NoError(err) s.Require().NotNil(consumer) @@ -148,12 +154,27 @@ func (s *WatcherTestSuite) TestConsumetWithFilter() { s.T().Fatal("unexpected payload received") case <-time.After(1 * time.Second): } - } func TestWatcherTestSuite(t *testing.T) { + // Watcher tests watcherSuite := &WatcherTestSuite{ ctx: context.TODO(), } suite.Run(t, watcherSuite) + + // These tests run store changes and make sure that the store properly + // triggers watcher notifications. + ctx := context.TODO() + watcher.InitWatcher(ctx) + + store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t)) + if err != nil { + t.Fatalf("failed to create db connection: %s", err) + } + watcherStoreSuite := &WatcherStoreTestSuite{ + ctx: context.TODO(), + store: store, + } + suite.Run(t, watcherStoreSuite) } From 8a79d9e8f95eb4f7e45671501edb1df98511589e Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Tue, 18 Jun 2024 17:45:48 +0000 Subject: [PATCH 06/11] Add more watcher tests Signed-off-by: Gabriel Adrian Samfira --- database/sql/enterprise.go | 38 +++- database/sql/organizations.go | 44 +++- database/sql/repositories.go | 5 +- database/watcher/watcher_store_test.go | 294 ++++++++++++++++++++++++- database/watcher/watcher_test.go | 8 +- 5 files changed, 366 insertions(+), 23 deletions(-) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index 7d20d2e8..c5af3bc4 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -16,6 +16,7 @@ package sql import ( "context" + "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -23,10 +24,11 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm-provider-common/util" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) { +func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (paramEnt params.Enterprise, err error) { if webhookSecret == "" { return params.Enterprise{}, errors.New("creating enterprise: missing secret") } @@ -34,6 +36,12 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam if err != nil { return params.Enterprise{}, errors.Wrap(err, "encoding secret") } + + defer func() { + if err == nil { + s.sendNotify(common.EnterpriseEntityType, common.CreateOperation, paramEnt) + } + }() newEnterprise := Enterprise{ Name: name, WebhookSecret: secret, @@ -66,12 +74,12 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam return params.Enterprise{}, errors.Wrap(err, "creating enterprise") } - param, err := s.sqlToCommonEnterprise(newEnterprise, true) + paramEnt, err = s.sqlToCommonEnterprise(newEnterprise, true) if err != nil { return params.Enterprise{}, errors.Wrap(err, "creating enterprise") } - return param, nil + return paramEnt, nil } func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { @@ -124,11 +132,22 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e } func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error { - enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID) + enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials") if err != nil { return errors.Wrap(err, "fetching enterprise") } + defer func(ent Enterprise) { + if err == nil { + asParams, innerErr := s.sqlToCommonEnterprise(ent, true) + if innerErr == nil { + s.sendNotify(common.EnterpriseEntityType, common.DeleteOperation, asParams) + } else { + slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "enterprise", enterpriseID) + } + } + }(enterprise) + q := s.conn.Unscoped().Delete(&enterprise) if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { return errors.Wrap(q.Error, "deleting enterprise") @@ -137,10 +156,15 @@ func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) return nil } -func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) { +func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (newParams params.Enterprise, err error) { + defer func() { + if err == nil { + s.sendNotify(common.EnterpriseEntityType, common.UpdateOperation, newParams) + } + }() var enterprise Enterprise var creds GithubCredentials - err := s.conn.Transaction(func(tx *gorm.DB) error { + err = s.conn.Transaction(func(tx *gorm.DB) error { var err error enterprise, err = s.getEnterpriseByID(ctx, tx, enterpriseID) if err != nil { @@ -196,7 +220,7 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, if err != nil { return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } - newParams, err := s.sqlToCommonEnterprise(enterprise, true) + newParams, err = s.sqlToCommonEnterprise(enterprise, true) if err != nil { return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 1192c843..0f3d58a3 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -17,6 +17,7 @@ package sql import ( "context" "fmt" + "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -24,10 +25,11 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm-provider-common/util" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) { +func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (org params.Organization, err error) { if webhookSecret == "" { return params.Organization{}, errors.New("creating org: missing secret") } @@ -35,6 +37,12 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN if err != nil { return params.Organization{}, errors.Wrap(err, "encoding secret") } + + defer func() { + if err == nil { + s.sendNotify(common.OrganizationEntityType, common.CreateOperation, org) + } + }() newOrg := Organization{ Name: name, WebhookSecret: secret, @@ -68,13 +76,13 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN return params.Organization{}, errors.Wrap(err, "creating org") } - param, err := s.sqlToCommonOrganization(newOrg, true) + org, err = s.sqlToCommonOrganization(newOrg, true) if err != nil { return params.Organization{}, errors.Wrap(err, "creating org") } - param.WebhookSecret = webhookSecret + org.WebhookSecret = webhookSecret - return param, nil + return org, nil } func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params.Organization, error) { @@ -114,12 +122,23 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio return ret, nil } -func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) error { - org, err := s.getOrgByID(ctx, s.conn, orgID) +func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err error) { + org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials") if err != nil { return errors.Wrap(err, "fetching org") } + defer func(org Organization) { + if err == nil { + asParam, innerErr := s.sqlToCommonOrganization(org, true) + if innerErr == nil { + s.sendNotify(common.OrganizationEntityType, common.DeleteOperation, asParam) + } else { + slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "org", orgID) + } + } + }(org) + q := s.conn.Unscoped().Delete(&org) if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { return errors.Wrap(q.Error, "deleting org") @@ -128,10 +147,15 @@ func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) erro return nil } -func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) { +func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (paramOrg params.Organization, err error) { + defer func() { + if err == nil { + s.sendNotify(common.OrganizationEntityType, common.UpdateOperation, paramOrg) + } + }() var org Organization var creds GithubCredentials - err := s.conn.Transaction(func(tx *gorm.DB) error { + err = s.conn.Transaction(func(tx *gorm.DB) error { var err error org, err = s.getOrgByID(ctx, tx, orgID) if err != nil { @@ -188,11 +212,11 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para if err != nil { return params.Organization{}, errors.Wrap(err, "updating enterprise") } - newParams, err := s.sqlToCommonOrganization(org, true) + paramOrg, err = s.sqlToCommonOrganization(org, true) if err != nil { return params.Organization{}, errors.Wrap(err, "saving org") } - return newParams, nil + return paramOrg, nil } func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) { diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 7ab1c522..5469950f 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -17,6 +17,7 @@ package sql import ( "context" "fmt" + "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -121,7 +122,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, } func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) { - repo, err := s.getRepoByID(ctx, s.conn, repoID) + repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials") if err != nil { return errors.Wrap(err, "fetching repo") } @@ -131,6 +132,8 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err asParam, innerErr := s.sqlToCommonRepository(repo, true) if innerErr == nil { s.sendNotify(common.RepositoryEntityType, common.DeleteOperation, asParam) + } else { + slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "repo", repoID) } } }(repo) diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index f7a2e4c3..895dab9d 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -8,6 +8,7 @@ import ( "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" + garmTesting "github.com/cloudbase/garm/internal/testing" "github.com/cloudbase/garm/params" ) @@ -18,16 +19,292 @@ type WatcherStoreTestSuite struct { ctx context.Context } +func (s *WatcherStoreTestSuite) TestEnterpriseWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "enterprise-test", + watcher.WithEntityTypeFilter(common.EnterpriseEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) }) + + ent, err := s.store.CreateEnterprise(s.ctx, "test-enterprise", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(ent.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.EnterpriseEntityType, + Operation: common.CreateOperation, + Payload: ent, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + updateParams := params.UpdateEntityParams{ + WebhookSecret: "updated", + } + + updatedEnt, err := s.store.UpdateEnterprise(s.ctx, ent.ID, updateParams) + s.Require().NoError(err) + s.Require().Equal("updated", updatedEnt.WebhookSecret) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.EnterpriseEntityType, + Operation: common.UpdateOperation, + Payload: updatedEnt, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteEnterprise(s.ctx, ent.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.EnterpriseEntityType, + Operation: common.DeleteOperation, + Payload: updatedEnt, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + +func (s *WatcherStoreTestSuite) TestOrgWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "org-test", + watcher.WithEntityTypeFilter(common.OrganizationEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) }) + + org, err := s.store.CreateOrganization(s.ctx, "test-org", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(org.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.OrganizationEntityType, + Operation: common.CreateOperation, + Payload: org, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + updateParams := params.UpdateEntityParams{ + WebhookSecret: "updated", + } + + updatedOrg, err := s.store.UpdateOrganization(s.ctx, org.ID, updateParams) + s.Require().NoError(err) + s.Require().Equal("updated", updatedOrg.WebhookSecret) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.OrganizationEntityType, + Operation: common.UpdateOperation, + Payload: updatedOrg, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteOrganization(s.ctx, org.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.OrganizationEntityType, + Operation: common.DeleteOperation, + Payload: updatedOrg, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + +func (s *WatcherStoreTestSuite) TestRepoWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "repo-test", + watcher.WithEntityTypeFilter(common.RepositoryEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) }) + + repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(repo.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.RepositoryEntityType, + Operation: common.CreateOperation, + Payload: repo, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + newSecret := "updated" + updateParams := params.UpdateEntityParams{ + WebhookSecret: newSecret, + } + + updatedRepo, err := s.store.UpdateRepository(s.ctx, repo.ID, updateParams) + s.Require().NoError(err) + s.Require().Equal(newSecret, updatedRepo.WebhookSecret) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.RepositoryEntityType, + Operation: common.UpdateOperation, + Payload: updatedRepo, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteRepository(s.ctx, repo.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.RepositoryEntityType, + Operation: common.DeleteOperation, + Payload: updatedRepo, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + +func (s *WatcherStoreTestSuite) TestGithubCredentialsWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "gh-cred-test", + watcher.WithEntityTypeFilter(common.GithubCredentialsEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ghCredParams := params.CreateGithubCredentialsParams{ + Name: "test-creds", + Description: "test credentials", + Endpoint: "github.com", + AuthType: params.GithubAuthTypePAT, + PAT: params.GithubPAT{ + OAuth2Token: "bogus", + }, + } + + ghCred, err := s.store.CreateGithubCredentials(s.ctx, ghCredParams) + s.Require().NoError(err) + s.Require().NotEmpty(ghCred.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubCredentialsEntityType, + Operation: common.CreateOperation, + Payload: ghCred, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + newDesc := "updated description" + updateParams := params.UpdateGithubCredentialsParams{ + Description: &newDesc, + } + + updatedGhCred, err := s.store.UpdateGithubCredentials(s.ctx, ghCred.ID, updateParams) + s.Require().NoError(err) + s.Require().Equal(newDesc, updatedGhCred.Description) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubCredentialsEntityType, + Operation: common.UpdateOperation, + Payload: updatedGhCred, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteGithubCredentials(s.ctx, ghCred.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubCredentialsEntityType, + Operation: common.DeleteOperation, + // We only get the ID and Name of the deleted entity + Payload: params.GithubCredentials{ID: ghCred.ID, Name: ghCred.Name}, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() { consumer, err := watcher.RegisterConsumer( s.ctx, "gh-ep-test", watcher.WithEntityTypeFilter(common.GithubEndpointEntityType), watcher.WithAny( watcher.WithOperationTypeFilter(common.CreateOperation), - watcher.WithOperationTypeFilter(common.UpdateOperation)), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), ) s.Require().NoError(err) s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + ghEpParams := params.CreateGithubEndpointParams{ Name: "test", Description: "test endpoint", @@ -70,4 +347,19 @@ func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() { case <-time.After(1 * time.Second): s.T().Fatal("expected payload not received") } + + err = s.store.DeleteGithubEndpoint(s.ctx, ghEp.Name) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubEndpointEntityType, + Operation: common.DeleteOperation, + // We only get the name of the deleted entity + Payload: params.GithubEndpoint{Name: ghEp.Name}, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } } diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go index b44c152e..21d15093 100644 --- a/database/watcher/watcher_test.go +++ b/database/watcher/watcher_test.go @@ -163,17 +163,17 @@ func TestWatcherTestSuite(t *testing.T) { } suite.Run(t, watcherSuite) - // These tests run store changes and make sure that the store properly - // triggers watcher notifications. - ctx := context.TODO() + ctx := context.Background() watcher.InitWatcher(ctx) store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t)) if err != nil { t.Fatalf("failed to create db connection: %s", err) } + + adminCtx := garmTesting.ImpersonateAdminContext(ctx, store, t) watcherStoreSuite := &WatcherStoreTestSuite{ - ctx: context.TODO(), + ctx: adminCtx, store: store, } suite.Run(t, watcherStoreSuite) From cc9ecf5847433f5ff9b71777c858c134b779e295 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Tue, 18 Jun 2024 19:28:34 +0000 Subject: [PATCH 07/11] Add more tests Signed-off-by: Gabriel Adrian Samfira --- database/sql/pools.go | 12 +++++++--- database/watcher/watcher_store_test.go | 31 ++++++++++++++++++++++++++ database/watcher/watcher_test.go | 18 +++++++++++++++ 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/database/sql/pools.go b/database/sql/pools.go index 01d1afc4..89500ed9 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -17,6 +17,7 @@ package sql import ( "context" "fmt" + "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -73,11 +74,16 @@ func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) (err erro return errors.Wrap(err, "fetching pool by ID") } - defer func() { + defer func(pool Pool) { if err == nil { - s.sendNotify(common.PoolEntityType, common.DeleteOperation, pool) + asParams, innerErr := s.sqlToCommonPool(pool) + if innerErr == nil { + s.sendNotify(common.PoolEntityType, common.DeleteOperation, asParams) + } else { + slog.With(slog.Any("error", innerErr)).ErrorContext(s.ctx, "error sending delete notification", "pool", poolID) + } } - }() + }(pool) if q := s.conn.Unscoped().Delete(&pool); q.Error != nil { return errors.Wrap(q.Error, "removing pool") diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index 895dab9d..fa82a339 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -19,6 +19,37 @@ type WatcherStoreTestSuite struct { ctx context.Context } +func (s *WatcherStoreTestSuite) TestControllerWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "controller-test", + watcher.WithEntityTypeFilter(common.ControllerEntityType), + watcher.WithOperationTypeFilter(common.UpdateOperation), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + metadataURL := "http://metadata.example.com" + updateParams := params.UpdateControllerParams{ + MetadataURL: &metadataURL, + } + + controller, err := s.store.UpdateController(updateParams) + s.Require().NoError(err) + s.Require().Equal(metadataURL, controller.MetadataURL) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.ControllerEntityType, + Operation: common.UpdateOperation, + Payload: controller, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + func (s *WatcherStoreTestSuite) TestEnterpriseWatcher() { consumer, err := watcher.RegisterConsumer( s.ctx, "enterprise-test", diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go index 21d15093..6d1091ed 100644 --- a/database/watcher/watcher_test.go +++ b/database/watcher/watcher_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/stretchr/testify/suite" "github.com/cloudbase/garm/database" @@ -156,6 +157,18 @@ func (s *WatcherTestSuite) TestConsumetWithFilter() { } } +func maybeInitController(db common.Store) error { + if _, err := db.ControllerInfo(); err == nil { + return nil + } + + if _, err := db.InitController(); err != nil { + return errors.Wrap(err, "initializing controller") + } + + return nil +} + func TestWatcherTestSuite(t *testing.T) { // Watcher tests watcherSuite := &WatcherTestSuite{ @@ -171,6 +184,11 @@ func TestWatcherTestSuite(t *testing.T) { t.Fatalf("failed to create db connection: %s", err) } + err = maybeInitController(store) + if err != nil { + t.Fatalf("failed to init controller: %s", err) + } + adminCtx := garmTesting.ImpersonateAdminContext(ctx, store, t) watcherStoreSuite := &WatcherStoreTestSuite{ ctx: adminCtx, From 0c8c6f5668b058361bddaa477b38da1c1c2c7b48 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Wed, 19 Jun 2024 12:19:58 +0000 Subject: [PATCH 08/11] Add more tests Signed-off-by: Gabriel Adrian Samfira --- database/sql/instances.go | 36 +++- database/sql/pools.go | 12 +- database/watcher/watcher_store_test.go | 217 +++++++++++++++++++++++++ database/watcher/watcher_test.go | 2 - internal/testing/testing.go | 2 +- 5 files changed, 253 insertions(+), 16 deletions(-) diff --git a/database/sql/instances.go b/database/sql/instances.go index c09b60f3..3f113669 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -25,15 +25,22 @@ import ( "gorm.io/gorm/clause" runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { +func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) { pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") } + defer func() { + if err == nil { + s.sendNotify(common.InstanceEntityType, common.CreateOperation, instance) + } + }() + var labels datatypes.JSON if len(param.AditionalLabels) > 0 { labels, err = json.Marshal(param.AditionalLabels) @@ -134,11 +141,28 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string return s.sqlToParamsInstance(instance) } -func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) error { +func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) (err error) { instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return errors.Wrap(err, "deleting instance") } + + defer func() { + if err == nil { + var providerID string + if instance.ProviderID != nil { + providerID = *instance.ProviderID + } + s.sendNotify(common.InstanceEntityType, common.DeleteOperation, params.Instance{ + ID: instance.ID.String(), + Name: instance.Name, + ProviderID: providerID, + AgentID: instance.AgentID, + PoolID: instance.PoolID.String(), + }) + } + }() + if q := s.conn.Unscoped().Delete(&instance); q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return nil @@ -230,8 +254,12 @@ func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, p return params.Instance{}, errors.Wrap(err, "updating addresses") } } - - return s.sqlToParamsInstance(instance) + inst, err := s.sqlToParamsInstance(instance) + if err != nil { + return params.Instance{}, errors.Wrap(err, "converting instance") + } + s.sendNotify(common.InstanceEntityType, common.UpdateOperation, inst) + return inst, nil } func (s *sqlDatabase) ListPoolInstances(_ context.Context, poolID string) ([]params.Instance, error) { diff --git a/database/sql/pools.go b/database/sql/pools.go index 89500ed9..0cb4a094 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -17,7 +17,6 @@ package sql import ( "context" "fmt" - "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -74,16 +73,11 @@ func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) (err erro return errors.Wrap(err, "fetching pool by ID") } - defer func(pool Pool) { + defer func() { if err == nil { - asParams, innerErr := s.sqlToCommonPool(pool) - if innerErr == nil { - s.sendNotify(common.PoolEntityType, common.DeleteOperation, asParams) - } else { - slog.With(slog.Any("error", innerErr)).ErrorContext(s.ctx, "error sending delete notification", "pool", poolID) - } + s.sendNotify(common.PoolEntityType, common.DeleteOperation, params.Pool{ID: poolID}) } - }(pool) + }() if q := s.conn.Unscoped().Delete(&pool); q.Error != nil { return errors.Wrap(q.Error, "removing pool") diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index fa82a339..80f71325 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/suite" + commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" garmTesting "github.com/cloudbase/garm/internal/testing" @@ -19,6 +20,221 @@ type WatcherStoreTestSuite struct { ctx context.Context } +func (s *WatcherStoreTestSuite) TestInstanceWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "instance-test", + watcher.WithEntityTypeFilter(common.InstanceEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) }) + + repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(repo.ID) + s.T().Cleanup(func() { s.store.DeleteRepository(s.ctx, repo.ID) }) + + entity, err := repo.GetEntity() + s.Require().NoError(err) + + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + Image: "test-image", + Flavor: "test-flavor", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Tags: []string{"test-tag"}, + } + + pool, err := s.store.CreateEntityPool(s.ctx, entity, createPoolParams) + s.Require().NoError(err) + s.Require().NotEmpty(pool.ID) + s.T().Cleanup(func() { s.store.DeleteEntityPool(s.ctx, entity, pool.ID) }) + + createInstanceParams := params.CreateInstanceParams{ + Name: "test-instance", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Status: commonParams.InstanceCreating, + } + instance, err := s.store.CreateInstance(s.ctx, pool.ID, createInstanceParams) + s.Require().NoError(err) + s.Require().NotEmpty(instance.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.InstanceEntityType, + Operation: common.CreateOperation, + Payload: instance, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + updateParams := params.UpdateInstanceParams{ + RunnerStatus: params.RunnerActive, + } + + updatedInstance, err := s.store.UpdateInstance(s.ctx, instance.Name, updateParams) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.InstanceEntityType, + Operation: common.UpdateOperation, + Payload: updatedInstance, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteInstance(s.ctx, pool.ID, updatedInstance.Name) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.InstanceEntityType, + Operation: common.DeleteOperation, + Payload: params.Instance{ + ID: updatedInstance.ID, + Name: updatedInstance.Name, + ProviderID: updatedInstance.ProviderID, + AgentID: updatedInstance.AgentID, + PoolID: updatedInstance.PoolID, + }, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + +func (s *WatcherStoreTestSuite) TestPoolWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "pool-test", + watcher.WithEntityTypeFilter(common.PoolEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { + if err := s.store.DeleteGithubCredentials(s.ctx, creds.ID); err != nil { + s.T().Logf("failed to delete Github credentials: %v", err) + } + }) + + repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(repo.ID) + s.T().Cleanup(func() { s.store.DeleteRepository(s.ctx, repo.ID) }) + + entity, err := repo.GetEntity() + s.Require().NoError(err) + + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + Image: "test-image", + Flavor: "test-flavor", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Tags: []string{"test-tag"}, + } + pool, err := s.store.CreateEntityPool(s.ctx, entity, createPoolParams) + s.Require().NoError(err) + s.Require().NotEmpty(pool.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.CreateOperation, + Payload: pool, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + updateParams := params.UpdatePoolParams{ + Tags: []string{"updated-tag"}, + } + + updatedPool, err := s.store.UpdateEntityPool(s.ctx, entity, pool.ID, updateParams) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.UpdateOperation, + Payload: updatedPool, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteEntityPool(s.ctx, entity, pool.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.DeleteOperation, + Payload: params.Pool{ID: pool.ID}, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + // Also test DeletePoolByID + pool, err = s.store.CreateEntityPool(s.ctx, entity, createPoolParams) + s.Require().NoError(err) + s.Require().NotEmpty(pool.ID) + + // Consume the create event + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.CreateOperation, + Payload: pool, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeletePoolByID(s.ctx, pool.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.DeleteOperation, + Payload: params.Pool{ID: pool.ID}, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + func (s *WatcherStoreTestSuite) TestControllerWatcher() { consumer, err := watcher.RegisterConsumer( s.ctx, "controller-test", @@ -275,6 +491,7 @@ func (s *WatcherStoreTestSuite) TestGithubCredentialsWatcher() { ghCred, err := s.store.CreateGithubCredentials(s.ctx, ghCredParams) s.Require().NoError(err) s.Require().NotEmpty(ghCred.ID) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, ghCred.ID) }) select { case event := <-consumer.Watch(): diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go index 6d1091ed..c5b56fe2 100644 --- a/database/watcher/watcher_test.go +++ b/database/watcher/watcher_test.go @@ -4,7 +4,6 @@ package watcher_test import ( "context" - "fmt" "testing" "time" @@ -26,7 +25,6 @@ type WatcherTestSuite struct { func (s *WatcherTestSuite) SetupTest() { ctx := context.TODO() watcher.InitWatcher(ctx) - fmt.Printf("creating store: %v\n", s.store) store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T())) if err != nil { s.T().Fatalf("failed to create db connection: %s", err) diff --git a/internal/testing/testing.go b/internal/testing/testing.go index 8949f7cf..6e76956f 100644 --- a/internal/testing/testing.go +++ b/internal/testing/testing.go @@ -122,7 +122,7 @@ func CreateTestGithubCredentials(ctx context.Context, credsName string, db commo } newCreds, err := db.CreateGithubCredentials(ctx, newCredsParams) if err != nil { - s.Fatalf("failed to create database object (new-creds): %v", err) + s.Fatalf("failed to create database object (%s): %v", credsName, err) } return newCreds } From 5f07bc2d7c1cc1aba62a1a1766acdb3a00bbe5e2 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Wed, 19 Jun 2024 12:40:56 +0000 Subject: [PATCH 09/11] Check if producer was registered Signed-off-by: Gabriel Adrian Samfira --- database/sql/instances.go | 7 +++++-- database/sql/util.go | 11 +++++++++-- database/watcher/watcher_store_test.go | 1 - 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/database/sql/instances.go b/database/sql/instances.go index 3f113669..864e7ba2 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -17,6 +17,7 @@ package sql import ( "context" "encoding/json" + "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -153,13 +154,15 @@ func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceN if instance.ProviderID != nil { providerID = *instance.ProviderID } - s.sendNotify(common.InstanceEntityType, common.DeleteOperation, params.Instance{ + if notifyErr := s.sendNotify(common.InstanceEntityType, common.DeleteOperation, params.Instance{ ID: instance.ID.String(), Name: instance.Name, ProviderID: providerID, AgentID: instance.AgentID, PoolID: instance.PoolID.String(), - }) + }); notifyErr != nil { + slog.With(slog.Any("error", notifyErr)).Error("failed to send notify") + } } }() diff --git a/database/sql/util.go b/database/sql/util.go index 5814483d..dd861197 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -488,11 +488,18 @@ func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error return nil } -func (s *sqlDatabase) sendNotify(entityType dbCommon.DatabaseEntityType, op dbCommon.OperationType, payload interface{}) { +func (s *sqlDatabase) sendNotify(entityType dbCommon.DatabaseEntityType, op dbCommon.OperationType, payload interface{}) error { + if s.producer == nil { + // no producer was registered. Not sending notifications. + return nil + } + if payload == nil { + return errors.New("missing payload") + } message := dbCommon.ChangePayload{ Operation: op, Payload: payload, EntityType: entityType, } - s.producer.Notify(message) + return s.producer.Notify(message) } diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index 80f71325..f74f836d 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -491,7 +491,6 @@ func (s *WatcherStoreTestSuite) TestGithubCredentialsWatcher() { ghCred, err := s.store.CreateGithubCredentials(s.ctx, ghCredParams) s.Require().NoError(err) s.Require().NotEmpty(ghCred.ID) - s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, ghCred.ID) }) select { case event := <-consumer.Watch(): From b7d138d2ac9c97ff402c2212ca385d260364b710 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Wed, 19 Jun 2024 13:44:24 +0000 Subject: [PATCH 10/11] Add notifications for jobs Signed-off-by: Gabriel Adrian Samfira --- database/sql/jobs.go | 45 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/database/sql/jobs.go b/database/sql/jobs.go index 0201428f..f3c8ae72 100644 --- a/database/sql/jobs.go +++ b/database/sql/jobs.go @@ -93,7 +93,14 @@ func (s *sqlDatabase) paramsJobToWorkflowJob(ctx context.Context, job params.Job return workflofJob, nil } -func (s *sqlDatabase) DeleteJob(_ context.Context, jobID int64) error { +func (s *sqlDatabase) DeleteJob(_ context.Context, jobID int64) (err error) { + defer func() { + if err == nil { + if notifyErr := s.sendNotify(common.JobEntityType, common.DeleteOperation, params.Job{ID: jobID}); notifyErr != nil { + slog.With(slog.Any("error", notifyErr)).Error("failed to send notify") + } + } + }() q := s.conn.Delete(&WorkflowJob{}, jobID) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { @@ -134,10 +141,17 @@ func (s *sqlDatabase) LockJob(_ context.Context, jobID int64, entityID string) e return errors.Wrap(err, "saving job") } + asParams, err := sqlWorkflowJobToParamsJob(workflowJob) + if err == nil { + s.sendNotify(common.JobEntityType, common.UpdateOperation, asParams) + } else { + slog.With(slog.Any("error", err)).Error("failed to convert job to params") + } + return nil } -func (s *sqlDatabase) BreakLockJobIsQueued(_ context.Context, jobID int64) error { +func (s *sqlDatabase) BreakLockJobIsQueued(_ context.Context, jobID int64) (err error) { var workflowJob WorkflowJob q := s.conn.Clauses(clause.Locking{Strength: "UPDATE"}).Preload("Instance").Where("id = ? and status = ?", jobID, params.JobStatusQueued).First(&workflowJob) @@ -157,7 +171,12 @@ func (s *sqlDatabase) BreakLockJobIsQueued(_ context.Context, jobID int64) error if err := s.conn.Save(&workflowJob).Error; err != nil { return errors.Wrap(err, "saving job") } - + asParams, err := sqlWorkflowJobToParamsJob(workflowJob) + if err == nil { + s.sendNotify(common.JobEntityType, common.UpdateOperation, asParams) + } else { + slog.With(slog.Any("error", err)).Error("failed to convert job to params") + } return nil } @@ -186,6 +205,12 @@ func (s *sqlDatabase) UnlockJob(_ context.Context, jobID int64, entityID string) return errors.Wrap(err, "saving job") } + asParams, err := sqlWorkflowJobToParamsJob(workflowJob) + if err == nil { + s.sendNotify(common.JobEntityType, common.UpdateOperation, asParams) + } else { + slog.With(slog.Any("error", err)).Error("failed to convert job to params") + } return nil } @@ -198,9 +223,11 @@ func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (pa return params.Job{}, errors.Wrap(q.Error, "fetching job") } } - + var operation common.OperationType if workflowJob.ID != 0 { // Update workflowJob with values from job. + operation = common.UpdateOperation + workflowJob.Status = job.Status workflowJob.Action = job.Action workflowJob.Conclusion = job.Conclusion @@ -238,6 +265,8 @@ func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (pa return params.Job{}, errors.Wrap(err, "saving job") } } else { + operation = common.CreateOperation + workflowJob, err := s.paramsJobToWorkflowJob(ctx, job) if err != nil { return params.Job{}, errors.Wrap(err, "converting job") @@ -247,7 +276,13 @@ func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (pa } } - return sqlWorkflowJobToParamsJob(workflowJob) + asParams, err := sqlWorkflowJobToParamsJob(workflowJob) + if err != nil { + return params.Job{}, errors.Wrap(err, "converting job") + } + s.sendNotify(common.JobEntityType, operation, asParams) + + return asParams, nil } // ListJobsByStatus lists all jobs for a given status. From 230f002902e5616a7c48c38c9a9bca1450c09ef3 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Thu, 20 Jun 2024 11:01:57 +0000 Subject: [PATCH 11/11] Add more tests Signed-off-by: Gabriel Adrian Samfira --- database/watcher/watcher_store_test.go | 116 +++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index f74f836d..bb2a42de 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/google/uuid" "github.com/stretchr/testify/suite" commonParams "github.com/cloudbase/garm-provider-common/params" @@ -20,6 +21,121 @@ type WatcherStoreTestSuite struct { ctx context.Context } +func (s *WatcherStoreTestSuite) TestJobWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "job-test", + watcher.WithEntityTypeFilter(common.JobEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + jobParams := params.Job{ + ID: 1, + RunID: 2, + Action: "test-action", + Conclusion: "started", + Status: "in_progress", + Name: "test-job", + } + + job, err := s.store.CreateOrUpdateJob(s.ctx, jobParams) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.JobEntityType, + Operation: common.CreateOperation, + Payload: job, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + jobParams.Conclusion = "success" + updatedJob, err := s.store.CreateOrUpdateJob(s.ctx, jobParams) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.JobEntityType, + Operation: common.UpdateOperation, + Payload: updatedJob, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + entityID, err := uuid.NewUUID() + s.Require().NoError(err) + + err = s.store.LockJob(s.ctx, updatedJob.ID, entityID.String()) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(event.Operation, common.UpdateOperation) + s.Require().Equal(event.EntityType, common.JobEntityType) + + job, ok := event.Payload.(params.Job) + s.Require().True(ok) + s.Require().Equal(job.ID, updatedJob.ID) + s.Require().Equal(job.LockedBy, entityID) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.UnlockJob(s.ctx, updatedJob.ID, entityID.String()) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(event.Operation, common.UpdateOperation) + s.Require().Equal(event.EntityType, common.JobEntityType) + + job, ok := event.Payload.(params.Job) + s.Require().True(ok) + s.Require().Equal(job.ID, updatedJob.ID) + s.Require().Equal(job.LockedBy, uuid.Nil) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + jobParams.Status = "queued" + jobParams.LockedBy = entityID + + updatedJob, err = s.store.CreateOrUpdateJob(s.ctx, jobParams) + s.Require().NoError(err) + select { + case <-consumer.Watch(): + // throw away event. + case <-time.After(1 * time.Second): + s.T().Fatal("unexpected payload received") + } + + err = s.store.BreakLockJobIsQueued(s.ctx, updatedJob.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(event.Operation, common.UpdateOperation) + s.Require().Equal(event.EntityType, common.JobEntityType) + + job, ok := event.Payload.(params.Job) + s.Require().True(ok) + s.Require().Equal(job.ID, updatedJob.ID) + s.Require().Equal(job.LockedBy, uuid.Nil) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + func (s *WatcherStoreTestSuite) TestInstanceWatcher() { consumer, err := watcher.RegisterConsumer( s.ctx, "instance-test",