Skip to content

Commit

Permalink
[SM-923] Add project service accounts access policies management endp…
Browse files Browse the repository at this point in the history
…oints (#3993)


* Add new models

* Update repositories

* Add new authz handler

* Add new query

* Add new command

* Add authz, command, and query to DI

* Add new endpoint to controller

* Add query unit tests

* Add api unit tests

* Add api integration tests
  • Loading branch information
Thomas-Avery committed May 2, 2024
1 parent e302ee1 commit 7f8cea5
Show file tree
Hide file tree
Showing 23 changed files with 1,559 additions and 29 deletions.
@@ -0,0 +1,107 @@
#nullable enable
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.SecretsManager.AuthorizationRequirements;
using Bit.Core.SecretsManager.Enums.AccessPolicies;
using Bit.Core.SecretsManager.Models.Data.AccessPolicyUpdates;
using Bit.Core.SecretsManager.Queries.Interfaces;
using Bit.Core.SecretsManager.Repositories;
using Microsoft.AspNetCore.Authorization;

namespace Bit.Commercial.Core.SecretsManager.AuthorizationHandlers.AccessPolicies;

public class ProjectServiceAccountsAccessPoliciesAuthorizationHandler : AuthorizationHandler<
ProjectServiceAccountsAccessPoliciesOperationRequirement,
ProjectServiceAccountsAccessPoliciesUpdates>
{
private readonly IAccessClientQuery _accessClientQuery;
private readonly ICurrentContext _currentContext;
private readonly IProjectRepository _projectRepository;
private readonly IServiceAccountRepository _serviceAccountRepository;

public ProjectServiceAccountsAccessPoliciesAuthorizationHandler(ICurrentContext currentContext,
IAccessClientQuery accessClientQuery,
IProjectRepository projectRepository,
IServiceAccountRepository serviceAccountRepository)
{
_currentContext = currentContext;
_accessClientQuery = accessClientQuery;
_serviceAccountRepository = serviceAccountRepository;
_projectRepository = projectRepository;
}

protected override async Task HandleRequirementAsync(AuthorizationHandlerContext context,
ProjectServiceAccountsAccessPoliciesOperationRequirement requirement,
ProjectServiceAccountsAccessPoliciesUpdates resource)
{
if (!_currentContext.AccessSecretsManager(resource.OrganizationId))
{
return;
}

// Only users and admins should be able to manipulate access policies
var (accessClient, userId) =
await _accessClientQuery.GetAccessClientAsync(context.User, resource.OrganizationId);
if (accessClient != AccessClientType.User && accessClient != AccessClientType.NoAccessCheck)
{
return;
}

switch (requirement)
{
case not null when requirement == ProjectServiceAccountsAccessPoliciesOperations.Updates:
await CanUpdateAsync(context, requirement, resource, accessClient,
userId);
break;
default:
throw new ArgumentException("Unsupported operation requirement type provided.",
nameof(requirement));
}
}

private async Task CanUpdateAsync(AuthorizationHandlerContext context,
ProjectServiceAccountsAccessPoliciesOperationRequirement requirement,
ProjectServiceAccountsAccessPoliciesUpdates resource,
AccessClientType accessClient, Guid userId)
{
var access =
await _projectRepository.AccessToProjectAsync(resource.ProjectId, userId,
accessClient);
if (!access.Write)
{
return;
}

var serviceAccountIds = resource.ServiceAccountAccessPolicyUpdates.Select(update =>
update.AccessPolicy.ServiceAccountId!.Value).ToList();

var inSameOrganization =
await _serviceAccountRepository.ServiceAccountsAreInOrganizationAsync(serviceAccountIds,
resource.OrganizationId);
if (!inSameOrganization)
{
return;
}

// Users can only create access policies for service accounts they have access to.
// User can delete and update any service account access policy if they have write access to the project.
var serviceAccountIdsToCheck = resource.ServiceAccountAccessPolicyUpdates
.Where(update => update.Operation == AccessPolicyOperation.Create).Select(update =>
update.AccessPolicy.ServiceAccountId!.Value).ToList();

if (serviceAccountIdsToCheck.Count == 0)
{
context.Succeed(requirement);
return;
}

var serviceAccountsAccess =
await _serviceAccountRepository.AccessToServiceAccountsAsync(serviceAccountIdsToCheck, userId,
accessClient);
if (serviceAccountsAccess.Count == serviceAccountIdsToCheck.Count &&
serviceAccountsAccess.All(a => a.Value.Write))
{
context.Succeed(requirement);
}
}
}
@@ -0,0 +1,26 @@
#nullable enable
using Bit.Core.SecretsManager.Commands.AccessPolicies.Interfaces;
using Bit.Core.SecretsManager.Models.Data.AccessPolicyUpdates;
using Bit.Core.SecretsManager.Repositories;

namespace Bit.Commercial.Core.SecretsManager.Commands.AccessPolicies;

public class UpdateProjectServiceAccountsAccessPoliciesCommand : IUpdateProjectServiceAccountsAccessPoliciesCommand
{
private readonly IAccessPolicyRepository _accessPolicyRepository;

public UpdateProjectServiceAccountsAccessPoliciesCommand(IAccessPolicyRepository accessPolicyRepository)
{
_accessPolicyRepository = accessPolicyRepository;
}

public async Task UpdateAsync(ProjectServiceAccountsAccessPoliciesUpdates accessPoliciesUpdates)
{
if (!accessPoliciesUpdates.ServiceAccountAccessPolicyUpdates.Any())
{
return;
}

await _accessPolicyRepository.UpdateProjectServiceAccountsAccessPoliciesAsync(accessPoliciesUpdates);
}
}
@@ -0,0 +1,44 @@
#nullable enable
using Bit.Core.SecretsManager.Enums.AccessPolicies;
using Bit.Core.SecretsManager.Models.Data;
using Bit.Core.SecretsManager.Models.Data.AccessPolicyUpdates;
using Bit.Core.SecretsManager.Queries.AccessPolicies.Interfaces;
using Bit.Core.SecretsManager.Repositories;

namespace Bit.Commercial.Core.SecretsManager.Queries.AccessPolicies;

public class ProjectServiceAccountsAccessPoliciesUpdatesQuery : IProjectServiceAccountsAccessPoliciesUpdatesQuery
{
private readonly IAccessPolicyRepository _accessPolicyRepository;

public ProjectServiceAccountsAccessPoliciesUpdatesQuery(IAccessPolicyRepository accessPolicyRepository)
{
_accessPolicyRepository = accessPolicyRepository;
}

public async Task<ProjectServiceAccountsAccessPoliciesUpdates> GetAsync(
ProjectServiceAccountsAccessPolicies projectServiceAccountsAccessPolicies)
{
var currentPolicies =
await _accessPolicyRepository.GetProjectServiceAccountsAccessPoliciesAsync(
projectServiceAccountsAccessPolicies.ProjectId);

if (currentPolicies == null)
{
return new ProjectServiceAccountsAccessPoliciesUpdates
{
ProjectId = projectServiceAccountsAccessPolicies.ProjectId,
OrganizationId = projectServiceAccountsAccessPolicies.OrganizationId,
ServiceAccountAccessPolicyUpdates =
projectServiceAccountsAccessPolicies.ServiceAccountAccessPolicies.Select(p =>
new ServiceAccountProjectAccessPolicyUpdate
{
Operation = AccessPolicyOperation.Create,
AccessPolicy = p
})
};
}

return currentPolicies.GetPolicyUpdates(projectServiceAccountsAccessPolicies);
}
}
Expand Up @@ -42,12 +42,14 @@ public static void AddSecretsManagerServices(this IServiceCollection services)
services.AddScoped<IAuthorizationHandler, ProjectPeopleAccessPoliciesAuthorizationHandler>();
services.AddScoped<IAuthorizationHandler, ServiceAccountPeopleAccessPoliciesAuthorizationHandler>();
services.AddScoped<IAuthorizationHandler, ServiceAccountGrantedPoliciesAuthorizationHandler>();
services.AddScoped<IAuthorizationHandler, ProjectServiceAccountsAccessPoliciesAuthorizationHandler>();
services.AddScoped<IAccessClientQuery, AccessClientQuery>();
services.AddScoped<IMaxProjectsQuery, MaxProjectsQuery>();
services.AddScoped<ISameOrganizationQuery, SameOrganizationQuery>();
services.AddScoped<IServiceAccountSecretsDetailsQuery, ServiceAccountSecretsDetailsQuery>();
services.AddScoped<IServiceAccountGrantedPolicyUpdatesQuery, ServiceAccountGrantedPolicyUpdatesQuery>();
services.AddScoped<ISecretsSyncQuery, SecretsSyncQuery>();
services.AddScoped<IProjectServiceAccountsAccessPoliciesUpdatesQuery, ProjectServiceAccountsAccessPoliciesUpdatesQuery>();
services.AddScoped<ICreateSecretCommand, CreateSecretCommand>();
services.AddScoped<IUpdateSecretCommand, UpdateSecretCommand>();
services.AddScoped<IDeleteSecretCommand, DeleteSecretCommand>();
Expand All @@ -67,5 +69,6 @@ public static void AddSecretsManagerServices(this IServiceCollection services)
services.AddScoped<IEmptyTrashCommand, EmptyTrashCommand>();
services.AddScoped<IRestoreTrashCommand, RestoreTrashCommand>();
services.AddScoped<IUpdateServiceAccountGrantedPoliciesCommand, UpdateServiceAccountGrantedPoliciesCommand>();
services.AddScoped<IUpdateProjectServiceAccountsAccessPoliciesCommand, UpdateProjectServiceAccountsAccessPoliciesCommand>();
}
}
Expand Up @@ -465,12 +465,68 @@ public async Task UpdateServiceAccountGrantedPoliciesAsync(ServiceAccountGranted
dbContext.RemoveRange(policiesToDelete);
}

await UpsertServiceAccountGrantedPoliciesAsync(dbContext, currentAccessPolicies,
await UpsertServiceAccountProjectPoliciesAsync(dbContext, currentAccessPolicies,
updates.ProjectGrantedPolicyUpdates.Where(pu => pu.Operation != AccessPolicyOperation.Delete).ToList());
await UpdateServiceAccountRevisionAsync(dbContext, updates.ServiceAccountId);
await dbContext.SaveChangesAsync();
}

public async Task<ProjectServiceAccountsAccessPolicies?> GetProjectServiceAccountsAccessPoliciesAsync(Guid projectId)
{
await using var scope = ServiceScopeFactory.CreateAsyncScope();
var dbContext = GetDatabaseContext(scope);
var entities = await dbContext.ServiceAccountProjectAccessPolicy
.Where(ap => ap.GrantedProjectId == projectId)
.Include(ap => ap.ServiceAccount)
.Include(ap => ap.GrantedProject)
.ToListAsync();

if (entities.Count == 0)
{
return null;
}

return new ProjectServiceAccountsAccessPolicies(projectId, entities.Select(MapToCore).ToList());
}

public async Task UpdateProjectServiceAccountsAccessPoliciesAsync(
ProjectServiceAccountsAccessPoliciesUpdates updates)
{
await using var scope = ServiceScopeFactory.CreateAsyncScope();
var dbContext = GetDatabaseContext(scope);
await using var transaction = await dbContext.Database.BeginTransactionAsync();

var currentAccessPolicies = await dbContext.ServiceAccountProjectAccessPolicy
.Where(ap => ap.GrantedProjectId == updates.ProjectId)
.ToListAsync();

if (currentAccessPolicies.Count != 0)
{
var serviceAccountIdsToDelete = updates.ServiceAccountAccessPolicyUpdates
.Where(pu => pu.Operation == AccessPolicyOperation.Delete)
.Select(pu => pu.AccessPolicy.ServiceAccountId!.Value)
.ToList();

var accessPolicyIdsToDelete = currentAccessPolicies
.Where(entity => serviceAccountIdsToDelete.Contains(entity.ServiceAccountId!.Value))
.Select(ap => ap.Id)
.ToList();

await dbContext.ServiceAccountProjectAccessPolicy
.Where(ap => accessPolicyIdsToDelete.Contains(ap.Id))
.ExecuteDeleteAsync();
}

await UpsertServiceAccountProjectPoliciesAsync(dbContext, currentAccessPolicies,
updates.ServiceAccountAccessPolicyUpdates.Where(update => update.Operation != AccessPolicyOperation.Delete)
.ToList());
var effectedServiceAccountIds = updates.ServiceAccountAccessPolicyUpdates
.Select(sa => sa.AccessPolicy.ServiceAccountId!.Value).ToList();
await UpdateServiceAccountsRevisionAsync(dbContext, effectedServiceAccountIds);
await dbContext.SaveChangesAsync();
await transaction.CommitAsync();
}

private static async Task UpsertPeoplePoliciesAsync(DatabaseContext dbContext,
List<BaseAccessPolicy> policies, IReadOnlyCollection<AccessPolicy> userPolicyEntities,
IReadOnlyCollection<AccessPolicy> groupPolicyEntities)
Expand Down Expand Up @@ -506,7 +562,7 @@ public async Task UpdateServiceAccountGrantedPoliciesAsync(ServiceAccountGranted
}
}

private async Task UpsertServiceAccountGrantedPoliciesAsync(DatabaseContext dbContext,
private async Task UpsertServiceAccountProjectPoliciesAsync(DatabaseContext dbContext,
IReadOnlyCollection<ServiceAccountProjectAccessPolicy> currentPolices,
List<ServiceAccountProjectAccessPolicyUpdate> policyUpdates)
{
Expand All @@ -515,7 +571,8 @@ public async Task UpdateServiceAccountGrantedPoliciesAsync(ServiceAccountGranted
{
var updatedEntity = MapToEntity(policyUpdate.AccessPolicy);
var currentEntity = currentPolices.FirstOrDefault(e =>
e.GrantedProjectId == policyUpdate.AccessPolicy.GrantedProjectId!.Value);
e.GrantedProjectId == policyUpdate.AccessPolicy.GrantedProjectId!.Value &&
e.ServiceAccountId == policyUpdate.AccessPolicy.ServiceAccountId!.Value);

switch (policyUpdate.Operation)
{
Expand Down Expand Up @@ -628,4 +685,13 @@ private static async Task UpdateServiceAccountRevisionAsync(DatabaseContext dbCo
entity.RevisionDate = DateTime.UtcNow;
}
}

private static async Task UpdateServiceAccountsRevisionAsync(DatabaseContext dbContext, List<Guid> serviceAccountIds)
{
var utcNow = DateTime.UtcNow;
await dbContext.ServiceAccount
.Where(sa => serviceAccountIds.Contains(sa.Id))
.ExecuteUpdateAsync(setters =>
setters.SetProperty(sa => sa.RevisionDate, utcNow));
}
}
Expand Up @@ -112,30 +112,29 @@ await dbContext.ServiceAccount
public async Task<(bool Read, bool Write)> AccessToServiceAccountAsync(Guid id, Guid userId,
AccessClientType accessType)
{
using var scope = ServiceScopeFactory.CreateScope();
await using var scope = ServiceScopeFactory.CreateAsyncScope();
var dbContext = GetDatabaseContext(scope);

var serviceAccount = dbContext.ServiceAccount.Where(sa => sa.Id == id);
var serviceAccountQuery = dbContext.ServiceAccount.Where(sa => sa.Id == id);

var query = accessType switch
{
AccessClientType.NoAccessCheck => serviceAccount.Select(_ => new { Read = true, Write = true }),
AccessClientType.User => serviceAccount.Select(sa => new
{
Read = sa.UserAccessPolicies.Any(ap => ap.OrganizationUser.User.Id == userId && ap.Read) ||
sa.GroupAccessPolicies.Any(ap =>
ap.Group.GroupUsers.Any(gu => gu.OrganizationUser.User.Id == userId && ap.Read)),
Write = sa.UserAccessPolicies.Any(ap => ap.OrganizationUser.User.Id == userId && ap.Write) ||
sa.GroupAccessPolicies.Any(ap =>
ap.Group.GroupUsers.Any(gu => gu.OrganizationUser.User.Id == userId && ap.Write)),
}),
AccessClientType.ServiceAccount => serviceAccount.Select(_ => new { Read = false, Write = false }),
_ => serviceAccount.Select(_ => new { Read = false, Write = false }),
};
var accessQuery = BuildServiceAccountAccessQuery(serviceAccountQuery, userId, accessType);
var access = await accessQuery.FirstOrDefaultAsync();

return access == null ? (false, false) : (access.Read, access.Write);
}

public async Task<Dictionary<Guid, (bool Read, bool Write)>> AccessToServiceAccountsAsync(
IEnumerable<Guid> ids,
Guid userId,
AccessClientType accessType)
{
await using var scope = ServiceScopeFactory.CreateAsyncScope();
var dbContext = GetDatabaseContext(scope);

var policy = await query.FirstOrDefaultAsync();
var serviceAccountsQuery = dbContext.ServiceAccount.Where(p => ids.Contains(p.Id));
var accessQuery = BuildServiceAccountAccessQuery(serviceAccountsQuery, userId, accessType);

return policy == null ? (false, false) : (policy.Read, policy.Write);
return await accessQuery.ToDictionaryAsync(access => access.Id, access => (access.Read, access.Write));
}

public async Task<int> GetServiceAccountCountByOrganizationIdAsync(Guid organizationId)
Expand All @@ -148,6 +147,15 @@ public async Task<int> GetServiceAccountCountByOrganizationIdAsync(Guid organiza
}
}

public async Task<bool> ServiceAccountsAreInOrganizationAsync(List<Guid> serviceAccountIds, Guid organizationId)
{
await using var scope = ServiceScopeFactory.CreateAsyncScope();
var dbContext = GetDatabaseContext(scope);
var result = await dbContext.ServiceAccount.CountAsync(sa =>
sa.OrganizationId == organizationId && serviceAccountIds.Contains(sa.Id));
return serviceAccountIds.Count == result;
}

public async Task<IEnumerable<ServiceAccountSecretsDetails>> GetManyByOrganizationIdWithSecretsDetailsAsync(
Guid organizationId, Guid userId, AccessClientType accessType)
{
Expand Down Expand Up @@ -186,6 +194,27 @@ select new
return results;
}

private record ServiceAccountAccess(Guid Id, bool Read, bool Write);

private static IQueryable<ServiceAccountAccess> BuildServiceAccountAccessQuery(IQueryable<ServiceAccount> serviceAccountQuery, Guid userId,
AccessClientType accessType) =>
accessType switch
{
AccessClientType.NoAccessCheck => serviceAccountQuery.Select(sa => new ServiceAccountAccess(sa.Id, true, true)),
AccessClientType.User => serviceAccountQuery.Select(sa => new ServiceAccountAccess
(
sa.Id,
sa.UserAccessPolicies.Any(ap => ap.OrganizationUser.User.Id == userId && ap.Read) ||
sa.GroupAccessPolicies.Any(ap =>
ap.Group.GroupUsers.Any(gu => gu.OrganizationUser.User.Id == userId && ap.Read)),
sa.UserAccessPolicies.Any(ap => ap.OrganizationUser.User.Id == userId && ap.Write) ||
sa.GroupAccessPolicies.Any(ap =>
ap.Group.GroupUsers.Any(gu => gu.OrganizationUser.User.Id == userId && ap.Write))
)),
AccessClientType.ServiceAccount => serviceAccountQuery.Select(sa => new ServiceAccountAccess(sa.Id, false, false)),
_ => serviceAccountQuery.Select(sa => new ServiceAccountAccess(sa.Id, false, false))
};

private static Expression<Func<ServiceAccount, bool>> UserHasReadAccessToServiceAccount(Guid userId) => sa =>
sa.UserAccessPolicies.Any(ap => ap.OrganizationUser.User.Id == userId && ap.Read) ||
sa.GroupAccessPolicies.Any(ap => ap.Group.GroupUsers.Any(gu => gu.OrganizationUser.User.Id == userId && ap.Read));
Expand Down

0 comments on commit 7f8cea5

Please sign in to comment.