Skip to content

Commit

Permalink
[PM-5766] Automatic Tax Feature Flag (#3729)
Browse files Browse the repository at this point in the history
* Added feature flag constant

* Wrapped Automatic Tax logic behind feature flag

* Only getting customer if feature is anabled.

* Enabled feature flag in unit tests

* Made IPaymentService scoped

* Added missing StripeFacade calls

(cherry picked from commit 9a1519f)
  • Loading branch information
cturnbull-bitwarden committed Feb 1, 2024
1 parent a33a890 commit 581aec3
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 64 deletions.
116 changes: 78 additions & 38 deletions src/Billing/Controllers/StripeController.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Bit.Billing.Constants;
using Bit.Billing.Services;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Context;
using Bit.Core.Enums;
Expand All @@ -21,6 +22,7 @@
using Event = Stripe.Event;
using PaymentMethod = Stripe.PaymentMethod;
using Subscription = Stripe.Subscription;
using TaxRate = Bit.Core.Entities.TaxRate;
using Transaction = Bit.Core.Entities.Transaction;
using TransactionType = Bit.Core.Enums.TransactionType;

Expand Down Expand Up @@ -50,6 +52,7 @@ public class StripeController : Controller
private readonly GlobalSettings _globalSettings;
private readonly IStripeEventService _stripeEventService;
private readonly IStripeFacade _stripeFacade;
private readonly IFeatureService _featureService;

public StripeController(
GlobalSettings globalSettings,
Expand All @@ -68,7 +71,8 @@ public class StripeController : Controller
IUserRepository userRepository,
ICurrentContext currentContext,
IStripeEventService stripeEventService,
IStripeFacade stripeFacade)
IStripeFacade stripeFacade,
IFeatureService featureService)
{
_billingSettings = billingSettings?.Value;
_hostingEnvironment = hostingEnvironment;
Expand All @@ -95,6 +99,7 @@ public class StripeController : Controller
_globalSettings = globalSettings;
_stripeEventService = stripeEventService;
_stripeFacade = stripeFacade;
_featureService = featureService;
}

[HttpPost("webhook")]
Expand Down Expand Up @@ -222,17 +227,29 @@ public async Task<IActionResult> PostWebhook([FromQuery] string key)
$"Received null Subscription from Stripe for ID '{invoice.SubscriptionId}' while processing Event with ID '{parsedEvent.Id}'");
}

if (!subscription.AutomaticTax.Enabled)
var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax);
if (pm5766AutomaticTaxIsEnabled)
{
subscription = await _stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
DefaultTaxRates = new List<string>(),
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
var customer = await _stripeFacade.GetCustomer(subscription.CustomerId);
if (!subscription.AutomaticTax.Enabled &&
!string.IsNullOrEmpty(customer.Address?.PostalCode) &&
!string.IsNullOrEmpty(customer.Address?.Country))
{
subscription = await _stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
DefaultTaxRates = new List<string>(),
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
}

var (organizationId, userId) = GetIdsFromMetaData(subscription.Metadata);

var updatedSubscription = pm5766AutomaticTaxIsEnabled
? subscription
: await VerifyCorrectTaxRateForCharge(invoice, subscription);

var (organizationId, userId) = GetIdsFromMetaData(updatedSubscription.Metadata);

var invoiceLineItemDescriptions = invoice.Lines.Select(i => i.Description).ToList();

Expand All @@ -253,7 +270,7 @@ async Task SendEmails(IEnumerable<string> emails)

if (organizationId.HasValue)
{
if (IsSponsoredSubscription(subscription))
if (IsSponsoredSubscription(updatedSubscription))
{
await _validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value);
}
Expand Down Expand Up @@ -301,22 +318,20 @@ async Task SendEmails(IEnumerable<string> emails)

Tuple<Guid?, Guid?> ids = null;
Subscription subscription = null;
var subscriptionService = new SubscriptionService();

if (charge.InvoiceId != null)
{
var invoiceService = new InvoiceService();
var invoice = await invoiceService.GetAsync(charge.InvoiceId);
var invoice = await _stripeFacade.GetInvoice(charge.InvoiceId);
if (invoice?.SubscriptionId != null)
{
subscription = await subscriptionService.GetAsync(invoice.SubscriptionId);
subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId);
ids = GetIdsFromMetaData(subscription?.Metadata);
}
}

if (subscription == null || ids == null || (ids.Item1.HasValue && ids.Item2.HasValue))
{
var subscriptions = await subscriptionService.ListAsync(new SubscriptionListOptions
var subscriptions = await _stripeFacade.ListSubscriptions(new SubscriptionListOptions
{
Customer = charge.CustomerId
});
Expand Down Expand Up @@ -470,8 +485,7 @@ async Task SendEmails(IEnumerable<string> emails)
var invoice = await _stripeEventService.GetInvoice(parsedEvent, true);
if (invoice.Paid && invoice.BillingReason == "subscription_create")
{
var subscriptionService = new SubscriptionService();
var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId);
var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId);
if (subscription?.Status == StripeSubscriptionStatus.Active)
{
if (DateTime.UtcNow - invoice.Created < TimeSpan.FromMinutes(1))
Expand Down Expand Up @@ -576,7 +590,6 @@ private async Task HandlePaymentMethodAttachedAsync(PaymentMethod paymentMethod)
return;
}

var subscriptionService = new SubscriptionService();
var subscriptionListOptions = new SubscriptionListOptions
{
Customer = paymentMethod.CustomerId,
Expand All @@ -587,7 +600,7 @@ private async Task HandlePaymentMethodAttachedAsync(PaymentMethod paymentMethod)
StripeList<Subscription> unpaidSubscriptions;
try
{
unpaidSubscriptions = await subscriptionService.ListAsync(subscriptionListOptions);
unpaidSubscriptions = await _stripeFacade.ListSubscriptions(subscriptionListOptions);
}
catch (Exception e)
{
Expand Down Expand Up @@ -682,8 +695,7 @@ private async Task AttemptToPayOpenSubscriptionAsync(Subscription unpaidSubscrip

private async Task<bool> AttemptToPayInvoiceAsync(Invoice invoice, bool attemptToPayWithStripe = false)
{
var customerService = new CustomerService();
var customer = await customerService.GetAsync(invoice.CustomerId);
var customer = await _stripeFacade.GetCustomer(invoice.CustomerId);

if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false)
{
Expand All @@ -708,8 +720,7 @@ private async Task<bool> AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice,
return false;
}

var subscriptionService = new SubscriptionService();
var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId);
var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId);
var ids = GetIdsFromMetaData(subscription?.Metadata);
if (!ids.Item1.HasValue && !ids.Item2.HasValue)
{
Expand Down Expand Up @@ -777,10 +788,9 @@ private async Task<bool> AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice,
return false;
}

var invoiceService = new InvoiceService();
try
{
await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions
await _stripeFacade.UpdateInvoice(invoice.Id, new InvoiceUpdateOptions
{
Metadata = new Dictionary<string, string>
{
Expand All @@ -789,14 +799,14 @@ private async Task<bool> AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice,
transactionResult.Target.PayPalDetails?.AuthorizationId
}
});
await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true });
await _stripeFacade.PayInvoice(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true });
}
catch (Exception e)
{
await _btGateway.Transaction.RefundAsync(transactionResult.Target.Id);
if (e.Message.Contains("Invoice is already paid"))
{
await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions
await _stripeFacade.UpdateInvoice(invoice.Id, new InvoiceUpdateOptions
{
Metadata = invoice.Metadata
});
Expand All @@ -814,8 +824,7 @@ private async Task<bool> AttemptToPayInvoiceWithStripeAsync(Invoice invoice)
{
try
{
var invoiceService = new InvoiceService();
await invoiceService.PayAsync(invoice.Id);
await _stripeFacade.PayInvoice(invoice.Id);
return true;
}
catch (Exception e)
Expand All @@ -835,15 +844,49 @@ private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice)
invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null;
}

private async Task<Subscription> VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription)
{
if (string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.Country) ||
string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.PostalCode))
{
return subscription;
}

var localBitwardenTaxRates = await _taxRateRepository.GetByLocationAsync(
new TaxRate()
{
Country = invoice.CustomerAddress.Country,
PostalCode = invoice.CustomerAddress.PostalCode
}
);

if (!localBitwardenTaxRates.Any())
{
return subscription;
}

var stripeTaxRate = await _stripeFacade.GetTaxRate(localBitwardenTaxRates.First().Id);
if (stripeTaxRate == null || subscription.DefaultTaxRates.Any(x => x == stripeTaxRate))
{
return subscription;
}

subscription.DefaultTaxRates = new List<Stripe.TaxRate> { stripeTaxRate };

var subscriptionOptions = new SubscriptionUpdateOptions { DefaultTaxRates = new List<string> { stripeTaxRate.Id } };
subscription = await _stripeFacade.UpdateSubscription(subscription.Id, subscriptionOptions);

return subscription;
}

private static bool IsSponsoredSubscription(Subscription subscription) =>
StaticStore.SponsoredPlans.Any(p => p.StripePlanId == subscription.Id);

private async Task HandlePaymentFailed(Invoice invoice)
{
if (!invoice.Paid && invoice.AttemptCount > 1 && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice))
{
var subscriptionService = new SubscriptionService();
var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId);
var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId);
// attempt count 4 = 11 days after initial failure
if (invoice.AttemptCount <= 3 ||
!subscription.Items.Any(i => i.Price.Id is PremiumPlanId or PremiumPlanIdAppStore))
Expand All @@ -853,23 +896,20 @@ private async Task HandlePaymentFailed(Invoice invoice)
}
}

private async Task CancelSubscription(string subscriptionId)
{
await new SubscriptionService().CancelAsync(subscriptionId, new SubscriptionCancelOptions());
}
private async Task CancelSubscription(string subscriptionId) =>
await _stripeFacade.CancelSubscription(subscriptionId, new SubscriptionCancelOptions());

private async Task VoidOpenInvoices(string subscriptionId)
{
var invoiceService = new InvoiceService();
var options = new InvoiceListOptions
{
Status = StripeInvoiceStatus.Open,
Subscription = subscriptionId
};
var invoices = invoiceService.List(options);
var invoices = await _stripeFacade.ListInvoices(options);
foreach (var invoice in invoices)
{
await invoiceService.VoidInvoiceAsync(invoice.Id);
await _stripeFacade.VoidInvoice(invoice.Id);
}
}
}
40 changes: 40 additions & 0 deletions src/Billing/Services/IStripeFacade.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,40 @@ public interface IStripeFacade
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<StripeList<Invoice>> ListInvoices(
InvoiceListOptions options = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<Invoice> UpdateInvoice(
string invoiceId,
InvoiceUpdateOptions invoiceGetOptions = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<Invoice> PayInvoice(
string invoiceId,
InvoicePayOptions options = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<Invoice> VoidInvoice(
string invoiceId,
InvoiceVoidOptions options = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<PaymentMethod> GetPaymentMethod(
string paymentMethodId,
PaymentMethodGetOptions paymentMethodGetOptions = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<StripeList<Subscription>> ListSubscriptions(
SubscriptionListOptions options = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<Subscription> GetSubscription(
string subscriptionId,
SubscriptionGetOptions subscriptionGetOptions = null,
Expand All @@ -39,4 +67,16 @@ public interface IStripeFacade
SubscriptionUpdateOptions subscriptionGetOptions = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<Subscription> CancelSubscription(
string subscriptionId,
SubscriptionCancelOptions options = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);

Task<TaxRate> GetTaxRate(
string taxRateId,
TaxRateGetOptions options = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);
}

0 comments on commit 581aec3

Please sign in to comment.