diff --git a/cln_interceptor.go b/cln_interceptor.go index bc7b491..cce2c94 100644 --- a/cln_interceptor.go +++ b/cln_interceptor.go @@ -24,6 +24,7 @@ import ( ) type ClnHtlcInterceptor struct { + interceptor *Interceptor config *config.NodeConfig pluginAddress string client *ClnClient @@ -35,15 +36,7 @@ type ClnHtlcInterceptor struct { cancel context.CancelFunc } -func NewClnHtlcInterceptor(conf *config.NodeConfig) (*ClnHtlcInterceptor, error) { - if conf.Cln == nil { - return nil, fmt.Errorf("missing cln config") - } - - client, err := NewClnClient(conf.Cln.SocketPath) - if err != nil { - return nil, err - } +func NewClnHtlcInterceptor(conf *config.NodeConfig, client *ClnClient, interceptor *Interceptor) (*ClnHtlcInterceptor, error) { i := &ClnHtlcInterceptor{ config: conf, pluginAddress: conf.Cln.PluginAddress, @@ -169,7 +162,7 @@ func (i *ClnHtlcInterceptor) intercept() error { interceptorClient.Send(i.defaultResolution(request)) i.doneWg.Done() } - interceptResult := intercept(i.client, i.config, nextHop, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry) + interceptResult := i.interceptor.Intercept(nextHop, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry) switch interceptResult.action { case INTERCEPT_RESUME_WITH_ONION: interceptorClient.Send(i.resumeWithOnion(request, interceptResult)) diff --git a/forwarding_history.go b/forwarding_history.go index e45166d..1fa960d 100644 --- a/forwarding_history.go +++ b/forwarding_history.go @@ -7,6 +7,8 @@ import ( "log" "time" + "github.com/breez/lspd/interceptor" + "github.com/breez/lspd/lnd" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/chainrpc" @@ -36,14 +38,32 @@ func (cfe *copyFromEvents) Err() error { return cfe.err } -func channelsSynchronize(ctx context.Context, client *LndClient) { +type ForwardingHistorySync struct { + client *LndClient + interceptStore interceptor.InterceptStore + forwardingStore lnd.ForwardingEventStore +} + +func NewForwardingHistorySync( + client *LndClient, + interceptStore interceptor.InterceptStore, + forwardingStore lnd.ForwardingEventStore, +) *ForwardingHistorySync { + return &ForwardingHistorySync{ + client: client, + interceptStore: interceptStore, + forwardingStore: forwardingStore, + } +} + +func (s *ForwardingHistorySync) ChannelsSynchronize(ctx context.Context) { lastSync := time.Now().Add(-6 * time.Minute) for { if ctx.Err() != nil { return } - stream, err := client.chainNotifierClient.RegisterBlockEpochNtfn(ctx, &chainrpc.BlockEpoch{}) + stream, err := s.client.chainNotifierClient.RegisterBlockEpochNtfn(ctx, &chainrpc.BlockEpoch{}) if err != nil { log.Printf("chainNotifierClient.RegisterBlockEpochNtfn(): %v", err) <-time.After(time.Second) @@ -67,7 +87,7 @@ func channelsSynchronize(ctx context.Context, client *LndClient) { return case <-time.After(1 * time.Minute): } - err = channelsSynchronizeOnce(client) + err = s.ChannelsSynchronizeOnce() lastSync = time.Now() log.Printf("channelsSynchronizeOnce() err: %v", err) } @@ -75,9 +95,9 @@ func channelsSynchronize(ctx context.Context, client *LndClient) { } } -func channelsSynchronizeOnce(client *LndClient) error { +func (s *ForwardingHistorySync) ChannelsSynchronizeOnce() error { log.Printf("channelsSynchronizeOnce - begin") - channels, err := client.client.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{PrivateOnly: true}) + channels, err := s.client.client.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{PrivateOnly: true}) if err != nil { log.Printf("ListChannels error: %v", err) return fmt.Errorf("client.ListChannels() error: %w", err) @@ -97,7 +117,7 @@ func channelsSynchronizeOnce(client *LndClient) error { confirmedChanId = 0 } } - err = insertChannel(c.ChanId, confirmedChanId, c.ChannelPoint, nodeID, lastUpdate) + err = s.interceptStore.InsertChannel(c.ChanId, confirmedChanId, c.ChannelPoint, nodeID, lastUpdate) if err != nil { log.Printf("insertChannel(%v, %v, %x) in channelsSynchronizeOnce error: %v", c.ChanId, c.ChannelPoint, nodeID, err) continue @@ -108,13 +128,13 @@ func channelsSynchronizeOnce(client *LndClient) error { return nil } -func forwardingHistorySynchronize(ctx context.Context, client *LndClient) { +func (s *ForwardingHistorySync) ForwardingHistorySynchronize(ctx context.Context) { for { if ctx.Err() != nil { return } - err := forwardingHistorySynchronizeOnce(client) + err := s.ForwardingHistorySynchronizeOnce() log.Printf("forwardingHistorySynchronizeOnce() err: %v", err) select { case <-time.After(1 * time.Minute): @@ -123,8 +143,8 @@ func forwardingHistorySynchronize(ctx context.Context, client *LndClient) { } } -func forwardingHistorySynchronizeOnce(client *LndClient) error { - last, err := lastForwardingEvent() +func (s *ForwardingHistorySync) ForwardingHistorySynchronizeOnce() error { + last, err := s.forwardingStore.LastForwardingEvent() if err != nil { return fmt.Errorf("lastForwardingEvent() error: %w", err) } @@ -138,7 +158,7 @@ func forwardingHistorySynchronizeOnce(client *LndClient) error { endTime := uint64(now.Add(time.Hour * 24).Unix()) indexOffset := uint32(0) for { - forwardHistory, err := client.client.ForwardingHistory(context.Background(), &lnrpc.ForwardingHistoryRequest{ + forwardHistory, err := s.client.client.ForwardingHistory(context.Background(), &lnrpc.ForwardingHistoryRequest{ StartTime: uint64(last), EndTime: endTime, NumMaxEvents: 10000, @@ -154,7 +174,7 @@ func forwardingHistorySynchronizeOnce(client *LndClient) error { } indexOffset = forwardHistory.LastOffsetIndex cfe := copyFromEvents{events: forwardHistory.ForwardingEvents, idx: -1} - err = insertForwardingEvents(&cfe) + err = s.forwardingStore.InsertForwardingEvents(&cfe) if err != nil { log.Printf("insertForwardingEvents() error: %v", err) return fmt.Errorf("insertForwardingEvents() error: %w", err) diff --git a/intercept.go b/intercept.go index 52b6cce..1460a2b 100644 --- a/intercept.go +++ b/intercept.go @@ -11,6 +11,7 @@ import ( "github.com/breez/lspd/chain" "github.com/breez/lspd/config" + "github.com/breez/lspd/interceptor" "github.com/breez/lspd/lightning" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" @@ -51,10 +52,28 @@ type interceptResult struct { onionBlob []byte } -func intercept(client lightning.Client, config *config.NodeConfig, nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) interceptResult { +type Interceptor struct { + client lightning.Client + config *config.NodeConfig + store interceptor.InterceptStore +} + +func NewInterceptor( + client lightning.Client, + config *config.NodeConfig, + store interceptor.InterceptStore, +) *Interceptor { + return &Interceptor{ + client: client, + config: config, + store: store, + } +} + +func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) interceptResult { reqPaymentHashStr := hex.EncodeToString(reqPaymentHash) resp, _, _ := payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) { - paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := paymentInfo(reqPaymentHash) + paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := i.store.PaymentInfo(reqPaymentHash) if err != nil { log.Printf("paymentInfo(%x) error: %v", reqPaymentHash, err) return interceptResult{ @@ -72,14 +91,14 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin if channelPoint == nil { if bytes.Equal(paymentHash, reqPaymentHash) { - if int64(reqIncomingExpiry)-int64(reqOutgoingExpiry) < int64(config.TimeLockDelta) { + if int64(reqIncomingExpiry)-int64(reqOutgoingExpiry) < int64(i.config.TimeLockDelta) { return interceptResult{ action: INTERCEPT_FAIL_HTLC_WITH_CODE, failureCode: FAILURE_TEMPORARY_CHANNEL_FAILURE, }, nil } - channelPoint, err = openChannel(client, config, reqPaymentHash, destination, incomingAmountMsat) + channelPoint, err = i.openChannel(reqPaymentHash, destination, incomingAmountMsat) if err != nil { log.Printf("openChannel(%x, %v) err: %v", destination, incomingAmountMsat, err) return interceptResult{ @@ -173,11 +192,11 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin deadline := time.Now().Add(60 * time.Second) for { - chanResult, _ := client.GetChannel(destination, *channelPoint) + chanResult, _ := i.client.GetChannel(destination, *channelPoint) if chanResult != nil { log.Printf("channel opended successfully alias: %v, confirmed: %v", chanResult.InitialChannelID.ToString(), chanResult.ConfirmedChannelID.ToString()) - err := insertChannel( + err := i.store.InsertChannel( uint64(chanResult.InitialChannelID), uint64(chanResult.ConfirmedChannelID), channelPoint.String(), @@ -226,20 +245,9 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin return resp.(interceptResult) } -func checkPayment(config *config.NodeConfig, incomingAmountMsat, outgoingAmountMsat int64) error { - fees := incomingAmountMsat * config.ChannelFeePermyriad / 10_000 / 1_000 * 1_000 - if fees < config.ChannelMinimumFeeMsat { - fees = config.ChannelMinimumFeeMsat - } - if incomingAmountMsat-outgoingAmountMsat < fees { - return fmt.Errorf("not enough fees") - } - return nil -} - -func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash, destination []byte, incomingAmountMsat int64) (*wire.OutPoint, error) { - capacity := incomingAmountMsat/1000 + config.AdditionalChannelCapacity - if capacity == config.PublicChannelAmount { +func (i *Interceptor) openChannel(paymentHash, destination []byte, incomingAmountMsat int64) (*wire.OutPoint, error) { + capacity := incomingAmountMsat/1000 + i.config.AdditionalChannelCapacity + if capacity == i.config.PublicChannelAmount { capacity++ } @@ -257,7 +265,7 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash feeStr = fmt.Sprintf("%.5f", *feeEstimation) } else { log.Printf("Error estimating chain fee, fallback to target conf: %v", err) - targetConf = &config.TargetConf + targetConf = &i.config.TargetConf confStr = fmt.Sprintf("%v", *targetConf) } } @@ -269,10 +277,10 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash feeStr, confStr, ) - channelPoint, err := client.OpenChannel(&lightning.OpenChannelRequest{ + channelPoint, err := i.client.OpenChannel(&lightning.OpenChannelRequest{ Destination: destination, CapacitySat: uint64(capacity), - MinConfs: config.MinConfs, + MinConfs: i.config.MinConfs, IsPrivate: true, IsZeroConf: true, FeeSatPerVByte: feeEstimation, @@ -289,6 +297,6 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash capacity, channelPoint.String(), ) - err = setFundingTx(paymentHash, channelPoint) + err = i.store.SetFundingTx(paymentHash, channelPoint) return channelPoint, err } diff --git a/interceptor/store.go b/interceptor/store.go new file mode 100644 index 0000000..3956ec2 --- /dev/null +++ b/interceptor/store.go @@ -0,0 +1,14 @@ +package interceptor + +import ( + "time" + + "github.com/btcsuite/btcd/wire" +) + +type InterceptStore interface { + PaymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error) + SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error + RegisterPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error + InsertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error +} diff --git a/lnd/forwarding_event_store.go b/lnd/forwarding_event_store.go new file mode 100644 index 0000000..22fc4a4 --- /dev/null +++ b/lnd/forwarding_event_store.go @@ -0,0 +1,12 @@ +package lnd + +type CopyFromSource interface { + Next() bool + Values() ([]interface{}, error) + Err() error +} + +type ForwardingEventStore interface { + LastForwardingEvent() (int64, error) + InsertForwardingEvents(rowSrc CopyFromSource) error +} diff --git a/lnd_interceptor.go b/lnd_interceptor.go index f1cf06d..ca84cd8 100644 --- a/lnd_interceptor.go +++ b/lnd_interceptor.go @@ -15,6 +15,8 @@ import ( ) type LndHtlcInterceptor struct { + fwsync *ForwardingHistorySync + interceptor *Interceptor config *config.NodeConfig client *LndClient stopRequested bool @@ -24,17 +26,17 @@ type LndHtlcInterceptor struct { cancel context.CancelFunc } -func NewLndHtlcInterceptor(conf *config.NodeConfig) (*LndHtlcInterceptor, error) { - if conf.Lnd == nil { - return nil, fmt.Errorf("missing lnd configuration") - } - client, err := NewLndClient(conf.Lnd) - if err != nil { - return nil, err - } +func NewLndHtlcInterceptor( + conf *config.NodeConfig, + client *LndClient, + fwsync *ForwardingHistorySync, + interceptor *Interceptor, +) (*LndHtlcInterceptor, error) { i := &LndHtlcInterceptor{ - config: conf, - client: client, + config: conf, + client: client, + fwsync: fwsync, + interceptor: interceptor, } i.initWg.Add(1) @@ -47,8 +49,8 @@ func (i *LndHtlcInterceptor) Start() error { i.ctx = ctx i.cancel = cancel i.stopRequested = false - go forwardingHistorySynchronize(ctx, i.client) - go channelsSynchronize(ctx, i.client) + go i.fwsync.ForwardingHistorySynchronize(ctx) + go i.fwsync.ChannelsSynchronize(ctx) return i.intercept() } @@ -149,7 +151,7 @@ func (i *LndHtlcInterceptor) intercept() error { i.doneWg.Add(1) go func() { - interceptResult := intercept(i.client, i.config, nextHop, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry) + interceptResult := i.interceptor.Intercept(nextHop, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry) switch interceptResult.action { case INTERCEPT_RESUME_WITH_ONION: interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{ diff --git a/main.go b/main.go index 4c5ad3f..1f40262 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "github.com/breez/lspd/chain" "github.com/breez/lspd/config" "github.com/breez/lspd/mempool" + "github.com/breez/lspd/postgresql" "github.com/btcsuite/btcd/btcec/v2" ) @@ -63,43 +64,59 @@ func main() { log.Printf("using mempool api for fee estimation: %v, fee strategy: %v:%v", mempoolUrl, envFeeStrategy, feeStrategy) } + databaseUrl := os.Getenv("DATABASE_URL") + pool, err := postgresql.PgConnect(databaseUrl) + if err != nil { + log.Fatalf("pgConnect() error: %v", err) + } + + interceptStore := postgresql.NewPostgresInterceptStore(pool) + forwardingStore := postgresql.NewForwardingEventStore(pool) + var interceptors []HtlcInterceptor for _, node := range nodes { - var interceptor HtlcInterceptor + var htlcInterceptor HtlcInterceptor if node.Lnd != nil { - interceptor, err = NewLndHtlcInterceptor(node) + client, err := NewLndClient(node.Lnd) + if err != nil { + log.Fatalf("failed to initialize LND client: %v", err) + } + + fwsync := NewForwardingHistorySync(client, interceptStore, forwardingStore) + interceptor := NewInterceptor(client, node, interceptStore) + htlcInterceptor, err = NewLndHtlcInterceptor(node, client, fwsync, interceptor) if err != nil { log.Fatalf("failed to initialize LND interceptor: %v", err) } } if node.Cln != nil { - interceptor, err = NewClnHtlcInterceptor(node) + client, err := NewClnClient(node.Cln.SocketPath) + if err != nil { + log.Fatalf("failed to initialize CLN client: %v", err) + } + + interceptor := NewInterceptor(client, node, interceptStore) + htlcInterceptor, err = NewClnHtlcInterceptor(node, client, interceptor) if err != nil { log.Fatalf("failed to initialize CLN interceptor: %v", err) } } - if interceptor == nil { + if htlcInterceptor == nil { log.Fatalf("node has to be either cln or lnd") } - interceptors = append(interceptors, interceptor) + interceptors = append(interceptors, htlcInterceptor) } address := os.Getenv("LISTEN_ADDRESS") certMagicDomain := os.Getenv("CERTMAGIC_DOMAIN") - s, err := NewGrpcServer(nodes, address, certMagicDomain) + s, err := NewGrpcServer(nodes, address, certMagicDomain, interceptStore) if err != nil { log.Fatalf("failed to initialize grpc server: %v", err) } - databaseUrl := os.Getenv("DATABASE_URL") - err = pgConnect(databaseUrl) - if err != nil { - log.Fatalf("pgConnect() error: %v", err) - } - var wg sync.WaitGroup wg.Add(len(interceptors) + 1) diff --git a/postgresql/connect.go b/postgresql/connect.go new file mode 100644 index 0000000..b14942f --- /dev/null +++ b/postgresql/connect.go @@ -0,0 +1,17 @@ +package postgresql + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v4/pgxpool" +) + +func PgConnect(databaseUrl string) (*pgxpool.Pool, error) { + var err error + pgxPool, err := pgxpool.Connect(context.Background(), databaseUrl) + if err != nil { + return nil, fmt.Errorf("pgxpool.Connect(%v): %w", databaseUrl, err) + } + return pgxPool, nil +} diff --git a/postgresql/forwarding_event_store.go b/postgresql/forwarding_event_store.go new file mode 100644 index 0000000..42cda1c --- /dev/null +++ b/postgresql/forwarding_event_store.go @@ -0,0 +1,68 @@ +package postgresql + +import ( + "context" + "fmt" + "log" + + "github.com/breez/lspd/lnd" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" +) + +type ForwardingEventStore struct { + pool *pgxpool.Pool +} + +func NewForwardingEventStore(pool *pgxpool.Pool) *ForwardingEventStore { + return &ForwardingEventStore{pool: pool} +} + +func (s *ForwardingEventStore) LastForwardingEvent() (int64, error) { + var last int64 + err := s.pool.QueryRow(context.Background(), + `SELECT coalesce(MAX("timestamp"), 0) AS last FROM forwarding_history`).Scan(&last) + if err != nil { + return 0, err + } + return last, nil +} + +func (s *ForwardingEventStore) InsertForwardingEvents(rowSrc lnd.CopyFromSource) error { + + tx, err := s.pool.Begin(context.Background()) + if err != nil { + return fmt.Errorf("pgxPool.Begin() error: %w", err) + } + defer tx.Rollback(context.Background()) + + _, err = tx.Exec(context.Background(), ` + CREATE TEMP TABLE tmp_table ON COMMIT DROP AS + SELECT * + FROM forwarding_history + WITH NO DATA; + `) + if err != nil { + return fmt.Errorf("CREATE TEMP TABLE error: %w", err) + } + + count, err := tx.CopyFrom(context.Background(), + pgx.Identifier{"tmp_table"}, + []string{"timestamp", "chanid_in", "chanid_out", "amt_msat_in", "amt_msat_out"}, rowSrc) + if err != nil { + return fmt.Errorf("CopyFrom() error: %w", err) + } + log.Printf("count1: %v", count) + + cmdTag, err := tx.Exec(context.Background(), ` + INSERT INTO forwarding_history + SELECT * + FROM tmp_table + ON CONFLICT DO NOTHING + `) + if err != nil { + return fmt.Errorf("INSERT INTO forwarding_history error: %w", err) + } + log.Printf("count2: %v", cmdTag.RowsAffected()) + return tx.Commit(context.Background()) +} diff --git a/db.go b/postgresql/intercept_store.go similarity index 57% rename from db.go rename to postgresql/intercept_store.go index cdd2b43..c3dec86 100644 --- a/db.go +++ b/postgresql/intercept_store.go @@ -1,4 +1,4 @@ -package main +package postgresql import ( "context" @@ -13,27 +13,22 @@ import ( "github.com/jackc/pgx/v4/pgxpool" ) -var ( - pgxPool *pgxpool.Pool -) - -func pgConnect(databaseUrl string) error { - var err error - pgxPool, err = pgxpool.Connect(context.Background(), databaseUrl) - if err != nil { - return fmt.Errorf("pgxpool.Connect(%v): %w", databaseUrl, err) - } - return nil +type PostgresInterceptStore struct { + pool *pgxpool.Pool } -func paymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error) { +func NewPostgresInterceptStore(pool *pgxpool.Pool) *PostgresInterceptStore { + return &PostgresInterceptStore{pool: pool} +} + +func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error) { var ( paymentHash, paymentSecret, destination []byte incomingAmountMsat, outgoingAmountMsat int64 fundingTxID []byte fundingTxOutnum pgtype.Int4 ) - err := pgxPool.QueryRow(context.Background(), + err := s.pool.QueryRow(context.Background(), `SELECT payment_hash, payment_secret, destination, incoming_amount_msat, outgoing_amount_msat, funding_tx_id, funding_tx_outnum FROM payments WHERE payment_hash=$1 OR sha256('probing-01:' || payment_hash)=$1`, @@ -55,8 +50,8 @@ func paymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, return paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, cp, nil } -func setFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error { - commandTag, err := pgxPool.Exec(context.Background(), +func (s *PostgresInterceptStore) SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error { + commandTag, err := s.pool.Exec(context.Background(), `UPDATE payments SET funding_tx_id = $2, funding_tx_outnum = $3 WHERE payment_hash=$1`, @@ -65,12 +60,12 @@ func setFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error { return err } -func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error { +func (s *PostgresInterceptStore) RegisterPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error { var t *string if tag != "" { t = &tag } - commandTag, err := pgxPool.Exec(context.Background(), + commandTag, err := s.pool.Exec(context.Background(), `INSERT INTO payments (destination, payment_hash, payment_secret, incoming_amount_msat, outgoing_amount_msat, tag) VALUES ($1, $2, $3, $4, $5, $6) @@ -85,14 +80,14 @@ func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmo return nil } -func insertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error { +func (s *PostgresInterceptStore) InsertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error { query := `INSERT INTO channels (initial_chanid, confirmed_chanid, channel_point, nodeid, last_update) VALUES ($1, NULLIF($2, 0::int8), $3, $4, $5) ON CONFLICT (channel_point) DO UPDATE SET confirmed_chanid=NULLIF($2, 0::int8), last_update=$5` - c, err := pgxPool.Exec(context.Background(), + c, err := s.pool.Exec(context.Background(), query, int64(initialChanID), int64(confirmedChanId), channelPoint, nodeID, lastUpdate) if err != nil { log.Printf("insertChannel(%v, %v, %s, %x) error: %v", @@ -104,52 +99,3 @@ func insertChannel(initialChanID, confirmedChanId uint64, channelPoint string, n initialChanID, confirmedChanId, nodeID, c.String()) return nil } - -func lastForwardingEvent() (int64, error) { - var last int64 - err := pgxPool.QueryRow(context.Background(), - `SELECT coalesce(MAX("timestamp"), 0) AS last FROM forwarding_history`).Scan(&last) - if err != nil { - return 0, err - } - return last, nil -} - -func insertForwardingEvents(rowSrc pgx.CopyFromSource) error { - - tx, err := pgxPool.Begin(context.Background()) - if err != nil { - return fmt.Errorf("pgxPool.Begin() error: %w", err) - } - defer tx.Rollback(context.Background()) - - _, err = tx.Exec(context.Background(), ` - CREATE TEMP TABLE tmp_table ON COMMIT DROP AS - SELECT * - FROM forwarding_history - WITH NO DATA; - `) - if err != nil { - return fmt.Errorf("CREATE TEMP TABLE error: %w", err) - } - - count, err := tx.CopyFrom(context.Background(), - pgx.Identifier{"tmp_table"}, - []string{"timestamp", "chanid_in", "chanid_out", "amt_msat_in", "amt_msat_out"}, rowSrc) - if err != nil { - return fmt.Errorf("CopyFrom() error: %w", err) - } - log.Printf("count1: %v", count) - - cmdTag, err := tx.Exec(context.Background(), ` - INSERT INTO forwarding_history - SELECT * - FROM tmp_table - ON CONFLICT DO NOTHING - `) - if err != nil { - return fmt.Errorf("INSERT INTO forwarding_history error: %w", err) - } - log.Printf("count2: %v", cmdTag.RowsAffected()) - return tx.Commit(context.Background()) -} diff --git a/server.go b/server.go index f14152f..184e794 100644 --- a/server.go +++ b/server.go @@ -12,6 +12,7 @@ import ( "github.com/breez/lspd/btceclegacy" "github.com/breez/lspd/config" + "github.com/breez/lspd/interceptor" "github.com/breez/lspd/lightning" lspdrpc "github.com/breez/lspd/rpc" ecies "github.com/ecies/go/v2" @@ -37,6 +38,7 @@ type server struct { lis net.Listener s *grpc.Server nodes map[string]*node + store interceptor.InterceptStore } type node struct { @@ -114,7 +116,7 @@ func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymen log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) } - err = registerPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag) + err = s.store.RegisterPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag) if err != nil { log.Printf("RegisterPayment() error: %v", err) return nil, fmt.Errorf("RegisterPayment() error: %w", err) @@ -261,7 +263,12 @@ func (s *server) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lsp return &lspdrpc.Encrypted{Data: encrypted}, nil } -func NewGrpcServer(configs []*config.NodeConfig, address string, certmagicDomain string) (*server, error) { +func NewGrpcServer( + configs []*config.NodeConfig, + address string, + certmagicDomain string, + store interceptor.InterceptStore, +) (*server, error) { if len(configs) == 0 { return nil, fmt.Errorf("no nodes supplied") } @@ -319,6 +326,7 @@ func NewGrpcServer(configs []*config.NodeConfig, address string, certmagicDomain address: address, certmagicDomain: certmagicDomain, nodes: nodes, + store: store, }, nil } @@ -409,3 +417,14 @@ func getNode(ctx context.Context) (*node, error) { return node, nil } + +func checkPayment(config *config.NodeConfig, incomingAmountMsat, outgoingAmountMsat int64) error { + fees := incomingAmountMsat * config.ChannelFeePermyriad / 10_000 / 1_000 * 1_000 + if fees < config.ChannelMinimumFeeMsat { + fees = config.ChannelMinimumFeeMsat + } + if incomingAmountMsat-outgoingAmountMsat < fees { + return fmt.Errorf("not enough fees") + } + return nil +}