track in-flight coins per payment hash

This commit is contained in:
Carsten Otto
2022-05-12 23:04:19 +02:00
parent ae996c0e5b
commit c024fc187b
9 changed files with 119 additions and 35 deletions

View File

@@ -37,7 +37,7 @@ public class GrpcSendToRoute {
decodedPaymentRequest.paymentHash(),
buildLndRoute(route, blockHeight, decodedPaymentRequest)
);
grpcRouterService.sendToRoute(request, new ReportingStreamObserver<>(observer));
grpcRouterService.sendToRoute(request, new ReportingStreamObserver(observer));
}
private lnrpc.Route buildLndRoute(Route route, int blockHeight, DecodedPaymentRequest decodedPaymentRequest) {

View File

@@ -1,8 +1,10 @@
package de.cotto.lndmanagej.grpc;
import de.cotto.lndmanagej.model.HexString;
import io.grpc.stub.StreamObserver;
import lnrpc.HTLCAttempt;
class ReportingStreamObserver<T> implements StreamObserver<T> {
class ReportingStreamObserver implements StreamObserver<HTLCAttempt> {
private final SendToRouteObserver sendToRouteObserver;
public ReportingStreamObserver(SendToRouteObserver sendToRouteObserver) {
@@ -10,8 +12,9 @@ class ReportingStreamObserver<T> implements StreamObserver<T> {
}
@Override
public void onNext(T value) {
sendToRouteObserver.onValue(value);
public void onNext(HTLCAttempt htlcAttempt) {
HexString preimage = new HexString(htlcAttempt.getPreimage().toByteArray());
sendToRouteObserver.onValue(preimage);
}
@Override

View File

@@ -1,7 +1,9 @@
package de.cotto.lndmanagej.grpc;
import de.cotto.lndmanagej.model.HexString;
public interface SendToRouteObserver {
void onError(Throwable throwable);
void onValue(Object value);
void onValue(HexString preimage);
}

View File

@@ -114,13 +114,15 @@ class GrpcSendToRouteTest {
void reporter_reports_value_to_given_observer() {
grpcSendToRoute.sendToRoute(ROUTE, DECODED_PAYMENT_REQUEST, observer);
verify(grpcRouterService).sendToRoute(any(), captor.capture());
HTLCAttempt value = htlcAttempt();
HexString preimage = new HexString("0011FF");
HTLCAttempt value = htlcAttempt(preimage);
captor.getValue().onNext(value);
verify(observer).onValue(value);
verify(observer).onValue(preimage);
}
private HTLCAttempt htlcAttempt() {
return HTLCAttempt.newBuilder().build();
private HTLCAttempt htlcAttempt(HexString hexString) {
ByteString bytestring = ByteString.copyFrom(hexString.getByteArray());
return HTLCAttempt.newBuilder().setPreimage(bytestring).build();
}
private ByteString toByteString(HexString hexString) {

View File

@@ -1,5 +1,8 @@
package de.cotto.lndmanagej.grpc;
import com.google.protobuf.ByteString;
import de.cotto.lndmanagej.model.HexString;
import lnrpc.HTLCAttempt;
import org.junit.jupiter.api.Test;
import javax.annotation.Nullable;
@@ -9,14 +12,8 @@ import static org.assertj.core.api.Assertions.assertThatCode;
class ReportingStreamObserverTest {
private final MySendToRouteObserver sendToRouteObserver = new MySendToRouteObserver();
private final ReportingStreamObserver<String> reportingStreamObserver =
new ReportingStreamObserver<>(sendToRouteObserver);
@Test
void onNext() {
assertThatCode(() -> reportingStreamObserver.onNext("foo")).doesNotThrowAnyException();
}
private final TestableSendToRouteObserver sendToRouteObserver = new TestableSendToRouteObserver();
private final ReportingStreamObserver reportingStreamObserver = new ReportingStreamObserver(sendToRouteObserver);
@Test
void onCompleted() {
@@ -31,18 +28,21 @@ class ReportingStreamObserverTest {
}
@Test
void onValue() {
String value = "";
assertThatCode(() -> reportingStreamObserver.onNext(value)).doesNotThrowAnyException();
assertThat(sendToRouteObserver.seenValue).isSameAs(value);
void onNext_forwards_preimage() {
HexString preimage = new HexString("AA00");
HTLCAttempt htlcAttempt = HTLCAttempt.newBuilder()
.setPreimage(ByteString.copyFrom(preimage.getByteArray()))
.build();
assertThatCode(() -> reportingStreamObserver.onNext(htlcAttempt)).doesNotThrowAnyException();
assertThat(sendToRouteObserver.seenPreimage).isEqualTo(preimage);
}
private static class MySendToRouteObserver implements SendToRouteObserver {
private static class TestableSendToRouteObserver implements SendToRouteObserver {
@Nullable
private Throwable seenThrowable;
@Nullable
private Object seenValue;
private HexString seenPreimage;
@Override
public void onError(Throwable throwable) {
@@ -50,8 +50,8 @@ class ReportingStreamObserverTest {
}
@Override
public void onValue(Object value) {
seenValue = value;
public void onValue(HexString preimage) {
seenPreimage = preimage;
}
}
}

View File

@@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments;
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
import de.cotto.lndmanagej.model.Coins;
import de.cotto.lndmanagej.model.Edge;
import de.cotto.lndmanagej.model.HexString;
import de.cotto.lndmanagej.model.PaymentAttemptHop;
import de.cotto.lndmanagej.model.Route;
import de.cotto.lndmanagej.service.LiquidityInformationUpdater;
@@ -12,19 +13,30 @@ import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
@Component
public class MultiPathPaymentObserver {
private final LiquidityInformationUpdater liquidityInformationUpdater;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final Map<HexString, Coins> inFlight = new ConcurrentHashMap<>();
public MultiPathPaymentObserver(LiquidityInformationUpdater liquidityInformationUpdater) {
this.liquidityInformationUpdater = liquidityInformationUpdater;
}
public SendToRouteObserver forRoute(Route route) {
return new SendToRouteObserverImpl(route);
public SendToRouteObserver getFor(Route route, HexString paymentHash) {
return new SendToRouteObserverImpl(route, paymentHash);
}
public Coins getInFlight(HexString paymentHash) {
return inFlight.getOrDefault(paymentHash, Coins.NONE);
}
private void addInFlight(HexString paymentHash, Coins amount) {
inFlight.compute(paymentHash, (key, value) -> value == null ? amount : amount.add(value));
}
private List<PaymentAttemptHop> topPaymentAttemptHops(Route route) {
@@ -45,20 +57,24 @@ public class MultiPathPaymentObserver {
private class SendToRouteObserverImpl implements SendToRouteObserver {
private final Route route;
private final HexString paymentHash;
public SendToRouteObserverImpl(Route route) {
public SendToRouteObserverImpl(Route route, HexString paymentHash) {
this.route = route;
this.paymentHash = paymentHash;
addInFlight(paymentHash, route.getAmount());
}
@Override
public void onError(Throwable throwable) {
logger.warn("Send to route failed for route {}: ", route, throwable);
liquidityInformationUpdater.removeInFlight(topPaymentAttemptHops(route));
addInFlight(paymentHash, route.getAmount().negate());
}
@Override
public void onValue(Object value) {
logger.info("Got value {}: ", value);
public void onValue(HexString preimage) {
addInFlight(paymentHash, route.getAmount().negate());
}
}
}

View File

@@ -2,8 +2,10 @@ package de.cotto.lndmanagej.pickhardtpayments;
import de.cotto.lndmanagej.grpc.GrpcPayments;
import de.cotto.lndmanagej.grpc.GrpcSendToRoute;
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
import de.cotto.lndmanagej.model.Coins;
import de.cotto.lndmanagej.model.DecodedPaymentRequest;
import de.cotto.lndmanagej.model.HexString;
import de.cotto.lndmanagej.model.Pubkey;
import de.cotto.lndmanagej.model.Route;
import de.cotto.lndmanagej.pickhardtpayments.model.MultiPathPayment;
@@ -40,8 +42,10 @@ public class MultiPathPaymentSender {
MultiPathPayment multiPathPayment =
multiPathPaymentSplitter.getMultiPathPaymentTo(destination, amount, feeRateWeight);
List<Route> routes = multiPathPayment.routes();
HexString paymentHash = decodedPaymentRequest.paymentHash();
for (Route route : routes) {
grpcSendToRoute.sendToRoute(route, decodedPaymentRequest, multiPathPaymentObserver.forRoute(route));
SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.getFor(route, paymentHash);
grpcSendToRoute.sendToRoute(route, decodedPaymentRequest, sendToRouteObserver);
}
return multiPathPayment;
}

View File

@@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments;
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
import de.cotto.lndmanagej.model.Coins;
import de.cotto.lndmanagej.model.Edge;
import de.cotto.lndmanagej.model.HexString;
import de.cotto.lndmanagej.model.PaymentAttemptHop;
import de.cotto.lndmanagej.service.LiquidityInformationUpdater;
import org.junit.jupiter.api.Test;
@@ -15,7 +16,11 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import static de.cotto.lndmanagej.model.DecodedPaymentRequestFixtures.DECODED_PAYMENT_REQUEST;
import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE;
import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE_2;
import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE_3;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.Mockito.verify;
@@ -29,15 +34,65 @@ class MultiPathPaymentObserverTest {
@Test
void cancels_in_flight_on_error() {
SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.forRoute(ROUTE);
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
SendToRouteObserver sendToRouteObserver =
multiPathPaymentObserver.getFor(ROUTE, paymentHash);
sendToRouteObserver.onError(new NullPointerException());
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE);
}
@Test
void cancels_in_flight_using_liquidity_information_updater_on_error() {
SendToRouteObserver sendToRouteObserver =
multiPathPaymentObserver.getFor(ROUTE, DECODED_PAYMENT_REQUEST.paymentHash());
sendToRouteObserver.onError(new NullPointerException());
verify(liquidityInformationUpdater).removeInFlight(hops());
}
@Test
void accepts_value() {
SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.forRoute(ROUTE);
assertThatCode(() -> sendToRouteObserver.onValue("")).doesNotThrowAnyException();
SendToRouteObserver sendToRouteObserver =
multiPathPaymentObserver.getFor(ROUTE, DECODED_PAYMENT_REQUEST.paymentHash());
assertThatCode(() -> sendToRouteObserver.onValue(HexString.EMPTY)).doesNotThrowAnyException();
}
@Test
void inFlight_initially_zero() {
assertThat(multiPathPaymentObserver.getInFlight(DECODED_PAYMENT_REQUEST.paymentHash()))
.isEqualTo(Coins.NONE);
}
@Test
void inFlight_initialized_with_route_amount() {
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
multiPathPaymentObserver.getFor(ROUTE, paymentHash);
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(ROUTE.getAmount());
}
@Test
void inFlight_reset_on_success() {
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
SendToRouteObserver observer = multiPathPaymentObserver.getFor(ROUTE, paymentHash);
observer.onValue(new HexString("AABBCC"));
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE);
}
@Test
void inFlight_reset_on_failure() {
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
SendToRouteObserver observer = multiPathPaymentObserver.getFor(ROUTE, paymentHash);
observer.onValue(HexString.EMPTY);
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE);
}
@Test
void inFlight_updated_with_several_routes() {
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
multiPathPaymentObserver.getFor(ROUTE, paymentHash);
multiPathPaymentObserver.getFor(ROUTE_2, HexString.EMPTY);
multiPathPaymentObserver.getFor(ROUTE_3, paymentHash);
Coins expectedAmount = ROUTE.getAmount().add(ROUTE_3.getAmount());
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(expectedAmount);
}
private List<PaymentAttemptHop> hops() {

View File

@@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments;
import de.cotto.lndmanagej.grpc.GrpcPayments;
import de.cotto.lndmanagej.grpc.GrpcSendToRoute;
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
import de.cotto.lndmanagej.model.HexString;
import de.cotto.lndmanagej.model.Route;
import de.cotto.lndmanagej.pickhardtpayments.model.MultiPathPayment;
import org.junit.jupiter.api.Test;
@@ -84,10 +85,11 @@ class MultiPathPaymentSenderTest {
void registers_observers_for_routes() {
when(grpcPayments.decodePaymentRequest(any())).thenReturn(Optional.of(DECODED_PAYMENT_REQUEST));
when(multiPathPaymentSplitter.getMultiPathPaymentTo(any(), any(), anyInt())).thenReturn(MULTI_PATH_PAYMENT);
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
Map<Route, SendToRouteObserver> expected = new LinkedHashMap<>();
for (Route route : MULTI_PATH_PAYMENT.routes()) {
SendToRouteObserver expectedObserver = mock(SendToRouteObserver.class);
when(multiPathPaymentObserver.forRoute(route)).thenReturn(expectedObserver);
when(multiPathPaymentObserver.getFor(route, paymentHash)).thenReturn(expectedObserver);
expected.put(route, expectedObserver);
}
multiPathPaymentSender.payPaymentRequest(PAYMENT_REQUEST, FEE_RATE_WEIGHT);