mirror of
https://github.com/aljazceru/lnd-manageJ.git
synced 2026-01-27 09:54:51 +01:00
track in-flight coins per payment hash
This commit is contained in:
@@ -37,7 +37,7 @@ public class GrpcSendToRoute {
|
||||
decodedPaymentRequest.paymentHash(),
|
||||
buildLndRoute(route, blockHeight, decodedPaymentRequest)
|
||||
);
|
||||
grpcRouterService.sendToRoute(request, new ReportingStreamObserver<>(observer));
|
||||
grpcRouterService.sendToRoute(request, new ReportingStreamObserver(observer));
|
||||
}
|
||||
|
||||
private lnrpc.Route buildLndRoute(Route route, int blockHeight, DecodedPaymentRequest decodedPaymentRequest) {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package de.cotto.lndmanagej.grpc;
|
||||
|
||||
import de.cotto.lndmanagej.model.HexString;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import lnrpc.HTLCAttempt;
|
||||
|
||||
class ReportingStreamObserver<T> implements StreamObserver<T> {
|
||||
class ReportingStreamObserver implements StreamObserver<HTLCAttempt> {
|
||||
private final SendToRouteObserver sendToRouteObserver;
|
||||
|
||||
public ReportingStreamObserver(SendToRouteObserver sendToRouteObserver) {
|
||||
@@ -10,8 +12,9 @@ class ReportingStreamObserver<T> implements StreamObserver<T> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(T value) {
|
||||
sendToRouteObserver.onValue(value);
|
||||
public void onNext(HTLCAttempt htlcAttempt) {
|
||||
HexString preimage = new HexString(htlcAttempt.getPreimage().toByteArray());
|
||||
sendToRouteObserver.onValue(preimage);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package de.cotto.lndmanagej.grpc;
|
||||
|
||||
import de.cotto.lndmanagej.model.HexString;
|
||||
|
||||
public interface SendToRouteObserver {
|
||||
void onError(Throwable throwable);
|
||||
|
||||
void onValue(Object value);
|
||||
void onValue(HexString preimage);
|
||||
}
|
||||
|
||||
@@ -114,13 +114,15 @@ class GrpcSendToRouteTest {
|
||||
void reporter_reports_value_to_given_observer() {
|
||||
grpcSendToRoute.sendToRoute(ROUTE, DECODED_PAYMENT_REQUEST, observer);
|
||||
verify(grpcRouterService).sendToRoute(any(), captor.capture());
|
||||
HTLCAttempt value = htlcAttempt();
|
||||
HexString preimage = new HexString("0011FF");
|
||||
HTLCAttempt value = htlcAttempt(preimage);
|
||||
captor.getValue().onNext(value);
|
||||
verify(observer).onValue(value);
|
||||
verify(observer).onValue(preimage);
|
||||
}
|
||||
|
||||
private HTLCAttempt htlcAttempt() {
|
||||
return HTLCAttempt.newBuilder().build();
|
||||
private HTLCAttempt htlcAttempt(HexString hexString) {
|
||||
ByteString bytestring = ByteString.copyFrom(hexString.getByteArray());
|
||||
return HTLCAttempt.newBuilder().setPreimage(bytestring).build();
|
||||
}
|
||||
|
||||
private ByteString toByteString(HexString hexString) {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package de.cotto.lndmanagej.grpc;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import de.cotto.lndmanagej.model.HexString;
|
||||
import lnrpc.HTLCAttempt;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
@@ -9,14 +12,8 @@ import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
|
||||
class ReportingStreamObserverTest {
|
||||
|
||||
private final MySendToRouteObserver sendToRouteObserver = new MySendToRouteObserver();
|
||||
private final ReportingStreamObserver<String> reportingStreamObserver =
|
||||
new ReportingStreamObserver<>(sendToRouteObserver);
|
||||
|
||||
@Test
|
||||
void onNext() {
|
||||
assertThatCode(() -> reportingStreamObserver.onNext("foo")).doesNotThrowAnyException();
|
||||
}
|
||||
private final TestableSendToRouteObserver sendToRouteObserver = new TestableSendToRouteObserver();
|
||||
private final ReportingStreamObserver reportingStreamObserver = new ReportingStreamObserver(sendToRouteObserver);
|
||||
|
||||
@Test
|
||||
void onCompleted() {
|
||||
@@ -31,18 +28,21 @@ class ReportingStreamObserverTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
void onValue() {
|
||||
String value = "";
|
||||
assertThatCode(() -> reportingStreamObserver.onNext(value)).doesNotThrowAnyException();
|
||||
assertThat(sendToRouteObserver.seenValue).isSameAs(value);
|
||||
void onNext_forwards_preimage() {
|
||||
HexString preimage = new HexString("AA00");
|
||||
HTLCAttempt htlcAttempt = HTLCAttempt.newBuilder()
|
||||
.setPreimage(ByteString.copyFrom(preimage.getByteArray()))
|
||||
.build();
|
||||
assertThatCode(() -> reportingStreamObserver.onNext(htlcAttempt)).doesNotThrowAnyException();
|
||||
assertThat(sendToRouteObserver.seenPreimage).isEqualTo(preimage);
|
||||
}
|
||||
|
||||
private static class MySendToRouteObserver implements SendToRouteObserver {
|
||||
private static class TestableSendToRouteObserver implements SendToRouteObserver {
|
||||
@Nullable
|
||||
private Throwable seenThrowable;
|
||||
|
||||
@Nullable
|
||||
private Object seenValue;
|
||||
private HexString seenPreimage;
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
@@ -50,8 +50,8 @@ class ReportingStreamObserverTest {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onValue(Object value) {
|
||||
seenValue = value;
|
||||
public void onValue(HexString preimage) {
|
||||
seenPreimage = preimage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments;
|
||||
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
|
||||
import de.cotto.lndmanagej.model.Coins;
|
||||
import de.cotto.lndmanagej.model.Edge;
|
||||
import de.cotto.lndmanagej.model.HexString;
|
||||
import de.cotto.lndmanagej.model.PaymentAttemptHop;
|
||||
import de.cotto.lndmanagej.model.Route;
|
||||
import de.cotto.lndmanagej.service.LiquidityInformationUpdater;
|
||||
@@ -12,19 +13,30 @@ import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@Component
|
||||
public class MultiPathPaymentObserver {
|
||||
private final LiquidityInformationUpdater liquidityInformationUpdater;
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
private final Map<HexString, Coins> inFlight = new ConcurrentHashMap<>();
|
||||
|
||||
public MultiPathPaymentObserver(LiquidityInformationUpdater liquidityInformationUpdater) {
|
||||
this.liquidityInformationUpdater = liquidityInformationUpdater;
|
||||
}
|
||||
|
||||
public SendToRouteObserver forRoute(Route route) {
|
||||
return new SendToRouteObserverImpl(route);
|
||||
public SendToRouteObserver getFor(Route route, HexString paymentHash) {
|
||||
return new SendToRouteObserverImpl(route, paymentHash);
|
||||
}
|
||||
|
||||
public Coins getInFlight(HexString paymentHash) {
|
||||
return inFlight.getOrDefault(paymentHash, Coins.NONE);
|
||||
}
|
||||
|
||||
private void addInFlight(HexString paymentHash, Coins amount) {
|
||||
inFlight.compute(paymentHash, (key, value) -> value == null ? amount : amount.add(value));
|
||||
}
|
||||
|
||||
private List<PaymentAttemptHop> topPaymentAttemptHops(Route route) {
|
||||
@@ -45,20 +57,24 @@ public class MultiPathPaymentObserver {
|
||||
|
||||
private class SendToRouteObserverImpl implements SendToRouteObserver {
|
||||
private final Route route;
|
||||
private final HexString paymentHash;
|
||||
|
||||
public SendToRouteObserverImpl(Route route) {
|
||||
public SendToRouteObserverImpl(Route route, HexString paymentHash) {
|
||||
this.route = route;
|
||||
this.paymentHash = paymentHash;
|
||||
addInFlight(paymentHash, route.getAmount());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
logger.warn("Send to route failed for route {}: ", route, throwable);
|
||||
liquidityInformationUpdater.removeInFlight(topPaymentAttemptHops(route));
|
||||
addInFlight(paymentHash, route.getAmount().negate());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onValue(Object value) {
|
||||
logger.info("Got value {}: ", value);
|
||||
public void onValue(HexString preimage) {
|
||||
addInFlight(paymentHash, route.getAmount().negate());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@ package de.cotto.lndmanagej.pickhardtpayments;
|
||||
|
||||
import de.cotto.lndmanagej.grpc.GrpcPayments;
|
||||
import de.cotto.lndmanagej.grpc.GrpcSendToRoute;
|
||||
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
|
||||
import de.cotto.lndmanagej.model.Coins;
|
||||
import de.cotto.lndmanagej.model.DecodedPaymentRequest;
|
||||
import de.cotto.lndmanagej.model.HexString;
|
||||
import de.cotto.lndmanagej.model.Pubkey;
|
||||
import de.cotto.lndmanagej.model.Route;
|
||||
import de.cotto.lndmanagej.pickhardtpayments.model.MultiPathPayment;
|
||||
@@ -40,8 +42,10 @@ public class MultiPathPaymentSender {
|
||||
MultiPathPayment multiPathPayment =
|
||||
multiPathPaymentSplitter.getMultiPathPaymentTo(destination, amount, feeRateWeight);
|
||||
List<Route> routes = multiPathPayment.routes();
|
||||
HexString paymentHash = decodedPaymentRequest.paymentHash();
|
||||
for (Route route : routes) {
|
||||
grpcSendToRoute.sendToRoute(route, decodedPaymentRequest, multiPathPaymentObserver.forRoute(route));
|
||||
SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.getFor(route, paymentHash);
|
||||
grpcSendToRoute.sendToRoute(route, decodedPaymentRequest, sendToRouteObserver);
|
||||
}
|
||||
return multiPathPayment;
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments;
|
||||
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
|
||||
import de.cotto.lndmanagej.model.Coins;
|
||||
import de.cotto.lndmanagej.model.Edge;
|
||||
import de.cotto.lndmanagej.model.HexString;
|
||||
import de.cotto.lndmanagej.model.PaymentAttemptHop;
|
||||
import de.cotto.lndmanagej.service.LiquidityInformationUpdater;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -15,7 +16,11 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import static de.cotto.lndmanagej.model.DecodedPaymentRequestFixtures.DECODED_PAYMENT_REQUEST;
|
||||
import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE;
|
||||
import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE_2;
|
||||
import static de.cotto.lndmanagej.model.RouteFixtures.ROUTE_3;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
@@ -29,15 +34,65 @@ class MultiPathPaymentObserverTest {
|
||||
|
||||
@Test
|
||||
void cancels_in_flight_on_error() {
|
||||
SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.forRoute(ROUTE);
|
||||
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
|
||||
SendToRouteObserver sendToRouteObserver =
|
||||
multiPathPaymentObserver.getFor(ROUTE, paymentHash);
|
||||
sendToRouteObserver.onError(new NullPointerException());
|
||||
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void cancels_in_flight_using_liquidity_information_updater_on_error() {
|
||||
SendToRouteObserver sendToRouteObserver =
|
||||
multiPathPaymentObserver.getFor(ROUTE, DECODED_PAYMENT_REQUEST.paymentHash());
|
||||
sendToRouteObserver.onError(new NullPointerException());
|
||||
verify(liquidityInformationUpdater).removeInFlight(hops());
|
||||
}
|
||||
|
||||
@Test
|
||||
void accepts_value() {
|
||||
SendToRouteObserver sendToRouteObserver = multiPathPaymentObserver.forRoute(ROUTE);
|
||||
assertThatCode(() -> sendToRouteObserver.onValue("")).doesNotThrowAnyException();
|
||||
SendToRouteObserver sendToRouteObserver =
|
||||
multiPathPaymentObserver.getFor(ROUTE, DECODED_PAYMENT_REQUEST.paymentHash());
|
||||
assertThatCode(() -> sendToRouteObserver.onValue(HexString.EMPTY)).doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
void inFlight_initially_zero() {
|
||||
assertThat(multiPathPaymentObserver.getInFlight(DECODED_PAYMENT_REQUEST.paymentHash()))
|
||||
.isEqualTo(Coins.NONE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void inFlight_initialized_with_route_amount() {
|
||||
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
|
||||
multiPathPaymentObserver.getFor(ROUTE, paymentHash);
|
||||
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(ROUTE.getAmount());
|
||||
}
|
||||
|
||||
@Test
|
||||
void inFlight_reset_on_success() {
|
||||
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
|
||||
SendToRouteObserver observer = multiPathPaymentObserver.getFor(ROUTE, paymentHash);
|
||||
observer.onValue(new HexString("AABBCC"));
|
||||
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void inFlight_reset_on_failure() {
|
||||
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
|
||||
SendToRouteObserver observer = multiPathPaymentObserver.getFor(ROUTE, paymentHash);
|
||||
observer.onValue(HexString.EMPTY);
|
||||
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(Coins.NONE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void inFlight_updated_with_several_routes() {
|
||||
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
|
||||
multiPathPaymentObserver.getFor(ROUTE, paymentHash);
|
||||
multiPathPaymentObserver.getFor(ROUTE_2, HexString.EMPTY);
|
||||
multiPathPaymentObserver.getFor(ROUTE_3, paymentHash);
|
||||
Coins expectedAmount = ROUTE.getAmount().add(ROUTE_3.getAmount());
|
||||
assertThat(multiPathPaymentObserver.getInFlight(paymentHash)).isEqualTo(expectedAmount);
|
||||
}
|
||||
|
||||
private List<PaymentAttemptHop> hops() {
|
||||
|
||||
@@ -3,6 +3,7 @@ package de.cotto.lndmanagej.pickhardtpayments;
|
||||
import de.cotto.lndmanagej.grpc.GrpcPayments;
|
||||
import de.cotto.lndmanagej.grpc.GrpcSendToRoute;
|
||||
import de.cotto.lndmanagej.grpc.SendToRouteObserver;
|
||||
import de.cotto.lndmanagej.model.HexString;
|
||||
import de.cotto.lndmanagej.model.Route;
|
||||
import de.cotto.lndmanagej.pickhardtpayments.model.MultiPathPayment;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -84,10 +85,11 @@ class MultiPathPaymentSenderTest {
|
||||
void registers_observers_for_routes() {
|
||||
when(grpcPayments.decodePaymentRequest(any())).thenReturn(Optional.of(DECODED_PAYMENT_REQUEST));
|
||||
when(multiPathPaymentSplitter.getMultiPathPaymentTo(any(), any(), anyInt())).thenReturn(MULTI_PATH_PAYMENT);
|
||||
HexString paymentHash = DECODED_PAYMENT_REQUEST.paymentHash();
|
||||
Map<Route, SendToRouteObserver> expected = new LinkedHashMap<>();
|
||||
for (Route route : MULTI_PATH_PAYMENT.routes()) {
|
||||
SendToRouteObserver expectedObserver = mock(SendToRouteObserver.class);
|
||||
when(multiPathPaymentObserver.forRoute(route)).thenReturn(expectedObserver);
|
||||
when(multiPathPaymentObserver.getFor(route, paymentHash)).thenReturn(expectedObserver);
|
||||
expected.put(route, expectedObserver);
|
||||
}
|
||||
multiPathPaymentSender.payPaymentRequest(PAYMENT_REQUEST, FEE_RATE_WEIGHT);
|
||||
|
||||
Reference in New Issue
Block a user