Skip to content

Commit

Permalink
Enforce same endpoint when updating credentials
Browse files Browse the repository at this point in the history
When updating credentials on an entity, we must ensure that the new credentials
belong to the same endpoint as the entity.

When an entity is created, the endpoint is determined by the credentials that
were used during the create operation. From that point forward the entity is
associated with an endpoint, and that cannot change.

Signed-off-by: Gabriel Adrian Samfira <[email protected]>
  • Loading branch information
gabriel-samfira committed Apr 17, 2024
1 parent 14a8a65 commit d111897
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 26 deletions.
23 changes: 22 additions & 1 deletion database/sql/enterprise.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam
if err != nil {
return errors.Wrap(err, "creating enterprise")
}
if creds.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "credentials have no endpoint")
}
newEnterprise.CredentialsID = &creds.ID
newEnterprise.CredentialsName = creds.Name
newEnterprise.EndpointName = creds.EndpointName

q := tx.Create(&newEnterprise)
if q.Error != nil {
return errors.Wrap(q.Error, "creating enterprise")
}

newEnterprise.Credentials = creds
newEnterprise.Endpoint = creds.Endpoint

return nil
})
Expand Down Expand Up @@ -132,16 +138,27 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string,
var creds GithubCredentials
err := s.conn.Transaction(func(tx *gorm.DB) error {
var err error
enterprise, err = s.getEnterpriseByID(ctx, tx, enterpriseID, "Credentials", "Endpoint")
enterprise, err = s.getEnterpriseByID(ctx, tx, enterpriseID)
if err != nil {
return errors.Wrap(err, "fetching enterprise")
}

if enterprise.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "enterprise has no endpoint")
}

if param.CredentialsName != "" {
creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false)
if err != nil {
return errors.Wrap(err, "fetching credentials")
}
if creds.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "credentials have no endpoint")
}

if *creds.EndpointName != *enterprise.EndpointName {
return errors.Wrap(runnerErrors.ErrBadRequest, "endpoint mismatch")
}
enterprise.CredentialsID = &creds.ID
}
if param.WebhookSecret != "" {
Expand Down Expand Up @@ -171,6 +188,10 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string,
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}

enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials")
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}
newParams, err := s.sqlToCommonEnterprise(enterprise, true)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
Expand Down
36 changes: 29 additions & 7 deletions database/sql/enterprise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,11 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseDBCreateErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Enterprises[0].CredentialsName).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.testCreds.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).AddRow(s.testCreds.ID, s.testCreds.Endpoint))
s.Fixtures.SQLMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_endpoints` WHERE `github_endpoints`.`name` = ? AND `github_endpoints`.`deleted_at` IS NULL")).
WithArgs(s.testCreds.Endpoint).
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(s.testCreds.Endpoint))
s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("INSERT INTO `enterprises`")).
WillReturnError(fmt.Errorf("creating enterprise mock error"))
Expand Down Expand Up @@ -346,11 +350,17 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBEncryptErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Enterprises[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.Fixtures.Enterprises[0].ID, s.Fixtures.Enterprises[0].Endpoint.Name))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.secondaryTestCreds.ID, s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_endpoints` WHERE `github_endpoints`.`name` = ? AND `github_endpoints`.`deleted_at` IS NULL")).
WithArgs(s.testCreds.Endpoint).
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectRollback()

_, err := s.StoreSQLMocked.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams)
Expand All @@ -365,11 +375,17 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBSaveErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Enterprises[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.Fixtures.Enterprises[0].ID, s.Fixtures.Enterprises[0].Endpoint.Name))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.secondaryTestCreds.ID, s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_endpoints` WHERE `github_endpoints`.`name` = ? AND `github_endpoints`.`deleted_at` IS NULL")).
WithArgs(s.testCreds.Endpoint).
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.
ExpectExec(("UPDATE `enterprises` SET")).
WillReturnError(fmt.Errorf("saving enterprise mock error"))
Expand All @@ -390,11 +406,17 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBDecryptingErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Enterprises[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.Fixtures.Enterprises[0].ID, s.Fixtures.Enterprises[0].Endpoint.Name))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.secondaryTestCreds.ID, s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_endpoints` WHERE `github_endpoints`.`name` = ? AND `github_endpoints`.`deleted_at` IS NULL")).
WithArgs(s.testCreds.Endpoint).
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectRollback()

_, err := s.StoreSQLMocked.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams)
Expand Down
22 changes: 21 additions & 1 deletion database/sql/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,20 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN
if err != nil {
return errors.Wrap(err, "creating org")
}
if creds.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "credentials have no endpoint")
}
newOrg.CredentialsID = &creds.ID
newOrg.CredentialsName = creds.Name
newOrg.EndpointName = creds.EndpointName

q := tx.Create(&newOrg)
if q.Error != nil {
return errors.Wrap(q.Error, "creating org")
}

newOrg.Credentials = creds
newOrg.Endpoint = creds.Endpoint

return nil
})
Expand Down Expand Up @@ -123,17 +129,27 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para
var creds GithubCredentials
err := s.conn.Transaction(func(tx *gorm.DB) error {
var err error
org, err = s.getOrgByID(ctx, tx, orgID, "Credentials", "Endpoint")
org, err = s.getOrgByID(ctx, tx, orgID)
if err != nil {
return errors.Wrap(err, "fetching org")
}
if org.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "org has no endpoint")
}

if param.CredentialsName != "" {
org.CredentialsName = param.CredentialsName
creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false)
if err != nil {
return errors.Wrap(err, "fetching credentials")
}
if creds.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "credentials have no endpoint")
}

if *creds.EndpointName != *org.EndpointName {
return errors.Wrap(runnerErrors.ErrBadRequest, "endpoint mismatch")
}
org.CredentialsID = &creds.ID
}

Expand Down Expand Up @@ -164,6 +180,10 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para
return params.Organization{}, errors.Wrap(err, "saving org")
}

org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials")
if err != nil {
return params.Organization{}, errors.Wrap(err, "updating enterprise")
}
newParams, err := s.sqlToCommonOrganization(org, true)
if err != nil {
return params.Organization{}, errors.Wrap(err, "saving org")
Expand Down
37 changes: 30 additions & 7 deletions database/sql/organizations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,12 @@ func (s *OrgTestSuite) TestCreateOrganizationDBCreateErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Orgs[0].CredentialsName).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.testCreds.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.testCreds.ID, s.githubEndpoint.Name))
s.Fixtures.SQLMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_endpoints` WHERE `github_endpoints`.`name` = ? AND `github_endpoints`.`deleted_at` IS NULL")).
WithArgs(s.testCreds.Endpoint).
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(s.githubEndpoint.Name))
s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("INSERT INTO `organizations`")).
WillReturnError(fmt.Errorf("creating org mock error"))
Expand Down Expand Up @@ -347,11 +352,17 @@ func (s *OrgTestSuite) TestUpdateOrganizationDBEncryptErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Orgs[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.Fixtures.Orgs[0].ID, s.Fixtures.Orgs[0].Endpoint.Name))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.secondaryTestCreds.ID, s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_endpoints` WHERE `github_endpoints`.`name` = ? AND `github_endpoints`.`deleted_at` IS NULL")).
WithArgs(s.testCreds.Endpoint).
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectRollback()

_, err := s.StoreSQLMocked.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams)
Expand All @@ -366,11 +377,17 @@ func (s *OrgTestSuite) TestUpdateOrganizationDBSaveErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Orgs[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.Fixtures.Orgs[0].ID, s.Fixtures.Orgs[0].Endpoint.Name))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.secondaryTestCreds.ID, s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_endpoints` WHERE `github_endpoints`.`name` = ? AND `github_endpoints`.`deleted_at` IS NULL")).
WithArgs(s.testCreds.Endpoint).
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.
ExpectExec(("UPDATE `organizations` SET")).
WillReturnError(fmt.Errorf("saving org mock error"))
Expand All @@ -391,11 +408,17 @@ func (s *OrgTestSuite) TestUpdateOrganizationDBDecryptingErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Orgs[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.Fixtures.Orgs[0].ID, s.Fixtures.Orgs[0].Endpoint.Name))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "endpoint_name"}).
AddRow(s.secondaryTestCreds.ID, s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_endpoints` WHERE `github_endpoints`.`name` = ? AND `github_endpoints`.`deleted_at` IS NULL")).
WithArgs(s.testCreds.Endpoint).
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(s.secondaryTestCreds.Endpoint))
s.Fixtures.SQLMock.ExpectRollback()

_, err := s.StoreSQLMocked.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams)
Expand Down
23 changes: 22 additions & 1 deletion database/sql/repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,21 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credent
if err != nil {
return errors.Wrap(err, "creating repository")
}
if creds.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "credentials have no endpoint")
}
newRepo.CredentialsID = &creds.ID
newRepo.CredentialsName = creds.Name
newRepo.EndpointName = creds.EndpointName

q := tx.Create(&newRepo)
if q.Error != nil {
return errors.Wrap(q.Error, "creating repository")
}

newRepo.Credentials = creds
newRepo.Endpoint = creds.Endpoint

return nil
})
if err != nil {
Expand Down Expand Up @@ -121,17 +128,27 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
var creds GithubCredentials
err := s.conn.Transaction(func(tx *gorm.DB) error {
var err error
repo, err = s.getRepoByID(ctx, tx, repoID, "Credentials", "Endpoint")
repo, err = s.getRepoByID(ctx, tx, repoID)
if err != nil {
return errors.Wrap(err, "fetching repo")
}
if repo.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "repository has no endpoint")
}

if param.CredentialsName != "" {
repo.CredentialsName = param.CredentialsName
creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false)
if err != nil {
return errors.Wrap(err, "fetching credentials")
}
if creds.EndpointName == nil {
return errors.Wrap(runnerErrors.ErrUnprocessable, "credentials have no endpoint")
}

if *creds.EndpointName != *repo.EndpointName {
return errors.Wrap(runnerErrors.ErrBadRequest, "endpoint mismatch")
}
repo.CredentialsID = &creds.ID
}

Expand Down Expand Up @@ -161,6 +178,10 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
return params.Repository{}, errors.Wrap(err, "saving repo")
}

repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials")
if err != nil {
return params.Repository{}, errors.Wrap(err, "updating enterprise")
}
newParams, err := s.sqlToCommonRepository(repo, true)
if err != nil {
return params.Repository{}, errors.Wrap(err, "saving repo")
Expand Down
Loading

0 comments on commit d111897

Please sign in to comment.