From 9bb1a5b80a0772aa09aeb20d00ab92bdefd7e906 Mon Sep 17 00:00:00 2001 From: "nicolas.dorier" Date: Sun, 27 Oct 2024 19:53:13 +0900 Subject: [PATCH] Prevent concurrency race on lightning payout update --- .../LightningPendingPayoutListener.cs | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/BTCPayServer/Payments/Lightning/LightningPendingPayoutListener.cs b/BTCPayServer/Payments/Lightning/LightningPendingPayoutListener.cs index fc9ae63be..2b3d78d80 100644 --- a/BTCPayServer/Payments/Lightning/LightningPendingPayoutListener.cs +++ b/BTCPayServer/Payments/Lightning/LightningPendingPayoutListener.cs @@ -13,6 +13,7 @@ using BTCPayServer.Services.Invoices; using BTCPayServer.Services.Stores; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Newtonsoft.Json.Linq; using PayoutData = BTCPayServer.Data.PayoutData; using StoreData = BTCPayServer.Data.StoreData; @@ -21,7 +22,6 @@ namespace BTCPayServer.Payments.Lightning; public class LightningPendingPayoutListener : BaseAsyncService { private readonly LightningClientFactoryService _lightningClientFactoryService; - private readonly ApplicationDbContextFactory _applicationDbContextFactory; private readonly PullPaymentHostedService _pullPaymentHostedService; private readonly StoreRepository _storeRepository; private readonly IOptions _options; @@ -32,7 +32,6 @@ public class LightningPendingPayoutListener : BaseAsyncService public LightningPendingPayoutListener( LightningClientFactoryService lightningClientFactoryService, - ApplicationDbContextFactory applicationDbContextFactory, PullPaymentHostedService pullPaymentHostedService, StoreRepository storeRepository, IOptions options, @@ -42,7 +41,6 @@ public class LightningPendingPayoutListener : BaseAsyncService ILogger logger) : base(logger) { _lightningClientFactoryService = lightningClientFactoryService; - _applicationDbContextFactory = applicationDbContextFactory; _pullPaymentHostedService = pullPaymentHostedService; _storeRepository = storeRepository; _options = options; @@ -54,19 +52,18 @@ public class LightningPendingPayoutListener : BaseAsyncService private async Task Act() { - await using var context = _applicationDbContextFactory.CreateContext(); var networks = _networkProvider.GetAll() .OfType() .Where(network => network.SupportLightning) .ToDictionary(network => PaymentTypes.LN.GetPaymentMethodId(network.CryptoCode)); - var payouts = await PullPaymentHostedService.GetPayouts( + var payouts = await _pullPaymentHostedService.GetPayouts( new PullPaymentHostedService.PayoutQuery() { States = new PayoutState[] { PayoutState.InProgress }, PayoutMethods = networks.Keys.Select(id => id.ToString()).ToArray() - }, context); + }); var storeIds = payouts.Select(data => data.StoreDataId).Distinct(); var stores = (await Task.WhenAll(storeIds.Select(_storeRepository.FindStore))) .Where(data => data is not null).ToDictionary(data => data.Id, data => (StoreData)data); @@ -83,9 +80,7 @@ public class LightningPendingPayoutListener : BaseAsyncService .Select(c => (LightningPaymentMethodConfig)c.Value) .FirstOrDefault(); if (pm is null) - { continue; - } var client = pm.CreateLightningClient(networks[pmi], _options.Value, _lightningClientFactoryService); @@ -94,9 +89,6 @@ public class LightningPendingPayoutListener : BaseAsyncService var handler = _payoutHandlers.TryGet(payoutData.GetPayoutMethodId()) as LightningLikePayoutHandler; if (handler is null || handler.PayoutsPaymentProcessing.Contains(payoutData.Id)) continue; - using var tracking = handler.PayoutsPaymentProcessing.StartTracking(); - if (!tracking.TryTrack(payoutData.Id)) - continue; var proof = handler.ParseProof(payoutData) as PayoutLightningBlob; LightningPayment payment = null; @@ -125,10 +117,23 @@ public class LightningPendingPayoutListener : BaseAsyncService break; } } + + foreach (PayoutData payoutData in payoutByStoreByPaymentMethod) + { + if (payoutData.State != PayoutState.InProgress) + { + // This update can fail if the payout has been updated in the meantime + await _pullPaymentHostedService.MarkPaid(new HostedServices.MarkPayoutRequest() + { + PayoutId = payoutData.Id, + State = payoutData.State, + Proof = payoutData.State is PayoutState.Completed ? JObject.Parse(payoutData.Proof) : null + }); + } + } } } - await context.SaveChangesAsync(CancellationToken); await Task.Delay(TimeSpan.FromSeconds(SecondsDelay), CancellationToken); }