diff --git a/intercept.go b/intercept.go index 99f708e..0fd5b92 100644 --- a/intercept.go +++ b/intercept.go @@ -24,6 +24,17 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" ) +func checkPayment(incomingAmountMsat, outgoingAmountMsat int64) error { + var fees int64 = 0 + if incomingAmountMsat > channelFeeStartAmount { + fees += (incomingAmountMsat - channelFeeStartAmount) * channelFeeAmountNumerator / channelFeeAmountDenominator + } + if incomingAmountMsat-outgoingAmountMsat < fees { + return fmt.Errorf("not enough fees") + } + return nil +} + func openChannel(ctx context.Context, client lnrpc.LightningClient, paymentHash, destination []byte, incomingAmountMsat int64) ([]byte, uint32, error) { capacity := incomingAmountMsat/1000 + channelFeeStartAmount channelPoint, err := client.OpenChannelSync(ctx, &lnrpc.OpenChannelRequest{ diff --git a/server.go b/server.go index 4a06de3..868ae8d 100644 --- a/server.go +++ b/server.go @@ -29,14 +29,15 @@ import ( ) const ( - channelAmount = 1_000_000 - targetConf = 1 - minHtlcMsat = 600 - baseFeeMsat = 1000 - feeRate = 0.000001 - timeLockDelta = 144 - channelFeeStartAmount = 100_000 - channelFeeAmount = 0.001 + channelAmount = 1_000_000 + targetConf = 1 + minHtlcMsat = 600 + baseFeeMsat = 1000 + feeRate = 0.000001 + timeLockDelta = 144 + channelFeeStartAmount = 100_000 + channelFeeAmountNumerator = 1 + channelFeeAmountDenominator = 1000 ) type server struct{} @@ -63,7 +64,7 @@ func (s *server) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInfo FeeRate: feeRate, TimeLockDelta: timeLockDelta, ChannelFeeStartAmount: channelFeeStartAmount, - ChannelFeeRate: channelFeeAmount, + ChannelFeeRate: 1.0 * channelFeeAmountNumerator / channelFeeAmountDenominator, LspPubkey: publicKey.SerializeCompressed(), }, nil } @@ -82,6 +83,11 @@ func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymen } log.Printf("RegisterPayment - Destination: %x, pi.PaymentHash: %x, pi.PaymentSecret: %x, pi.IncomingAmountMsat: %v, pi.OutgoingAmountMsat: %v", pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) + err = checkPayment(pi.IncomingAmountMsat, pi.OutgoingAmountMsat) + if err != nil { + log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) + return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) + } err = registerPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) if err != nil { log.Printf("RegisterPayment() error: %v", err)