cancel in-flight if payment fails immediately

This commit is contained in:
Carsten Otto
2022-05-09 23:19:02 +02:00
parent f136ae71f2
commit b24dc880b5
14 changed files with 223 additions and 42 deletions

View File

@@ -67,12 +67,12 @@ public class LiquidityInformationUpdater implements PaymentListener {
}
}
private void addInFlight(List<PaymentAttemptHop> paymentAttemptHops) {
updateInFlight(paymentAttemptHops, false);
public void removeInFlight(List<PaymentAttemptHop> paymentAttemptHops) {
updateInFlight(paymentAttemptHops, true);
}
private void removeInFlight(List<PaymentAttemptHop> paymentAttemptHops) {
updateInFlight(paymentAttemptHops, true);
private void addInFlight(List<PaymentAttemptHop> paymentAttemptHops) {
updateInFlight(paymentAttemptHops, false);
}
private void updateInFlight(List<PaymentAttemptHop> paymentAttemptHops, boolean negate) {

View File

@@ -404,6 +404,13 @@ class LiquidityInformationUpdaterTest {
verifyNoMoreInteractions(liquidityBoundsService);
}
@Test
void removeInFlight() {
liquidityInformationUpdater.removeInFlight(hopsWithChannelIdsAndPubkeys);
verifyRemovesInFlightForAllHops();
verifyNoMoreInteractions(liquidityBoundsService);
}
private void verifyRemovesInFlightForAllHops() {
verify(liquidityBoundsService).markAsInFlight(PUBKEY, PUBKEY_2, Coins.ofSatoshis(-100));
verify(liquidityBoundsService).markAsInFlight(PUBKEY_2, PUBKEY_3, Coins.ofSatoshis(-90));

View File

@@ -18,6 +18,7 @@ tasks.withType(JavaCompile).configureEach {
options.errorprone.nullaway {
severity = net.ltgt.gradle.errorprone.CheckSeverity.ERROR
excludedFieldAnnotations.add('org.mockito.Mock')
excludedFieldAnnotations.add('org.mockito.Captor')
excludedFieldAnnotations.add('org.springframework.beans.factory.annotation.Value')
excludedFieldAnnotations.add('org.mockito.InjectMocks')
excludedFieldAnnotations.add('org.junit.jupiter.api.io.TempDir')

View File

@@ -115,7 +115,7 @@
</rule>
<rule ref="category/java/design.xml/ImmutableField">
<properties>
<property name="ignoredAnnotations" value="org.springframework.beans.factory.annotation.Autowired|org.springframework.boot.test.mock.mockito.MockBean|org.mockito.Mock|org.mockito.InjectMocks|javax.persistence.Id" />
<property name="ignoredAnnotations" value="org.springframework.beans.factory.annotation.Autowired|org.springframework.boot.test.mock.mockito.MockBean|org.mockito.Mock|org.mockito.Captor|org.mockito.InjectMocks|javax.persistence.Id" />
</properties>
</rule>

View File

@@ -2,9 +2,13 @@ package de.cotto.lndmanagej.grpc;
import io.grpc.stub.StreamObserver;
class NoopObserver<T> implements StreamObserver<T> {
public NoopObserver() {
// default constructor
import java.util.function.Consumer;
class ErrorReporter<T> implements StreamObserver<T> {
private final Consumer<Throwable> consumer;
public ErrorReporter(Consumer<Throwable> consumer) {
this.consumer = consumer;
}
@Override
@@ -14,7 +18,7 @@ class NoopObserver<T> implements StreamObserver<T> {
@Override
public void onError(Throwable throwable) {
// nothing
consumer.accept(throwable);
}
@Override

View File

@@ -27,7 +27,7 @@ public class GrpcSendToRoute {
this.grpcGetInfo = grpcGetInfo;
}
public void sendToRoute(Route route, DecodedPaymentRequest decodedPaymentRequest) {
public void sendToRoute(Route route, DecodedPaymentRequest decodedPaymentRequest, SendToRouteObserver observer) {
Integer blockHeight = grpcGetInfo.getBlockHeight().orElse(null);
if (blockHeight == null) {
logger.error("Unable to get current block height");
@@ -37,7 +37,7 @@ public class GrpcSendToRoute {
decodedPaymentRequest.paymentHash(),
buildLndRoute(route, blockHeight, decodedPaymentRequest)
);
grpcRouterService.sendToRoute(request, new NoopObserver<>());
grpcRouterService.sendToRoute(request, new ErrorReporter<>(observer));
}
private lnrpc.Route buildLndRoute(Route route, int blockHeight, DecodedPaymentRequest decodedPaymentRequest) {

View File

@@ -0,0 +1,6 @@
package de.cotto.lndmanagej.grpc;
import java.util.function.Consumer;
public interface SendToRouteObserver extends Consumer<Throwable> {
}

View File

@@ -0,0 +1,37 @@
package de.cotto.lndmanagej.grpc;
import org.junit.jupiter.api.Test;
import javax.annotation.Nullable;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
class ErrorReporterTest {
private final ErrorReporter<String> errorReporter = new ErrorReporter<>(this::consumeThrowable);
@Nullable
private Throwable seenThrowable;
@Test
void onNext() {
assertThatCode(() -> errorReporter.onNext("foo")).doesNotThrowAnyException();
}
@Test
void onCompleted() {
assertThatCode(errorReporter::onCompleted).doesNotThrowAnyException();
}
@Test
void onError() {
NullPointerException throwable = new NullPointerException();
assertThatCode(() -> errorReporter.onError(throwable)).doesNotThrowAnyException();
assertThat(seenThrowable).isSameAs(throwable);
}
private void consumeThrowable(Throwable throwable) {
seenThrowable = throwable;
}
}

View File

@@ -2,11 +2,16 @@ package de.cotto.lndmanagej.grpc;
import com.google.protobuf.ByteString;
import de.cotto.lndmanagej.model.HexString;
import io.grpc.stub.StreamObserver;
import lnrpc.HTLCAttempt;
import lnrpc.Hop;
import lnrpc.MPPRecord;
import lnrpc.Route;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
@@ -40,17 +45,27 @@ class GrpcSendToRouteTest {
@Mock
private GrpcRouterService grpcRouterService;
@Mock
private SendToRouteObserver observer;
@Captor
private ArgumentCaptor<StreamObserver<HTLCAttempt>> captor;
@BeforeEach
void setUp() {
when(grpcGetInfo.getBlockHeight()).thenReturn(Optional.of(BLOCK_HEIGHT));
}
@Test
void block_height_not_available() {
when(grpcGetInfo.getBlockHeight()).thenReturn(Optional.empty());
grpcSendToRoute.sendToRoute(ROUTE, DECODED_PAYMENT_REQUEST);
grpcSendToRoute.sendToRoute(ROUTE, DECODED_PAYMENT_REQUEST, observer);
verifyNoInteractions(grpcRouterService);
}
@Test
void sends_to_converted_route() {
when(grpcGetInfo.getBlockHeight()).thenReturn(Optional.of(BLOCK_HEIGHT));
grpcSendToRoute.sendToRoute(ROUTE, DECODED_PAYMENT_REQUEST);
grpcSendToRoute.sendToRoute(ROUTE, DECODED_PAYMENT_REQUEST, observer);
RouterOuterClass.SendToRouteRequest expectedRequest = RouterOuterClass.SendToRouteRequest.newBuilder()
.setRoute(Route.newBuilder()
.setTotalTimeLock(ROUTE.getTotalTimeLock(BLOCK_HEIGHT, DECODED_PAYMENT_REQUEST.cltvExpiry()))
@@ -86,6 +101,15 @@ class GrpcSendToRouteTest {
verify(grpcRouterService).sendToRoute(eq(expectedRequest), any());
}
@Test
void error_reporter_reports_to_given_observer() {
grpcSendToRoute.sendToRoute(ROUTE, DECODED_PAYMENT_REQUEST, observer);
verify(grpcRouterService).sendToRoute(any(), captor.capture());
NullPointerException throwable = new NullPointerException();
captor.getValue().onError(throwable);
verify(observer).accept(throwable);
}
private ByteString toByteString(HexString hexString) {
return ByteString.copyFrom(hexString.getByteArray());
}

View File

@@ -1,25 +0,0 @@
package de.cotto.lndmanagej.grpc;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThatCode;
class NoopObserverTest {
private final NoopObserver<String> noopObserver = new NoopObserver<>();
@Test
void onNext() {
assertThatCode(() -> noopObserver.onNext("foo")).doesNotThrowAnyException();
}
@Test
void onCompleted() {
assertThatCode(noopObserver::onCompleted).doesNotThrowAnyException();
}
@Test
void onError() {
assertThatCode(() -> noopObserver.onError(new NullPointerException())).doesNotThrowAnyException();
}
}

View File

@@ -0,0 +1,48 @@
package de.cotto.lndmanagej.pickhardtpayments;
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
import de.cotto.lndmanagej.model.Coins;
import de.cotto.lndmanagej.model.Edge;
import de.cotto.lndmanagej.model.PaymentAttemptHop;
import de.cotto.lndmanagej.model.Route;
import de.cotto.lndmanagej.service.LiquidityInformationUpdater;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
@Component
public class MultiPathPaymentObserver {
private final LiquidityInformationUpdater liquidityInformationUpdater;
private final Logger logger = LoggerFactory.getLogger(getClass());
public MultiPathPaymentObserver(LiquidityInformationUpdater liquidityInformationUpdater) {
this.liquidityInformationUpdater = liquidityInformationUpdater;
}
public SendToRouteObserver forRoute(Route route) {
return throwable -> {
logger.warn("Send to route failed for route {}: ", route, throwable);
liquidityInformationUpdater.removeInFlight(topPaymentAttemptHops(route));
};
}
private List<PaymentAttemptHop> topPaymentAttemptHops(Route route) {
List<Edge> edges = route.getEdges();
List<PaymentAttemptHop> result = new ArrayList<>();
for (int i = 0; i < edges.size(); i++) {
Edge edge = edges.get(i);
Coins forwardAmountForHop = route.getForwardAmountForHop(i);
PaymentAttemptHop hop = new PaymentAttemptHop(
Optional.of(edge.channelId()),
forwardAmountForHop,
Optional.empty()
);
result.add(hop);
}
return result;
}
}

View File

@@ -16,15 +16,18 @@ public class MultiPathPaymentSender {
private final GrpcPayments grpcPayments;
private final GrpcSendToRoute grpcSendToRoute;
private final MultiPathPaymentSplitter multiPathPaymentSplitter;
private final MultiPathPaymentObserver multiPathPaymentObserver;
public MultiPathPaymentSender(
GrpcPayments grpcPayments,
GrpcSendToRoute grpcSendToRoute,
MultiPathPaymentSplitter multiPathPaymentSplitter
MultiPathPaymentSplitter multiPathPaymentSplitter,
MultiPathPaymentObserver multiPathPaymentObserver
) {
this.grpcPayments = grpcPayments;
this.grpcSendToRoute = grpcSendToRoute;
this.multiPathPaymentSplitter = multiPathPaymentSplitter;
this.multiPathPaymentObserver = multiPathPaymentObserver;
}
public MultiPathPayment payPaymentRequest(String paymentRequest, int feeRateWeight) {
@@ -38,7 +41,7 @@ public class MultiPathPaymentSender {
multiPathPaymentSplitter.getMultiPathPaymentTo(destination, amount, feeRateWeight);
List<Route> routes = multiPathPayment.routes();
for (Route route : routes) {
grpcSendToRoute.sendToRoute(route, decodedPaymentRequest);
grpcSendToRoute.sendToRoute(route, decodedPaymentRequest, multiPathPaymentObserver.forRoute(route));
}
return multiPathPayment;
}

View File

@@ -0,0 +1,51 @@
package de.cotto.lndmanagej.pickhardtpayments;
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
import de.cotto.lndmanagej.model.Coins;
import de.cotto.lndmanagej.model.Edge;
import de.cotto.lndmanagej.model.PaymentAttemptHop;
import de.cotto.lndmanagej.service.LiquidityInformationUpdater;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE;
import static org.mockito.Mockito.verify;
@ExtendWith(MockitoExtension.class)
class MultiPathPaymentObserverTest {
@InjectMocks
private MultiPathPaymentObserver multiPathPaymentObserver;
@Mock
private LiquidityInformationUpdater liquidityInformationUpdater;
@Test
void cancels_in_flight_on_error() {
SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.forRoute(ROUTE);
sendToRouteObserver.accept(new NullPointerException());
verify(liquidityInformationUpdater).removeInFlight(hops());
}
private List<PaymentAttemptHop> hops() {
List<Edge> edges = ROUTE.getEdges();
List<PaymentAttemptHop> result = new ArrayList<>();
for (int i = 0; i < edges.size(); i++) {
Edge edge = edges.get(i);
Coins forwardAmountForHop = ROUTE.getForwardAmountForHop(i);
PaymentAttemptHop hop = new PaymentAttemptHop(
Optional.of(edge.channelId()),
forwardAmountForHop,
Optional.empty()
);
result.add(hop);
}
return result;
}
}

View File

@@ -2,6 +2,7 @@ package de.cotto.lndmanagej.pickhardtpayments;
import de.cotto.lndmanagej.grpc.GrpcPayments;
import de.cotto.lndmanagej.grpc.GrpcSendToRoute;
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
import de.cotto.lndmanagej.model.Route;
import de.cotto.lndmanagej.pickhardtpayments.model.MultiPathPayment;
import org.junit.jupiter.api.Test;
@@ -10,13 +11,18 @@ import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
import static de.cotto.lndmanagej.model.DecodedPaymentRequestFixtures.DECODED_PAYMENT_REQUEST;
import static de.cotto.lndmanagej.pickhardtpayments.model.MultiPathPaymentFixtures.MULTI_PATH_PAYMENT;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
@@ -38,6 +44,9 @@ class MultiPathPaymentSenderTest {
@Mock
private MultiPathPaymentSplitter multiPathPaymentSplitter;
@Mock
private MultiPathPaymentObserver multiPathPaymentObserver;
@Test
void payment_request_cannot_be_decoded() {
when(grpcPayments.decodePaymentRequest(any())).thenReturn(Optional.empty());
@@ -67,7 +76,23 @@ class MultiPathPaymentSenderTest {
MultiPathPayment multiPathPayment = multiPathPaymentSender.payPaymentRequest(PAYMENT_REQUEST, FEE_RATE_WEIGHT);
assertThat(multiPathPayment.isFailure()).isFalse();
for (Route route : MULTI_PATH_PAYMENT.routes()) {
verify(grpcSendToRoute).sendToRoute(route, DECODED_PAYMENT_REQUEST);
verify(grpcSendToRoute).sendToRoute(eq(route), eq(DECODED_PAYMENT_REQUEST), any());
}
}
@Test
void registers_observers_for_routes() {
when(grpcPayments.decodePaymentRequest(any())).thenReturn(Optional.of(DECODED_PAYMENT_REQUEST));
when(multiPathPaymentSplitter.getMultiPathPaymentTo(any(), any(), anyInt())).thenReturn(MULTI_PATH_PAYMENT);
Map<Route, SendToRouteObserver> expected = new LinkedHashMap<>();
for (Route route : MULTI_PATH_PAYMENT.routes()) {
SendToRouteObserver expectedObserver = mock(SendToRouteObserver.class);
when(multiPathPaymentObserver.forRoute(route)).thenReturn(expectedObserver);
expected.put(route, expectedObserver);
}
multiPathPaymentSender.payPaymentRequest(PAYMENT_REQUEST, FEE_RATE_WEIGHT);
for (Route route : MULTI_PATH_PAYMENT.routes()) {
verify(grpcSendToRoute).sendToRoute(route, DECODED_PAYMENT_REQUEST, requireNonNull(expected.get(route)));
}
}
}