cleanup: convert intercept and database to types

This commit is contained in:
Jesse de Wit
2023-03-24 15:53:19 +01:00
parent 9781ac6bb0
commit 086d500750
11 changed files with 258 additions and 142 deletions

View File

@@ -24,6 +24,7 @@ import (
) )
type ClnHtlcInterceptor struct { type ClnHtlcInterceptor struct {
interceptor *Interceptor
config *config.NodeConfig config *config.NodeConfig
pluginAddress string pluginAddress string
client *ClnClient client *ClnClient
@@ -35,15 +36,7 @@ type ClnHtlcInterceptor struct {
cancel context.CancelFunc cancel context.CancelFunc
} }
func NewClnHtlcInterceptor(conf *config.NodeConfig) (*ClnHtlcInterceptor, error) { func NewClnHtlcInterceptor(conf *config.NodeConfig, client *ClnClient, interceptor *Interceptor) (*ClnHtlcInterceptor, error) {
if conf.Cln == nil {
return nil, fmt.Errorf("missing cln config")
}
client, err := NewClnClient(conf.Cln.SocketPath)
if err != nil {
return nil, err
}
i := &ClnHtlcInterceptor{ i := &ClnHtlcInterceptor{
config: conf, config: conf,
pluginAddress: conf.Cln.PluginAddress, pluginAddress: conf.Cln.PluginAddress,
@@ -169,7 +162,7 @@ func (i *ClnHtlcInterceptor) intercept() error {
interceptorClient.Send(i.defaultResolution(request)) interceptorClient.Send(i.defaultResolution(request))
i.doneWg.Done() i.doneWg.Done()
} }
interceptResult := intercept(i.client, i.config, nextHop, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry) interceptResult := i.interceptor.Intercept(nextHop, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry)
switch interceptResult.action { switch interceptResult.action {
case INTERCEPT_RESUME_WITH_ONION: case INTERCEPT_RESUME_WITH_ONION:
interceptorClient.Send(i.resumeWithOnion(request, interceptResult)) interceptorClient.Send(i.resumeWithOnion(request, interceptResult))

View File

@@ -7,6 +7,8 @@ import (
"log" "log"
"time" "time"
"github.com/breez/lspd/interceptor"
"github.com/breez/lspd/lnd"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/chainrpc" "github.com/lightningnetwork/lnd/lnrpc/chainrpc"
@@ -36,14 +38,32 @@ func (cfe *copyFromEvents) Err() error {
return cfe.err return cfe.err
} }
func channelsSynchronize(ctx context.Context, client *LndClient) { type ForwardingHistorySync struct {
client *LndClient
interceptStore interceptor.InterceptStore
forwardingStore lnd.ForwardingEventStore
}
func NewForwardingHistorySync(
client *LndClient,
interceptStore interceptor.InterceptStore,
forwardingStore lnd.ForwardingEventStore,
) *ForwardingHistorySync {
return &ForwardingHistorySync{
client: client,
interceptStore: interceptStore,
forwardingStore: forwardingStore,
}
}
func (s *ForwardingHistorySync) ChannelsSynchronize(ctx context.Context) {
lastSync := time.Now().Add(-6 * time.Minute) lastSync := time.Now().Add(-6 * time.Minute)
for { for {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
stream, err := client.chainNotifierClient.RegisterBlockEpochNtfn(ctx, &chainrpc.BlockEpoch{}) stream, err := s.client.chainNotifierClient.RegisterBlockEpochNtfn(ctx, &chainrpc.BlockEpoch{})
if err != nil { if err != nil {
log.Printf("chainNotifierClient.RegisterBlockEpochNtfn(): %v", err) log.Printf("chainNotifierClient.RegisterBlockEpochNtfn(): %v", err)
<-time.After(time.Second) <-time.After(time.Second)
@@ -67,7 +87,7 @@ func channelsSynchronize(ctx context.Context, client *LndClient) {
return return
case <-time.After(1 * time.Minute): case <-time.After(1 * time.Minute):
} }
err = channelsSynchronizeOnce(client) err = s.ChannelsSynchronizeOnce()
lastSync = time.Now() lastSync = time.Now()
log.Printf("channelsSynchronizeOnce() err: %v", err) log.Printf("channelsSynchronizeOnce() err: %v", err)
} }
@@ -75,9 +95,9 @@ func channelsSynchronize(ctx context.Context, client *LndClient) {
} }
} }
func channelsSynchronizeOnce(client *LndClient) error { func (s *ForwardingHistorySync) ChannelsSynchronizeOnce() error {
log.Printf("channelsSynchronizeOnce - begin") log.Printf("channelsSynchronizeOnce - begin")
channels, err := client.client.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{PrivateOnly: true}) channels, err := s.client.client.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{PrivateOnly: true})
if err != nil { if err != nil {
log.Printf("ListChannels error: %v", err) log.Printf("ListChannels error: %v", err)
return fmt.Errorf("client.ListChannels() error: %w", err) return fmt.Errorf("client.ListChannels() error: %w", err)
@@ -97,7 +117,7 @@ func channelsSynchronizeOnce(client *LndClient) error {
confirmedChanId = 0 confirmedChanId = 0
} }
} }
err = insertChannel(c.ChanId, confirmedChanId, c.ChannelPoint, nodeID, lastUpdate) err = s.interceptStore.InsertChannel(c.ChanId, confirmedChanId, c.ChannelPoint, nodeID, lastUpdate)
if err != nil { if err != nil {
log.Printf("insertChannel(%v, %v, %x) in channelsSynchronizeOnce error: %v", c.ChanId, c.ChannelPoint, nodeID, err) log.Printf("insertChannel(%v, %v, %x) in channelsSynchronizeOnce error: %v", c.ChanId, c.ChannelPoint, nodeID, err)
continue continue
@@ -108,13 +128,13 @@ func channelsSynchronizeOnce(client *LndClient) error {
return nil return nil
} }
func forwardingHistorySynchronize(ctx context.Context, client *LndClient) { func (s *ForwardingHistorySync) ForwardingHistorySynchronize(ctx context.Context) {
for { for {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
err := forwardingHistorySynchronizeOnce(client) err := s.ForwardingHistorySynchronizeOnce()
log.Printf("forwardingHistorySynchronizeOnce() err: %v", err) log.Printf("forwardingHistorySynchronizeOnce() err: %v", err)
select { select {
case <-time.After(1 * time.Minute): case <-time.After(1 * time.Minute):
@@ -123,8 +143,8 @@ func forwardingHistorySynchronize(ctx context.Context, client *LndClient) {
} }
} }
func forwardingHistorySynchronizeOnce(client *LndClient) error { func (s *ForwardingHistorySync) ForwardingHistorySynchronizeOnce() error {
last, err := lastForwardingEvent() last, err := s.forwardingStore.LastForwardingEvent()
if err != nil { if err != nil {
return fmt.Errorf("lastForwardingEvent() error: %w", err) return fmt.Errorf("lastForwardingEvent() error: %w", err)
} }
@@ -138,7 +158,7 @@ func forwardingHistorySynchronizeOnce(client *LndClient) error {
endTime := uint64(now.Add(time.Hour * 24).Unix()) endTime := uint64(now.Add(time.Hour * 24).Unix())
indexOffset := uint32(0) indexOffset := uint32(0)
for { for {
forwardHistory, err := client.client.ForwardingHistory(context.Background(), &lnrpc.ForwardingHistoryRequest{ forwardHistory, err := s.client.client.ForwardingHistory(context.Background(), &lnrpc.ForwardingHistoryRequest{
StartTime: uint64(last), StartTime: uint64(last),
EndTime: endTime, EndTime: endTime,
NumMaxEvents: 10000, NumMaxEvents: 10000,
@@ -154,7 +174,7 @@ func forwardingHistorySynchronizeOnce(client *LndClient) error {
} }
indexOffset = forwardHistory.LastOffsetIndex indexOffset = forwardHistory.LastOffsetIndex
cfe := copyFromEvents{events: forwardHistory.ForwardingEvents, idx: -1} cfe := copyFromEvents{events: forwardHistory.ForwardingEvents, idx: -1}
err = insertForwardingEvents(&cfe) err = s.forwardingStore.InsertForwardingEvents(&cfe)
if err != nil { if err != nil {
log.Printf("insertForwardingEvents() error: %v", err) log.Printf("insertForwardingEvents() error: %v", err)
return fmt.Errorf("insertForwardingEvents() error: %w", err) return fmt.Errorf("insertForwardingEvents() error: %w", err)

View File

@@ -11,6 +11,7 @@ import (
"github.com/breez/lspd/chain" "github.com/breez/lspd/chain"
"github.com/breez/lspd/config" "github.com/breez/lspd/config"
"github.com/breez/lspd/interceptor"
"github.com/breez/lspd/lightning" "github.com/breez/lspd/lightning"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
@@ -51,10 +52,28 @@ type interceptResult struct {
onionBlob []byte onionBlob []byte
} }
func intercept(client lightning.Client, config *config.NodeConfig, nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) interceptResult { type Interceptor struct {
client lightning.Client
config *config.NodeConfig
store interceptor.InterceptStore
}
func NewInterceptor(
client lightning.Client,
config *config.NodeConfig,
store interceptor.InterceptStore,
) *Interceptor {
return &Interceptor{
client: client,
config: config,
store: store,
}
}
func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) interceptResult {
reqPaymentHashStr := hex.EncodeToString(reqPaymentHash) reqPaymentHashStr := hex.EncodeToString(reqPaymentHash)
resp, _, _ := payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) { resp, _, _ := payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) {
paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := paymentInfo(reqPaymentHash) paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := i.store.PaymentInfo(reqPaymentHash)
if err != nil { if err != nil {
log.Printf("paymentInfo(%x) error: %v", reqPaymentHash, err) log.Printf("paymentInfo(%x) error: %v", reqPaymentHash, err)
return interceptResult{ return interceptResult{
@@ -72,14 +91,14 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin
if channelPoint == nil { if channelPoint == nil {
if bytes.Equal(paymentHash, reqPaymentHash) { if bytes.Equal(paymentHash, reqPaymentHash) {
if int64(reqIncomingExpiry)-int64(reqOutgoingExpiry) < int64(config.TimeLockDelta) { if int64(reqIncomingExpiry)-int64(reqOutgoingExpiry) < int64(i.config.TimeLockDelta) {
return interceptResult{ return interceptResult{
action: INTERCEPT_FAIL_HTLC_WITH_CODE, action: INTERCEPT_FAIL_HTLC_WITH_CODE,
failureCode: FAILURE_TEMPORARY_CHANNEL_FAILURE, failureCode: FAILURE_TEMPORARY_CHANNEL_FAILURE,
}, nil }, nil
} }
channelPoint, err = openChannel(client, config, reqPaymentHash, destination, incomingAmountMsat) channelPoint, err = i.openChannel(reqPaymentHash, destination, incomingAmountMsat)
if err != nil { if err != nil {
log.Printf("openChannel(%x, %v) err: %v", destination, incomingAmountMsat, err) log.Printf("openChannel(%x, %v) err: %v", destination, incomingAmountMsat, err)
return interceptResult{ return interceptResult{
@@ -173,11 +192,11 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin
deadline := time.Now().Add(60 * time.Second) deadline := time.Now().Add(60 * time.Second)
for { for {
chanResult, _ := client.GetChannel(destination, *channelPoint) chanResult, _ := i.client.GetChannel(destination, *channelPoint)
if chanResult != nil { if chanResult != nil {
log.Printf("channel opended successfully alias: %v, confirmed: %v", chanResult.InitialChannelID.ToString(), chanResult.ConfirmedChannelID.ToString()) log.Printf("channel opended successfully alias: %v, confirmed: %v", chanResult.InitialChannelID.ToString(), chanResult.ConfirmedChannelID.ToString())
err := insertChannel( err := i.store.InsertChannel(
uint64(chanResult.InitialChannelID), uint64(chanResult.InitialChannelID),
uint64(chanResult.ConfirmedChannelID), uint64(chanResult.ConfirmedChannelID),
channelPoint.String(), channelPoint.String(),
@@ -226,20 +245,9 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin
return resp.(interceptResult) return resp.(interceptResult)
} }
func checkPayment(config *config.NodeConfig, incomingAmountMsat, outgoingAmountMsat int64) error { func (i *Interceptor) openChannel(paymentHash, destination []byte, incomingAmountMsat int64) (*wire.OutPoint, error) {
fees := incomingAmountMsat * config.ChannelFeePermyriad / 10_000 / 1_000 * 1_000 capacity := incomingAmountMsat/1000 + i.config.AdditionalChannelCapacity
if fees < config.ChannelMinimumFeeMsat { if capacity == i.config.PublicChannelAmount {
fees = config.ChannelMinimumFeeMsat
}
if incomingAmountMsat-outgoingAmountMsat < fees {
return fmt.Errorf("not enough fees")
}
return nil
}
func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash, destination []byte, incomingAmountMsat int64) (*wire.OutPoint, error) {
capacity := incomingAmountMsat/1000 + config.AdditionalChannelCapacity
if capacity == config.PublicChannelAmount {
capacity++ capacity++
} }
@@ -257,7 +265,7 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash
feeStr = fmt.Sprintf("%.5f", *feeEstimation) feeStr = fmt.Sprintf("%.5f", *feeEstimation)
} else { } else {
log.Printf("Error estimating chain fee, fallback to target conf: %v", err) log.Printf("Error estimating chain fee, fallback to target conf: %v", err)
targetConf = &config.TargetConf targetConf = &i.config.TargetConf
confStr = fmt.Sprintf("%v", *targetConf) confStr = fmt.Sprintf("%v", *targetConf)
} }
} }
@@ -269,10 +277,10 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash
feeStr, feeStr,
confStr, confStr,
) )
channelPoint, err := client.OpenChannel(&lightning.OpenChannelRequest{ channelPoint, err := i.client.OpenChannel(&lightning.OpenChannelRequest{
Destination: destination, Destination: destination,
CapacitySat: uint64(capacity), CapacitySat: uint64(capacity),
MinConfs: config.MinConfs, MinConfs: i.config.MinConfs,
IsPrivate: true, IsPrivate: true,
IsZeroConf: true, IsZeroConf: true,
FeeSatPerVByte: feeEstimation, FeeSatPerVByte: feeEstimation,
@@ -289,6 +297,6 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash
capacity, capacity,
channelPoint.String(), channelPoint.String(),
) )
err = setFundingTx(paymentHash, channelPoint) err = i.store.SetFundingTx(paymentHash, channelPoint)
return channelPoint, err return channelPoint, err
} }

14
interceptor/store.go Normal file
View File

@@ -0,0 +1,14 @@
package interceptor
import (
"time"
"github.com/btcsuite/btcd/wire"
)
type InterceptStore interface {
PaymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error)
SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error
RegisterPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error
InsertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error
}

View File

@@ -0,0 +1,12 @@
package lnd
type CopyFromSource interface {
Next() bool
Values() ([]interface{}, error)
Err() error
}
type ForwardingEventStore interface {
LastForwardingEvent() (int64, error)
InsertForwardingEvents(rowSrc CopyFromSource) error
}

View File

@@ -15,6 +15,8 @@ import (
) )
type LndHtlcInterceptor struct { type LndHtlcInterceptor struct {
fwsync *ForwardingHistorySync
interceptor *Interceptor
config *config.NodeConfig config *config.NodeConfig
client *LndClient client *LndClient
stopRequested bool stopRequested bool
@@ -24,17 +26,17 @@ type LndHtlcInterceptor struct {
cancel context.CancelFunc cancel context.CancelFunc
} }
func NewLndHtlcInterceptor(conf *config.NodeConfig) (*LndHtlcInterceptor, error) { func NewLndHtlcInterceptor(
if conf.Lnd == nil { conf *config.NodeConfig,
return nil, fmt.Errorf("missing lnd configuration") client *LndClient,
} fwsync *ForwardingHistorySync,
client, err := NewLndClient(conf.Lnd) interceptor *Interceptor,
if err != nil { ) (*LndHtlcInterceptor, error) {
return nil, err
}
i := &LndHtlcInterceptor{ i := &LndHtlcInterceptor{
config: conf, config: conf,
client: client, client: client,
fwsync: fwsync,
interceptor: interceptor,
} }
i.initWg.Add(1) i.initWg.Add(1)
@@ -47,8 +49,8 @@ func (i *LndHtlcInterceptor) Start() error {
i.ctx = ctx i.ctx = ctx
i.cancel = cancel i.cancel = cancel
i.stopRequested = false i.stopRequested = false
go forwardingHistorySynchronize(ctx, i.client) go i.fwsync.ForwardingHistorySynchronize(ctx)
go channelsSynchronize(ctx, i.client) go i.fwsync.ChannelsSynchronize(ctx)
return i.intercept() return i.intercept()
} }
@@ -149,7 +151,7 @@ func (i *LndHtlcInterceptor) intercept() error {
i.doneWg.Add(1) i.doneWg.Add(1)
go func() { go func() {
interceptResult := intercept(i.client, i.config, nextHop, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry) interceptResult := i.interceptor.Intercept(nextHop, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry)
switch interceptResult.action { switch interceptResult.action {
case INTERCEPT_RESUME_WITH_ONION: case INTERCEPT_RESUME_WITH_ONION:
interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{ interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{

41
main.go
View File

@@ -13,6 +13,7 @@ import (
"github.com/breez/lspd/chain" "github.com/breez/lspd/chain"
"github.com/breez/lspd/config" "github.com/breez/lspd/config"
"github.com/breez/lspd/mempool" "github.com/breez/lspd/mempool"
"github.com/breez/lspd/postgresql"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
) )
@@ -63,43 +64,59 @@ func main() {
log.Printf("using mempool api for fee estimation: %v, fee strategy: %v:%v", mempoolUrl, envFeeStrategy, feeStrategy) log.Printf("using mempool api for fee estimation: %v, fee strategy: %v:%v", mempoolUrl, envFeeStrategy, feeStrategy)
} }
databaseUrl := os.Getenv("DATABASE_URL")
pool, err := postgresql.PgConnect(databaseUrl)
if err != nil {
log.Fatalf("pgConnect() error: %v", err)
}
interceptStore := postgresql.NewPostgresInterceptStore(pool)
forwardingStore := postgresql.NewForwardingEventStore(pool)
var interceptors []HtlcInterceptor var interceptors []HtlcInterceptor
for _, node := range nodes { for _, node := range nodes {
var interceptor HtlcInterceptor var htlcInterceptor HtlcInterceptor
if node.Lnd != nil { if node.Lnd != nil {
interceptor, err = NewLndHtlcInterceptor(node) client, err := NewLndClient(node.Lnd)
if err != nil {
log.Fatalf("failed to initialize LND client: %v", err)
}
fwsync := NewForwardingHistorySync(client, interceptStore, forwardingStore)
interceptor := NewInterceptor(client, node, interceptStore)
htlcInterceptor, err = NewLndHtlcInterceptor(node, client, fwsync, interceptor)
if err != nil { if err != nil {
log.Fatalf("failed to initialize LND interceptor: %v", err) log.Fatalf("failed to initialize LND interceptor: %v", err)
} }
} }
if node.Cln != nil { if node.Cln != nil {
interceptor, err = NewClnHtlcInterceptor(node) client, err := NewClnClient(node.Cln.SocketPath)
if err != nil {
log.Fatalf("failed to initialize CLN client: %v", err)
}
interceptor := NewInterceptor(client, node, interceptStore)
htlcInterceptor, err = NewClnHtlcInterceptor(node, client, interceptor)
if err != nil { if err != nil {
log.Fatalf("failed to initialize CLN interceptor: %v", err) log.Fatalf("failed to initialize CLN interceptor: %v", err)
} }
} }
if interceptor == nil { if htlcInterceptor == nil {
log.Fatalf("node has to be either cln or lnd") log.Fatalf("node has to be either cln or lnd")
} }
interceptors = append(interceptors, interceptor) interceptors = append(interceptors, htlcInterceptor)
} }
address := os.Getenv("LISTEN_ADDRESS") address := os.Getenv("LISTEN_ADDRESS")
certMagicDomain := os.Getenv("CERTMAGIC_DOMAIN") certMagicDomain := os.Getenv("CERTMAGIC_DOMAIN")
s, err := NewGrpcServer(nodes, address, certMagicDomain) s, err := NewGrpcServer(nodes, address, certMagicDomain, interceptStore)
if err != nil { if err != nil {
log.Fatalf("failed to initialize grpc server: %v", err) log.Fatalf("failed to initialize grpc server: %v", err)
} }
databaseUrl := os.Getenv("DATABASE_URL")
err = pgConnect(databaseUrl)
if err != nil {
log.Fatalf("pgConnect() error: %v", err)
}
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(interceptors) + 1) wg.Add(len(interceptors) + 1)

17
postgresql/connect.go Normal file
View File

@@ -0,0 +1,17 @@
package postgresql
import (
"context"
"fmt"
"github.com/jackc/pgx/v4/pgxpool"
)
func PgConnect(databaseUrl string) (*pgxpool.Pool, error) {
var err error
pgxPool, err := pgxpool.Connect(context.Background(), databaseUrl)
if err != nil {
return nil, fmt.Errorf("pgxpool.Connect(%v): %w", databaseUrl, err)
}
return pgxPool, nil
}

View File

@@ -0,0 +1,68 @@
package postgresql
import (
"context"
"fmt"
"log"
"github.com/breez/lspd/lnd"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
type ForwardingEventStore struct {
pool *pgxpool.Pool
}
func NewForwardingEventStore(pool *pgxpool.Pool) *ForwardingEventStore {
return &ForwardingEventStore{pool: pool}
}
func (s *ForwardingEventStore) LastForwardingEvent() (int64, error) {
var last int64
err := s.pool.QueryRow(context.Background(),
`SELECT coalesce(MAX("timestamp"), 0) AS last FROM forwarding_history`).Scan(&last)
if err != nil {
return 0, err
}
return last, nil
}
func (s *ForwardingEventStore) InsertForwardingEvents(rowSrc lnd.CopyFromSource) error {
tx, err := s.pool.Begin(context.Background())
if err != nil {
return fmt.Errorf("pgxPool.Begin() error: %w", err)
}
defer tx.Rollback(context.Background())
_, err = tx.Exec(context.Background(), `
CREATE TEMP TABLE tmp_table ON COMMIT DROP AS
SELECT *
FROM forwarding_history
WITH NO DATA;
`)
if err != nil {
return fmt.Errorf("CREATE TEMP TABLE error: %w", err)
}
count, err := tx.CopyFrom(context.Background(),
pgx.Identifier{"tmp_table"},
[]string{"timestamp", "chanid_in", "chanid_out", "amt_msat_in", "amt_msat_out"}, rowSrc)
if err != nil {
return fmt.Errorf("CopyFrom() error: %w", err)
}
log.Printf("count1: %v", count)
cmdTag, err := tx.Exec(context.Background(), `
INSERT INTO forwarding_history
SELECT *
FROM tmp_table
ON CONFLICT DO NOTHING
`)
if err != nil {
return fmt.Errorf("INSERT INTO forwarding_history error: %w", err)
}
log.Printf("count2: %v", cmdTag.RowsAffected())
return tx.Commit(context.Background())
}

View File

@@ -1,4 +1,4 @@
package main package postgresql
import ( import (
"context" "context"
@@ -13,27 +13,22 @@ import (
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
) )
var ( type PostgresInterceptStore struct {
pgxPool *pgxpool.Pool pool *pgxpool.Pool
)
func pgConnect(databaseUrl string) error {
var err error
pgxPool, err = pgxpool.Connect(context.Background(), databaseUrl)
if err != nil {
return fmt.Errorf("pgxpool.Connect(%v): %w", databaseUrl, err)
}
return nil
} }
func paymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error) { func NewPostgresInterceptStore(pool *pgxpool.Pool) *PostgresInterceptStore {
return &PostgresInterceptStore{pool: pool}
}
func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error) {
var ( var (
paymentHash, paymentSecret, destination []byte paymentHash, paymentSecret, destination []byte
incomingAmountMsat, outgoingAmountMsat int64 incomingAmountMsat, outgoingAmountMsat int64
fundingTxID []byte fundingTxID []byte
fundingTxOutnum pgtype.Int4 fundingTxOutnum pgtype.Int4
) )
err := pgxPool.QueryRow(context.Background(), err := s.pool.QueryRow(context.Background(),
`SELECT payment_hash, payment_secret, destination, incoming_amount_msat, outgoing_amount_msat, funding_tx_id, funding_tx_outnum `SELECT payment_hash, payment_secret, destination, incoming_amount_msat, outgoing_amount_msat, funding_tx_id, funding_tx_outnum
FROM payments FROM payments
WHERE payment_hash=$1 OR sha256('probing-01:' || payment_hash)=$1`, WHERE payment_hash=$1 OR sha256('probing-01:' || payment_hash)=$1`,
@@ -55,8 +50,8 @@ func paymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64,
return paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, cp, nil return paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, cp, nil
} }
func setFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error { func (s *PostgresInterceptStore) SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error {
commandTag, err := pgxPool.Exec(context.Background(), commandTag, err := s.pool.Exec(context.Background(),
`UPDATE payments `UPDATE payments
SET funding_tx_id = $2, funding_tx_outnum = $3 SET funding_tx_id = $2, funding_tx_outnum = $3
WHERE payment_hash=$1`, WHERE payment_hash=$1`,
@@ -65,12 +60,12 @@ func setFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error {
return err return err
} }
func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error { func (s *PostgresInterceptStore) RegisterPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error {
var t *string var t *string
if tag != "" { if tag != "" {
t = &tag t = &tag
} }
commandTag, err := pgxPool.Exec(context.Background(), commandTag, err := s.pool.Exec(context.Background(),
`INSERT INTO `INSERT INTO
payments (destination, payment_hash, payment_secret, incoming_amount_msat, outgoing_amount_msat, tag) payments (destination, payment_hash, payment_secret, incoming_amount_msat, outgoing_amount_msat, tag)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
@@ -85,14 +80,14 @@ func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmo
return nil return nil
} }
func insertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error { func (s *PostgresInterceptStore) InsertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error {
query := `INSERT INTO query := `INSERT INTO
channels (initial_chanid, confirmed_chanid, channel_point, nodeid, last_update) channels (initial_chanid, confirmed_chanid, channel_point, nodeid, last_update)
VALUES ($1, NULLIF($2, 0::int8), $3, $4, $5) VALUES ($1, NULLIF($2, 0::int8), $3, $4, $5)
ON CONFLICT (channel_point) DO UPDATE SET confirmed_chanid=NULLIF($2, 0::int8), last_update=$5` ON CONFLICT (channel_point) DO UPDATE SET confirmed_chanid=NULLIF($2, 0::int8), last_update=$5`
c, err := pgxPool.Exec(context.Background(), c, err := s.pool.Exec(context.Background(),
query, int64(initialChanID), int64(confirmedChanId), channelPoint, nodeID, lastUpdate) query, int64(initialChanID), int64(confirmedChanId), channelPoint, nodeID, lastUpdate)
if err != nil { if err != nil {
log.Printf("insertChannel(%v, %v, %s, %x) error: %v", log.Printf("insertChannel(%v, %v, %s, %x) error: %v",
@@ -104,52 +99,3 @@ func insertChannel(initialChanID, confirmedChanId uint64, channelPoint string, n
initialChanID, confirmedChanId, nodeID, c.String()) initialChanID, confirmedChanId, nodeID, c.String())
return nil return nil
} }
func lastForwardingEvent() (int64, error) {
var last int64
err := pgxPool.QueryRow(context.Background(),
`SELECT coalesce(MAX("timestamp"), 0) AS last FROM forwarding_history`).Scan(&last)
if err != nil {
return 0, err
}
return last, nil
}
func insertForwardingEvents(rowSrc pgx.CopyFromSource) error {
tx, err := pgxPool.Begin(context.Background())
if err != nil {
return fmt.Errorf("pgxPool.Begin() error: %w", err)
}
defer tx.Rollback(context.Background())
_, err = tx.Exec(context.Background(), `
CREATE TEMP TABLE tmp_table ON COMMIT DROP AS
SELECT *
FROM forwarding_history
WITH NO DATA;
`)
if err != nil {
return fmt.Errorf("CREATE TEMP TABLE error: %w", err)
}
count, err := tx.CopyFrom(context.Background(),
pgx.Identifier{"tmp_table"},
[]string{"timestamp", "chanid_in", "chanid_out", "amt_msat_in", "amt_msat_out"}, rowSrc)
if err != nil {
return fmt.Errorf("CopyFrom() error: %w", err)
}
log.Printf("count1: %v", count)
cmdTag, err := tx.Exec(context.Background(), `
INSERT INTO forwarding_history
SELECT *
FROM tmp_table
ON CONFLICT DO NOTHING
`)
if err != nil {
return fmt.Errorf("INSERT INTO forwarding_history error: %w", err)
}
log.Printf("count2: %v", cmdTag.RowsAffected())
return tx.Commit(context.Background())
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/breez/lspd/btceclegacy" "github.com/breez/lspd/btceclegacy"
"github.com/breez/lspd/config" "github.com/breez/lspd/config"
"github.com/breez/lspd/interceptor"
"github.com/breez/lspd/lightning" "github.com/breez/lspd/lightning"
lspdrpc "github.com/breez/lspd/rpc" lspdrpc "github.com/breez/lspd/rpc"
ecies "github.com/ecies/go/v2" ecies "github.com/ecies/go/v2"
@@ -37,6 +38,7 @@ type server struct {
lis net.Listener lis net.Listener
s *grpc.Server s *grpc.Server
nodes map[string]*node nodes map[string]*node
store interceptor.InterceptStore
} }
type node struct { type node struct {
@@ -114,7 +116,7 @@ func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymen
log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err)
return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err)
} }
err = registerPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag) err = s.store.RegisterPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag)
if err != nil { if err != nil {
log.Printf("RegisterPayment() error: %v", err) log.Printf("RegisterPayment() error: %v", err)
return nil, fmt.Errorf("RegisterPayment() error: %w", err) return nil, fmt.Errorf("RegisterPayment() error: %w", err)
@@ -261,7 +263,12 @@ func (s *server) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lsp
return &lspdrpc.Encrypted{Data: encrypted}, nil return &lspdrpc.Encrypted{Data: encrypted}, nil
} }
func NewGrpcServer(configs []*config.NodeConfig, address string, certmagicDomain string) (*server, error) { func NewGrpcServer(
configs []*config.NodeConfig,
address string,
certmagicDomain string,
store interceptor.InterceptStore,
) (*server, error) {
if len(configs) == 0 { if len(configs) == 0 {
return nil, fmt.Errorf("no nodes supplied") return nil, fmt.Errorf("no nodes supplied")
} }
@@ -319,6 +326,7 @@ func NewGrpcServer(configs []*config.NodeConfig, address string, certmagicDomain
address: address, address: address,
certmagicDomain: certmagicDomain, certmagicDomain: certmagicDomain,
nodes: nodes, nodes: nodes,
store: store,
}, nil }, nil
} }
@@ -409,3 +417,14 @@ func getNode(ctx context.Context) (*node, error) {
return node, nil return node, nil
} }
func checkPayment(config *config.NodeConfig, incomingAmountMsat, outgoingAmountMsat int64) error {
fees := incomingAmountMsat * config.ChannelFeePermyriad / 10_000 / 1_000 * 1_000
if fees < config.ChannelMinimumFeeMsat {
fees = config.ChannelMinimumFeeMsat
}
if incomingAmountMsat-outgoingAmountMsat < fees {
return fmt.Errorf("not enough fees")
}
return nil
}