diff --git a/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/GrpcSendToRoute.java b/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/GrpcSendToRoute.java index dba380f8..b6a529e2 100644 --- a/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/GrpcSendToRoute.java +++ b/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/GrpcSendToRoute.java @@ -37,7 +37,7 @@ public class GrpcSendToRoute { decodedPaymentRequest.paymentHash(), buildLndRoute(route, blockHeight, decodedPaymentRequest) ); - grpcRouterService.sendToRoute(request, new ReportingStreamObserver<>(observer)); + grpcRouterService.sendToRoute(request, new ReportingStreamObserver(observer)); } private lnrpc.Route buildLndRoute(Route route, int blockHeight, DecodedPaymentRequest decodedPaymentRequest) { diff --git a/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/ReportingStreamObserver.java b/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/ReportingStreamObserver.java index efb4e2bd..21f29bf1 100644 --- a/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/ReportingStreamObserver.java +++ b/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/ReportingStreamObserver.java @@ -1,8 +1,10 @@ package de.cotto.lndmanagej.grpc; +import de.cotto.lndmanagej.model.HexString; import io.grpc.stub.StreamObserver; +import lnrpc.HTLCAttempt; -class ReportingStreamObserver implements StreamObserver { +class ReportingStreamObserver implements StreamObserver { private final SendToRouteObserver sendToRouteObserver; public ReportingStreamObserver(SendToRouteObserver sendToRouteObserver) { @@ -10,8 +12,9 @@ class ReportingStreamObserver implements StreamObserver { } @Override - public void onNext(T value) { - sendToRouteObserver.onValue(value); + public void onNext(HTLCAttempt htlcAttempt) { + HexString preimage = new HexString(htlcAttempt.getPreimage().toByteArray()); + sendToRouteObserver.onValue(preimage); } @Override diff --git a/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/SendToRouteObserver.java b/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/SendToRouteObserver.java index 298e5759..1fa6cee0 100644 --- a/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/SendToRouteObserver.java +++ b/grpc-adapter/src/main/java/de/cotto/lndmanagej/grpc/SendToRouteObserver.java @@ -1,7 +1,9 @@ package de.cotto.lndmanagej.grpc; +import de.cotto.lndmanagej.model.HexString; + public interface SendToRouteObserver { void onError(Throwable throwable); - void onValue(Object value); + void onValue(HexString preimage); } diff --git a/grpc-adapter/src/test/java/de/cotto/lndmanagej/grpc/GrpcSendToRouteTest.java b/grpc-adapter/src/test/java/de/cotto/lndmanagej/grpc/GrpcSendToRouteTest.java index 5b0b734e..58fd525a 100644 --- a/grpc-adapter/src/test/java/de/cotto/lndmanagej/grpc/GrpcSendToRouteTest.java +++ b/grpc-adapter/src/test/java/de/cotto/lndmanagej/grpc/GrpcSendToRouteTest.java @@ -114,13 +114,15 @@ class GrpcSendToRouteTest { void reporter_reports_value_to_given_observer() { grpcSendToRoute.sendToRoute(ROUTE, DECODED_PAYMENT_REQUEST, observer); verify(grpcRouterService).sendToRoute(any(), captor.capture()); - HTLCAttempt value = htlcAttempt(); + HexString preimage = new HexString("0011FF"); + HTLCAttempt value = htlcAttempt(preimage); captor.getValue().onNext(value); - verify(observer).onValue(value); + verify(observer).onValue(preimage); } - private HTLCAttempt htlcAttempt() { - return HTLCAttempt.newBuilder().build(); + private HTLCAttempt htlcAttempt(HexString hexString) { + ByteString bytestring = ByteString.copyFrom(hexString.getByteArray()); + return HTLCAttempt.newBuilder().setPreimage(bytestring).build(); } private ByteString toByteString(HexString hexString) { diff --git a/grpc-adapter/src/test/java/de/cotto/lndmanagej/grpc/ReportingStreamObserverTest.java b/grpc-adapter/src/test/java/de/cotto/lndmanagej/grpc/ReportingStreamObserverTest.java index 49880308..4c43431b 100644 --- a/grpc-adapter/src/test/java/de/cotto/lndmanagej/grpc/ReportingStreamObserverTest.java +++ b/grpc-adapter/src/test/java/de/cotto/lndmanagej/grpc/ReportingStreamObserverTest.java @@ -1,5 +1,8 @@ package de.cotto.lndmanagej.grpc; +import com.google.protobuf.ByteString; +import de.cotto.lndmanagej.model.HexString; +import lnrpc.HTLCAttempt; import org.junit.jupiter.api.Test; import javax.annotation.Nullable; @@ -9,14 +12,8 @@ import static org.assertj.core.api.Assertions.assertThatCode; class ReportingStreamObserverTest { - private final MySendToRouteObserver sendToRouteObserver = new MySendToRouteObserver(); - private final ReportingStreamObserver reportingStreamObserver = - new ReportingStreamObserver<>(sendToRouteObserver); - - @Test - void onNext() { - assertThatCode(() -> reportingStreamObserver.onNext("foo")).doesNotThrowAnyException(); - } + private final TestableSendToRouteObserver sendToRouteObserver = new TestableSendToRouteObserver(); + private final ReportingStreamObserver reportingStreamObserver = new ReportingStreamObserver(sendToRouteObserver); @Test void onCompleted() { @@ -31,18 +28,21 @@ class ReportingStreamObserverTest { } @Test - void onValue() { - String value = ""; - assertThatCode(() -> reportingStreamObserver.onNext(value)).doesNotThrowAnyException(); - assertThat(sendToRouteObserver.seenValue).isSameAs(value); + void onNext_forwards_preimage() { + HexString preimage = new HexString("AA00"); + HTLCAttempt htlcAttempt = HTLCAttempt.newBuilder() + .setPreimage(ByteString.copyFrom(preimage.getByteArray())) + .build(); + assertThatCode(() -> reportingStreamObserver.onNext(htlcAttempt)).doesNotThrowAnyException(); + assertThat(sendToRouteObserver.seenPreimage).isEqualTo(preimage); } - private static class MySendToRouteObserver implements SendToRouteObserver { + private static class TestableSendToRouteObserver implements SendToRouteObserver { @Nullable private Throwable seenThrowable; @Nullable - private Object seenValue; + private HexString seenPreimage; @Override public void onError(Throwable throwable) { @@ -50,8 +50,8 @@ class ReportingStreamObserverTest { } @Override - public void onValue(Object value) { - seenValue = value; + public void onValue(HexString preimage) { + seenPreimage = preimage; } } } diff --git a/pickhardt-payments/src/main/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentObserver.java b/pickhardt-payments/src/main/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentObserver.java index d25cce6b..a34a0b5b 100644 --- a/pickhardt-payments/src/main/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentObserver.java +++ b/pickhardt-payments/src/main/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentObserver.java @@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments; import de.cotto.lndmanagej.grpc.SendToRouteObserver; import de.cotto.lndmanagej.model.Coins; import de.cotto.lndmanagej.model.Edge; +import de.cotto.lndmanagej.model.HexString; import de.cotto.lndmanagej.model.PaymentAttemptHop; import de.cotto.lndmanagej.model.Route; import de.cotto.lndmanagej.service.LiquidityInformationUpdater; @@ -12,19 +13,30 @@ import org.springframework.stereotype.Component; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; @Component public class MultiPathPaymentObserver { private final LiquidityInformationUpdater liquidityInformationUpdater; private final Logger logger = LoggerFactory.getLogger(getClass()); + private final Map inFlight = new ConcurrentHashMap<>(); public MultiPathPaymentObserver(LiquidityInformationUpdater liquidityInformationUpdater) { this.liquidityInformationUpdater = liquidityInformationUpdater; } - public SendToRouteObserver forRoute(Route route) { - return new SendToRouteObserverImpl(route); + public SendToRouteObserver getFor(Route route, HexString paymentHash) { + return new SendToRouteObserverImpl(route, paymentHash); + } + + public Coins getInFlight(HexString paymentHash) { + return inFlight.getOrDefault(paymentHash, Coins.NONE); + } + + private void addInFlight(HexString paymentHash, Coins amount) { + inFlight.compute(paymentHash, (key, value) -> value == null ? amount : amount.add(value)); } private List topPaymentAttemptHops(Route route) { @@ -45,20 +57,24 @@ public class MultiPathPaymentObserver { private class SendToRouteObserverImpl implements SendToRouteObserver { private final Route route; + private final HexString paymentHash; - public SendToRouteObserverImpl(Route route) { + public SendToRouteObserverImpl(Route route, HexString paymentHash) { this.route = route; + this.paymentHash = paymentHash; + addInFlight(paymentHash, route.getAmount()); } @Override public void onError(Throwable throwable) { logger.warn("Send to route failed for route {}: ", route, throwable); liquidityInformationUpdater.removeInFlight(topPaymentAttemptHops(route)); + addInFlight(paymentHash, route.getAmount().negate()); } @Override - public void onValue(Object value) { - logger.info("Got value {}: ", value); + public void onValue(HexString preimage) { + addInFlight(paymentHash, route.getAmount().negate()); } } } diff --git a/pickhardt-payments/src/main/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentSender.java b/pickhardt-payments/src/main/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentSender.java index e0f5bc41..a19b1b4e 100644 --- a/pickhardt-payments/src/main/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentSender.java +++ b/pickhardt-payments/src/main/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentSender.java @@ -2,8 +2,10 @@ package de.cotto.lndmanagej.pickhardtpayments; import de.cotto.lndmanagej.grpc.GrpcPayments; import de.cotto.lndmanagej.grpc.GrpcSendToRoute; +import de.cotto.lndmanagej.grpc.SendToRouteObserver; import de.cotto.lndmanagej.model.Coins; import de.cotto.lndmanagej.model.DecodedPaymentRequest; +import de.cotto.lndmanagej.model.HexString; import de.cotto.lndmanagej.model.Pubkey; import de.cotto.lndmanagej.model.Route; import de.cotto.lndmanagej.pickhardtpayments.model.MultiPathPayment; @@ -40,8 +42,10 @@ public class MultiPathPaymentSender { MultiPathPayment multiPathPayment = multiPathPaymentSplitter.getMultiPathPaymentTo(destination, amount, feeRateWeight); List routes = multiPathPayment.routes(); + HexString paymentHash = decodedPaymentRequest.paymentHash(); for (Route route : routes) { - grpcSendToRoute.sendToRoute(route, decodedPaymentRequest, multiPathPaymentObserver.forRoute(route)); + SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.getFor(route, paymentHash); + grpcSendToRoute.sendToRoute(route, decodedPaymentRequest, sendToRouteObserver); } return multiPathPayment; } diff --git a/pickhardt-payments/src/test/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentObserverTest.java b/pickhardt-payments/src/test/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentObserverTest.java index e5dfadd9..21e0f2b2 100644 --- a/pickhardt-payments/src/test/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentObserverTest.java +++ b/pickhardt-payments/src/test/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentObserverTest.java @@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments; import de.cotto.lndmanagej.grpc.SendToRouteObserver; import de.cotto.lndmanagej.model.Coins; import de.cotto.lndmanagej.model.Edge; +import de.cotto.lndmanagej.model.HexString; import de.cotto.lndmanagej.model.PaymentAttemptHop; import de.cotto.lndmanagej.service.LiquidityInformationUpdater; import org.junit.jupiter.api.Test; @@ -15,7 +16,11 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import static de.cotto.lndmanagej.model.DecodedPaymentRequestFixtures.DECODED_PAYMENT_REQUEST; import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE; +import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE_2; +import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE_3; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.Mockito.verify; @@ -29,15 +34,65 @@ class MultiPathPaymentObserverTest { @Test void cancels_in_flight_on_error() { - SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.forRoute(ROUTE); + HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash(); + SendToRouteObserver sendToRouteObserver = + multiPathPaymentObserver.getFor(ROUTE, paymentHash); + sendToRouteObserver.onError(new NullPointerException()); + assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE); + } + + @Test + void cancels_in_flight_using_liquidity_information_updater_on_error() { + SendToRouteObserver sendToRouteObserver = + multiPathPaymentObserver.getFor(ROUTE, DECODED_PAYMENT_REQUEST.paymentHash()); sendToRouteObserver.onError(new NullPointerException()); verify(liquidityInformationUpdater).removeInFlight(hops()); } @Test void accepts_value() { - SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.forRoute(ROUTE); - assertThatCode(() -> sendToRouteObserver.onValue("")).doesNotThrowAnyException(); + SendToRouteObserver sendToRouteObserver = + multiPathPaymentObserver.getFor(ROUTE, DECODED_PAYMENT_REQUEST.paymentHash()); + assertThatCode(() -> sendToRouteObserver.onValue(HexString.EMPTY)).doesNotThrowAnyException(); + } + + @Test + void inFlight_initially_zero() { + assertThat(multiPathPaymentObserver.getInFlight(DECODED_PAYMENT_REQUEST.paymentHash())) + .isEqualTo(Coins.NONE); + } + + @Test + void inFlight_initialized_with_route_amount() { + HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash(); + multiPathPaymentObserver.getFor(ROUTE, paymentHash); + assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(ROUTE.getAmount()); + } + + @Test + void inFlight_reset_on_success() { + HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash(); + SendToRouteObserver observer = multiPathPaymentObserver.getFor(ROUTE, paymentHash); + observer.onValue(new HexString("AABBCC")); + assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE); + } + + @Test + void inFlight_reset_on_failure() { + HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash(); + SendToRouteObserver observer = multiPathPaymentObserver.getFor(ROUTE, paymentHash); + observer.onValue(HexString.EMPTY); + assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE); + } + + @Test + void inFlight_updated_with_several_routes() { + HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash(); + multiPathPaymentObserver.getFor(ROUTE, paymentHash); + multiPathPaymentObserver.getFor(ROUTE_2, HexString.EMPTY); + multiPathPaymentObserver.getFor(ROUTE_3, paymentHash); + Coins expectedAmount = ROUTE.getAmount().add(ROUTE_3.getAmount()); + assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(expectedAmount); } private List hops() { diff --git a/pickhardt-payments/src/test/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentSenderTest.java b/pickhardt-payments/src/test/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentSenderTest.java index 1eaaf030..217bb29b 100644 --- a/pickhardt-payments/src/test/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentSenderTest.java +++ b/pickhardt-payments/src/test/java/de/cotto/lndmanagej/pickhardtpayments/MultiPathPaymentSenderTest.java @@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments; import de.cotto.lndmanagej.grpc.GrpcPayments; import de.cotto.lndmanagej.grpc.GrpcSendToRoute; import de.cotto.lndmanagej.grpc.SendToRouteObserver; +import de.cotto.lndmanagej.model.HexString; import de.cotto.lndmanagej.model.Route; import de.cotto.lndmanagej.pickhardtpayments.model.MultiPathPayment; import org.junit.jupiter.api.Test; @@ -84,10 +85,11 @@ class MultiPathPaymentSenderTest { void registers_observers_for_routes() { when(grpcPayments.decodePaymentRequest(any())).thenReturn(Optional.of(DECODED_PAYMENT_REQUEST)); when(multiPathPaymentSplitter.getMultiPathPaymentTo(any(), any(), anyInt())).thenReturn(MULTI_PATH_PAYMENT); + HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash(); Map expected = new LinkedHashMap<>(); for (Route route : MULTI_PATH_PAYMENT.routes()) { SendToRouteObserver expectedObserver = mock(SendToRouteObserver.class); - when(multiPathPaymentObserver.forRoute(route)).thenReturn(expectedObserver); + when(multiPathPaymentObserver.getFor(route, paymentHash)).thenReturn(expectedObserver); expected.put(route, expectedObserver); } multiPathPaymentSender.payPaymentRequest(PAYMENT_REQUEST, FEE_RATE_WEIGHT);