mirror of
https://github.com/aljazceru/lspd.git
synced 2025-12-18 14:24:21 +01:00
cleanup: convert intercept and database to types
This commit is contained in:
@@ -24,6 +24,7 @@ import (
|
||||
)
|
||||
|
||||
type ClnHtlcInterceptor struct {
|
||||
interceptor *Interceptor
|
||||
config *config.NodeConfig
|
||||
pluginAddress string
|
||||
client *ClnClient
|
||||
@@ -35,15 +36,7 @@ type ClnHtlcInterceptor struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewClnHtlcInterceptor(conf *config.NodeConfig) (*ClnHtlcInterceptor, error) {
|
||||
if conf.Cln == nil {
|
||||
return nil, fmt.Errorf("missing cln config")
|
||||
}
|
||||
|
||||
client, err := NewClnClient(conf.Cln.SocketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func NewClnHtlcInterceptor(conf *config.NodeConfig, client *ClnClient, interceptor *Interceptor) (*ClnHtlcInterceptor, error) {
|
||||
i := &ClnHtlcInterceptor{
|
||||
config: conf,
|
||||
pluginAddress: conf.Cln.PluginAddress,
|
||||
@@ -169,7 +162,7 @@ func (i *ClnHtlcInterceptor) intercept() error {
|
||||
interceptorClient.Send(i.defaultResolution(request))
|
||||
i.doneWg.Done()
|
||||
}
|
||||
interceptResult := intercept(i.client, i.config, nextHop, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry)
|
||||
interceptResult := i.interceptor.Intercept(nextHop, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry)
|
||||
switch interceptResult.action {
|
||||
case INTERCEPT_RESUME_WITH_ONION:
|
||||
interceptorClient.Send(i.resumeWithOnion(request, interceptResult))
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/breez/lspd/interceptor"
|
||||
"github.com/breez/lspd/lnd"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lnrpc/chainrpc"
|
||||
@@ -36,14 +38,32 @@ func (cfe *copyFromEvents) Err() error {
|
||||
return cfe.err
|
||||
}
|
||||
|
||||
func channelsSynchronize(ctx context.Context, client *LndClient) {
|
||||
type ForwardingHistorySync struct {
|
||||
client *LndClient
|
||||
interceptStore interceptor.InterceptStore
|
||||
forwardingStore lnd.ForwardingEventStore
|
||||
}
|
||||
|
||||
func NewForwardingHistorySync(
|
||||
client *LndClient,
|
||||
interceptStore interceptor.InterceptStore,
|
||||
forwardingStore lnd.ForwardingEventStore,
|
||||
) *ForwardingHistorySync {
|
||||
return &ForwardingHistorySync{
|
||||
client: client,
|
||||
interceptStore: interceptStore,
|
||||
forwardingStore: forwardingStore,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ForwardingHistorySync) ChannelsSynchronize(ctx context.Context) {
|
||||
lastSync := time.Now().Add(-6 * time.Minute)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stream, err := client.chainNotifierClient.RegisterBlockEpochNtfn(ctx, &chainrpc.BlockEpoch{})
|
||||
stream, err := s.client.chainNotifierClient.RegisterBlockEpochNtfn(ctx, &chainrpc.BlockEpoch{})
|
||||
if err != nil {
|
||||
log.Printf("chainNotifierClient.RegisterBlockEpochNtfn(): %v", err)
|
||||
<-time.After(time.Second)
|
||||
@@ -67,7 +87,7 @@ func channelsSynchronize(ctx context.Context, client *LndClient) {
|
||||
return
|
||||
case <-time.After(1 * time.Minute):
|
||||
}
|
||||
err = channelsSynchronizeOnce(client)
|
||||
err = s.ChannelsSynchronizeOnce()
|
||||
lastSync = time.Now()
|
||||
log.Printf("channelsSynchronizeOnce() err: %v", err)
|
||||
}
|
||||
@@ -75,9 +95,9 @@ func channelsSynchronize(ctx context.Context, client *LndClient) {
|
||||
}
|
||||
}
|
||||
|
||||
func channelsSynchronizeOnce(client *LndClient) error {
|
||||
func (s *ForwardingHistorySync) ChannelsSynchronizeOnce() error {
|
||||
log.Printf("channelsSynchronizeOnce - begin")
|
||||
channels, err := client.client.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{PrivateOnly: true})
|
||||
channels, err := s.client.client.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{PrivateOnly: true})
|
||||
if err != nil {
|
||||
log.Printf("ListChannels error: %v", err)
|
||||
return fmt.Errorf("client.ListChannels() error: %w", err)
|
||||
@@ -97,7 +117,7 @@ func channelsSynchronizeOnce(client *LndClient) error {
|
||||
confirmedChanId = 0
|
||||
}
|
||||
}
|
||||
err = insertChannel(c.ChanId, confirmedChanId, c.ChannelPoint, nodeID, lastUpdate)
|
||||
err = s.interceptStore.InsertChannel(c.ChanId, confirmedChanId, c.ChannelPoint, nodeID, lastUpdate)
|
||||
if err != nil {
|
||||
log.Printf("insertChannel(%v, %v, %x) in channelsSynchronizeOnce error: %v", c.ChanId, c.ChannelPoint, nodeID, err)
|
||||
continue
|
||||
@@ -108,13 +128,13 @@ func channelsSynchronizeOnce(client *LndClient) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func forwardingHistorySynchronize(ctx context.Context, client *LndClient) {
|
||||
func (s *ForwardingHistorySync) ForwardingHistorySynchronize(ctx context.Context) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := forwardingHistorySynchronizeOnce(client)
|
||||
err := s.ForwardingHistorySynchronizeOnce()
|
||||
log.Printf("forwardingHistorySynchronizeOnce() err: %v", err)
|
||||
select {
|
||||
case <-time.After(1 * time.Minute):
|
||||
@@ -123,8 +143,8 @@ func forwardingHistorySynchronize(ctx context.Context, client *LndClient) {
|
||||
}
|
||||
}
|
||||
|
||||
func forwardingHistorySynchronizeOnce(client *LndClient) error {
|
||||
last, err := lastForwardingEvent()
|
||||
func (s *ForwardingHistorySync) ForwardingHistorySynchronizeOnce() error {
|
||||
last, err := s.forwardingStore.LastForwardingEvent()
|
||||
if err != nil {
|
||||
return fmt.Errorf("lastForwardingEvent() error: %w", err)
|
||||
}
|
||||
@@ -138,7 +158,7 @@ func forwardingHistorySynchronizeOnce(client *LndClient) error {
|
||||
endTime := uint64(now.Add(time.Hour * 24).Unix())
|
||||
indexOffset := uint32(0)
|
||||
for {
|
||||
forwardHistory, err := client.client.ForwardingHistory(context.Background(), &lnrpc.ForwardingHistoryRequest{
|
||||
forwardHistory, err := s.client.client.ForwardingHistory(context.Background(), &lnrpc.ForwardingHistoryRequest{
|
||||
StartTime: uint64(last),
|
||||
EndTime: endTime,
|
||||
NumMaxEvents: 10000,
|
||||
@@ -154,7 +174,7 @@ func forwardingHistorySynchronizeOnce(client *LndClient) error {
|
||||
}
|
||||
indexOffset = forwardHistory.LastOffsetIndex
|
||||
cfe := copyFromEvents{events: forwardHistory.ForwardingEvents, idx: -1}
|
||||
err = insertForwardingEvents(&cfe)
|
||||
err = s.forwardingStore.InsertForwardingEvents(&cfe)
|
||||
if err != nil {
|
||||
log.Printf("insertForwardingEvents() error: %v", err)
|
||||
return fmt.Errorf("insertForwardingEvents() error: %w", err)
|
||||
|
||||
56
intercept.go
56
intercept.go
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/breez/lspd/chain"
|
||||
"github.com/breez/lspd/config"
|
||||
"github.com/breez/lspd/interceptor"
|
||||
"github.com/breez/lspd/lightning"
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
@@ -51,10 +52,28 @@ type interceptResult struct {
|
||||
onionBlob []byte
|
||||
}
|
||||
|
||||
func intercept(client lightning.Client, config *config.NodeConfig, nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) interceptResult {
|
||||
type Interceptor struct {
|
||||
client lightning.Client
|
||||
config *config.NodeConfig
|
||||
store interceptor.InterceptStore
|
||||
}
|
||||
|
||||
func NewInterceptor(
|
||||
client lightning.Client,
|
||||
config *config.NodeConfig,
|
||||
store interceptor.InterceptStore,
|
||||
) *Interceptor {
|
||||
return &Interceptor{
|
||||
client: client,
|
||||
config: config,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) interceptResult {
|
||||
reqPaymentHashStr := hex.EncodeToString(reqPaymentHash)
|
||||
resp, _, _ := payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) {
|
||||
paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := paymentInfo(reqPaymentHash)
|
||||
paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := i.store.PaymentInfo(reqPaymentHash)
|
||||
if err != nil {
|
||||
log.Printf("paymentInfo(%x) error: %v", reqPaymentHash, err)
|
||||
return interceptResult{
|
||||
@@ -72,14 +91,14 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin
|
||||
|
||||
if channelPoint == nil {
|
||||
if bytes.Equal(paymentHash, reqPaymentHash) {
|
||||
if int64(reqIncomingExpiry)-int64(reqOutgoingExpiry) < int64(config.TimeLockDelta) {
|
||||
if int64(reqIncomingExpiry)-int64(reqOutgoingExpiry) < int64(i.config.TimeLockDelta) {
|
||||
return interceptResult{
|
||||
action: INTERCEPT_FAIL_HTLC_WITH_CODE,
|
||||
failureCode: FAILURE_TEMPORARY_CHANNEL_FAILURE,
|
||||
}, nil
|
||||
}
|
||||
|
||||
channelPoint, err = openChannel(client, config, reqPaymentHash, destination, incomingAmountMsat)
|
||||
channelPoint, err = i.openChannel(reqPaymentHash, destination, incomingAmountMsat)
|
||||
if err != nil {
|
||||
log.Printf("openChannel(%x, %v) err: %v", destination, incomingAmountMsat, err)
|
||||
return interceptResult{
|
||||
@@ -173,11 +192,11 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
|
||||
for {
|
||||
chanResult, _ := client.GetChannel(destination, *channelPoint)
|
||||
chanResult, _ := i.client.GetChannel(destination, *channelPoint)
|
||||
if chanResult != nil {
|
||||
log.Printf("channel opended successfully alias: %v, confirmed: %v", chanResult.InitialChannelID.ToString(), chanResult.ConfirmedChannelID.ToString())
|
||||
|
||||
err := insertChannel(
|
||||
err := i.store.InsertChannel(
|
||||
uint64(chanResult.InitialChannelID),
|
||||
uint64(chanResult.ConfirmedChannelID),
|
||||
channelPoint.String(),
|
||||
@@ -226,20 +245,9 @@ func intercept(client lightning.Client, config *config.NodeConfig, nextHop strin
|
||||
return resp.(interceptResult)
|
||||
}
|
||||
|
||||
func checkPayment(config *config.NodeConfig, incomingAmountMsat, outgoingAmountMsat int64) error {
|
||||
fees := incomingAmountMsat * config.ChannelFeePermyriad / 10_000 / 1_000 * 1_000
|
||||
if fees < config.ChannelMinimumFeeMsat {
|
||||
fees = config.ChannelMinimumFeeMsat
|
||||
}
|
||||
if incomingAmountMsat-outgoingAmountMsat < fees {
|
||||
return fmt.Errorf("not enough fees")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash, destination []byte, incomingAmountMsat int64) (*wire.OutPoint, error) {
|
||||
capacity := incomingAmountMsat/1000 + config.AdditionalChannelCapacity
|
||||
if capacity == config.PublicChannelAmount {
|
||||
func (i *Interceptor) openChannel(paymentHash, destination []byte, incomingAmountMsat int64) (*wire.OutPoint, error) {
|
||||
capacity := incomingAmountMsat/1000 + i.config.AdditionalChannelCapacity
|
||||
if capacity == i.config.PublicChannelAmount {
|
||||
capacity++
|
||||
}
|
||||
|
||||
@@ -257,7 +265,7 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash
|
||||
feeStr = fmt.Sprintf("%.5f", *feeEstimation)
|
||||
} else {
|
||||
log.Printf("Error estimating chain fee, fallback to target conf: %v", err)
|
||||
targetConf = &config.TargetConf
|
||||
targetConf = &i.config.TargetConf
|
||||
confStr = fmt.Sprintf("%v", *targetConf)
|
||||
}
|
||||
}
|
||||
@@ -269,10 +277,10 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash
|
||||
feeStr,
|
||||
confStr,
|
||||
)
|
||||
channelPoint, err := client.OpenChannel(&lightning.OpenChannelRequest{
|
||||
channelPoint, err := i.client.OpenChannel(&lightning.OpenChannelRequest{
|
||||
Destination: destination,
|
||||
CapacitySat: uint64(capacity),
|
||||
MinConfs: config.MinConfs,
|
||||
MinConfs: i.config.MinConfs,
|
||||
IsPrivate: true,
|
||||
IsZeroConf: true,
|
||||
FeeSatPerVByte: feeEstimation,
|
||||
@@ -289,6 +297,6 @@ func openChannel(client lightning.Client, config *config.NodeConfig, paymentHash
|
||||
capacity,
|
||||
channelPoint.String(),
|
||||
)
|
||||
err = setFundingTx(paymentHash, channelPoint)
|
||||
err = i.store.SetFundingTx(paymentHash, channelPoint)
|
||||
return channelPoint, err
|
||||
}
|
||||
|
||||
14
interceptor/store.go
Normal file
14
interceptor/store.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package interceptor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
)
|
||||
|
||||
type InterceptStore interface {
|
||||
PaymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error)
|
||||
SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error
|
||||
RegisterPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error
|
||||
InsertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error
|
||||
}
|
||||
12
lnd/forwarding_event_store.go
Normal file
12
lnd/forwarding_event_store.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package lnd
|
||||
|
||||
type CopyFromSource interface {
|
||||
Next() bool
|
||||
Values() ([]interface{}, error)
|
||||
Err() error
|
||||
}
|
||||
|
||||
type ForwardingEventStore interface {
|
||||
LastForwardingEvent() (int64, error)
|
||||
InsertForwardingEvents(rowSrc CopyFromSource) error
|
||||
}
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
)
|
||||
|
||||
type LndHtlcInterceptor struct {
|
||||
fwsync *ForwardingHistorySync
|
||||
interceptor *Interceptor
|
||||
config *config.NodeConfig
|
||||
client *LndClient
|
||||
stopRequested bool
|
||||
@@ -24,17 +26,17 @@ type LndHtlcInterceptor struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewLndHtlcInterceptor(conf *config.NodeConfig) (*LndHtlcInterceptor, error) {
|
||||
if conf.Lnd == nil {
|
||||
return nil, fmt.Errorf("missing lnd configuration")
|
||||
}
|
||||
client, err := NewLndClient(conf.Lnd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func NewLndHtlcInterceptor(
|
||||
conf *config.NodeConfig,
|
||||
client *LndClient,
|
||||
fwsync *ForwardingHistorySync,
|
||||
interceptor *Interceptor,
|
||||
) (*LndHtlcInterceptor, error) {
|
||||
i := &LndHtlcInterceptor{
|
||||
config: conf,
|
||||
client: client,
|
||||
config: conf,
|
||||
client: client,
|
||||
fwsync: fwsync,
|
||||
interceptor: interceptor,
|
||||
}
|
||||
|
||||
i.initWg.Add(1)
|
||||
@@ -47,8 +49,8 @@ func (i *LndHtlcInterceptor) Start() error {
|
||||
i.ctx = ctx
|
||||
i.cancel = cancel
|
||||
i.stopRequested = false
|
||||
go forwardingHistorySynchronize(ctx, i.client)
|
||||
go channelsSynchronize(ctx, i.client)
|
||||
go i.fwsync.ForwardingHistorySynchronize(ctx)
|
||||
go i.fwsync.ChannelsSynchronize(ctx)
|
||||
|
||||
return i.intercept()
|
||||
}
|
||||
@@ -149,7 +151,7 @@ func (i *LndHtlcInterceptor) intercept() error {
|
||||
|
||||
i.doneWg.Add(1)
|
||||
go func() {
|
||||
interceptResult := intercept(i.client, i.config, nextHop, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry)
|
||||
interceptResult := i.interceptor.Intercept(nextHop, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry)
|
||||
switch interceptResult.action {
|
||||
case INTERCEPT_RESUME_WITH_ONION:
|
||||
interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{
|
||||
|
||||
41
main.go
41
main.go
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/breez/lspd/chain"
|
||||
"github.com/breez/lspd/config"
|
||||
"github.com/breez/lspd/mempool"
|
||||
"github.com/breez/lspd/postgresql"
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
)
|
||||
|
||||
@@ -63,43 +64,59 @@ func main() {
|
||||
log.Printf("using mempool api for fee estimation: %v, fee strategy: %v:%v", mempoolUrl, envFeeStrategy, feeStrategy)
|
||||
}
|
||||
|
||||
databaseUrl := os.Getenv("DATABASE_URL")
|
||||
pool, err := postgresql.PgConnect(databaseUrl)
|
||||
if err != nil {
|
||||
log.Fatalf("pgConnect() error: %v", err)
|
||||
}
|
||||
|
||||
interceptStore := postgresql.NewPostgresInterceptStore(pool)
|
||||
forwardingStore := postgresql.NewForwardingEventStore(pool)
|
||||
|
||||
var interceptors []HtlcInterceptor
|
||||
for _, node := range nodes {
|
||||
var interceptor HtlcInterceptor
|
||||
var htlcInterceptor HtlcInterceptor
|
||||
if node.Lnd != nil {
|
||||
interceptor, err = NewLndHtlcInterceptor(node)
|
||||
client, err := NewLndClient(node.Lnd)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize LND client: %v", err)
|
||||
}
|
||||
|
||||
fwsync := NewForwardingHistorySync(client, interceptStore, forwardingStore)
|
||||
interceptor := NewInterceptor(client, node, interceptStore)
|
||||
htlcInterceptor, err = NewLndHtlcInterceptor(node, client, fwsync, interceptor)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize LND interceptor: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if node.Cln != nil {
|
||||
interceptor, err = NewClnHtlcInterceptor(node)
|
||||
client, err := NewClnClient(node.Cln.SocketPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize CLN client: %v", err)
|
||||
}
|
||||
|
||||
interceptor := NewInterceptor(client, node, interceptStore)
|
||||
htlcInterceptor, err = NewClnHtlcInterceptor(node, client, interceptor)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize CLN interceptor: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if interceptor == nil {
|
||||
if htlcInterceptor == nil {
|
||||
log.Fatalf("node has to be either cln or lnd")
|
||||
}
|
||||
|
||||
interceptors = append(interceptors, interceptor)
|
||||
interceptors = append(interceptors, htlcInterceptor)
|
||||
}
|
||||
|
||||
address := os.Getenv("LISTEN_ADDRESS")
|
||||
certMagicDomain := os.Getenv("CERTMAGIC_DOMAIN")
|
||||
s, err := NewGrpcServer(nodes, address, certMagicDomain)
|
||||
s, err := NewGrpcServer(nodes, address, certMagicDomain, interceptStore)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize grpc server: %v", err)
|
||||
}
|
||||
|
||||
databaseUrl := os.Getenv("DATABASE_URL")
|
||||
err = pgConnect(databaseUrl)
|
||||
if err != nil {
|
||||
log.Fatalf("pgConnect() error: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(interceptors) + 1)
|
||||
|
||||
|
||||
17
postgresql/connect.go
Normal file
17
postgresql/connect.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
func PgConnect(databaseUrl string) (*pgxpool.Pool, error) {
|
||||
var err error
|
||||
pgxPool, err := pgxpool.Connect(context.Background(), databaseUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgxpool.Connect(%v): %w", databaseUrl, err)
|
||||
}
|
||||
return pgxPool, nil
|
||||
}
|
||||
68
postgresql/forwarding_event_store.go
Normal file
68
postgresql/forwarding_event_store.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/breez/lspd/lnd"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
type ForwardingEventStore struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewForwardingEventStore(pool *pgxpool.Pool) *ForwardingEventStore {
|
||||
return &ForwardingEventStore{pool: pool}
|
||||
}
|
||||
|
||||
func (s *ForwardingEventStore) LastForwardingEvent() (int64, error) {
|
||||
var last int64
|
||||
err := s.pool.QueryRow(context.Background(),
|
||||
`SELECT coalesce(MAX("timestamp"), 0) AS last FROM forwarding_history`).Scan(&last)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return last, nil
|
||||
}
|
||||
|
||||
func (s *ForwardingEventStore) InsertForwardingEvents(rowSrc lnd.CopyFromSource) error {
|
||||
|
||||
tx, err := s.pool.Begin(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgxPool.Begin() error: %w", err)
|
||||
}
|
||||
defer tx.Rollback(context.Background())
|
||||
|
||||
_, err = tx.Exec(context.Background(), `
|
||||
CREATE TEMP TABLE tmp_table ON COMMIT DROP AS
|
||||
SELECT *
|
||||
FROM forwarding_history
|
||||
WITH NO DATA;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CREATE TEMP TABLE error: %w", err)
|
||||
}
|
||||
|
||||
count, err := tx.CopyFrom(context.Background(),
|
||||
pgx.Identifier{"tmp_table"},
|
||||
[]string{"timestamp", "chanid_in", "chanid_out", "amt_msat_in", "amt_msat_out"}, rowSrc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CopyFrom() error: %w", err)
|
||||
}
|
||||
log.Printf("count1: %v", count)
|
||||
|
||||
cmdTag, err := tx.Exec(context.Background(), `
|
||||
INSERT INTO forwarding_history
|
||||
SELECT *
|
||||
FROM tmp_table
|
||||
ON CONFLICT DO NOTHING
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("INSERT INTO forwarding_history error: %w", err)
|
||||
}
|
||||
log.Printf("count2: %v", cmdTag.RowsAffected())
|
||||
return tx.Commit(context.Background())
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -13,27 +13,22 @@ import (
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
pgxPool *pgxpool.Pool
|
||||
)
|
||||
|
||||
func pgConnect(databaseUrl string) error {
|
||||
var err error
|
||||
pgxPool, err = pgxpool.Connect(context.Background(), databaseUrl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgxpool.Connect(%v): %w", databaseUrl, err)
|
||||
}
|
||||
return nil
|
||||
type PostgresInterceptStore struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func paymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error) {
|
||||
func NewPostgresInterceptStore(pool *pgxpool.Pool) *PostgresInterceptStore {
|
||||
return &PostgresInterceptStore{pool: pool}
|
||||
}
|
||||
|
||||
func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64, *wire.OutPoint, error) {
|
||||
var (
|
||||
paymentHash, paymentSecret, destination []byte
|
||||
incomingAmountMsat, outgoingAmountMsat int64
|
||||
fundingTxID []byte
|
||||
fundingTxOutnum pgtype.Int4
|
||||
)
|
||||
err := pgxPool.QueryRow(context.Background(),
|
||||
err := s.pool.QueryRow(context.Background(),
|
||||
`SELECT payment_hash, payment_secret, destination, incoming_amount_msat, outgoing_amount_msat, funding_tx_id, funding_tx_outnum
|
||||
FROM payments
|
||||
WHERE payment_hash=$1 OR sha256('probing-01:' || payment_hash)=$1`,
|
||||
@@ -55,8 +50,8 @@ func paymentInfo(htlcPaymentHash []byte) ([]byte, []byte, []byte, int64, int64,
|
||||
return paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, cp, nil
|
||||
}
|
||||
|
||||
func setFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error {
|
||||
commandTag, err := pgxPool.Exec(context.Background(),
|
||||
func (s *PostgresInterceptStore) SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error {
|
||||
commandTag, err := s.pool.Exec(context.Background(),
|
||||
`UPDATE payments
|
||||
SET funding_tx_id = $2, funding_tx_outnum = $3
|
||||
WHERE payment_hash=$1`,
|
||||
@@ -65,12 +60,12 @@ func setFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error {
|
||||
func (s *PostgresInterceptStore) RegisterPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error {
|
||||
var t *string
|
||||
if tag != "" {
|
||||
t = &tag
|
||||
}
|
||||
commandTag, err := pgxPool.Exec(context.Background(),
|
||||
commandTag, err := s.pool.Exec(context.Background(),
|
||||
`INSERT INTO
|
||||
payments (destination, payment_hash, payment_secret, incoming_amount_msat, outgoing_amount_msat, tag)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
@@ -85,14 +80,14 @@ func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmo
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error {
|
||||
func (s *PostgresInterceptStore) InsertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error {
|
||||
|
||||
query := `INSERT INTO
|
||||
channels (initial_chanid, confirmed_chanid, channel_point, nodeid, last_update)
|
||||
VALUES ($1, NULLIF($2, 0::int8), $3, $4, $5)
|
||||
ON CONFLICT (channel_point) DO UPDATE SET confirmed_chanid=NULLIF($2, 0::int8), last_update=$5`
|
||||
|
||||
c, err := pgxPool.Exec(context.Background(),
|
||||
c, err := s.pool.Exec(context.Background(),
|
||||
query, int64(initialChanID), int64(confirmedChanId), channelPoint, nodeID, lastUpdate)
|
||||
if err != nil {
|
||||
log.Printf("insertChannel(%v, %v, %s, %x) error: %v",
|
||||
@@ -104,52 +99,3 @@ func insertChannel(initialChanID, confirmedChanId uint64, channelPoint string, n
|
||||
initialChanID, confirmedChanId, nodeID, c.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func lastForwardingEvent() (int64, error) {
|
||||
var last int64
|
||||
err := pgxPool.QueryRow(context.Background(),
|
||||
`SELECT coalesce(MAX("timestamp"), 0) AS last FROM forwarding_history`).Scan(&last)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return last, nil
|
||||
}
|
||||
|
||||
func insertForwardingEvents(rowSrc pgx.CopyFromSource) error {
|
||||
|
||||
tx, err := pgxPool.Begin(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgxPool.Begin() error: %w", err)
|
||||
}
|
||||
defer tx.Rollback(context.Background())
|
||||
|
||||
_, err = tx.Exec(context.Background(), `
|
||||
CREATE TEMP TABLE tmp_table ON COMMIT DROP AS
|
||||
SELECT *
|
||||
FROM forwarding_history
|
||||
WITH NO DATA;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CREATE TEMP TABLE error: %w", err)
|
||||
}
|
||||
|
||||
count, err := tx.CopyFrom(context.Background(),
|
||||
pgx.Identifier{"tmp_table"},
|
||||
[]string{"timestamp", "chanid_in", "chanid_out", "amt_msat_in", "amt_msat_out"}, rowSrc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CopyFrom() error: %w", err)
|
||||
}
|
||||
log.Printf("count1: %v", count)
|
||||
|
||||
cmdTag, err := tx.Exec(context.Background(), `
|
||||
INSERT INTO forwarding_history
|
||||
SELECT *
|
||||
FROM tmp_table
|
||||
ON CONFLICT DO NOTHING
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("INSERT INTO forwarding_history error: %w", err)
|
||||
}
|
||||
log.Printf("count2: %v", cmdTag.RowsAffected())
|
||||
return tx.Commit(context.Background())
|
||||
}
|
||||
23
server.go
23
server.go
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/breez/lspd/btceclegacy"
|
||||
"github.com/breez/lspd/config"
|
||||
"github.com/breez/lspd/interceptor"
|
||||
"github.com/breez/lspd/lightning"
|
||||
lspdrpc "github.com/breez/lspd/rpc"
|
||||
ecies "github.com/ecies/go/v2"
|
||||
@@ -37,6 +38,7 @@ type server struct {
|
||||
lis net.Listener
|
||||
s *grpc.Server
|
||||
nodes map[string]*node
|
||||
store interceptor.InterceptStore
|
||||
}
|
||||
|
||||
type node struct {
|
||||
@@ -114,7 +116,7 @@ func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymen
|
||||
log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err)
|
||||
return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err)
|
||||
}
|
||||
err = registerPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag)
|
||||
err = s.store.RegisterPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag)
|
||||
if err != nil {
|
||||
log.Printf("RegisterPayment() error: %v", err)
|
||||
return nil, fmt.Errorf("RegisterPayment() error: %w", err)
|
||||
@@ -261,7 +263,12 @@ func (s *server) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lsp
|
||||
return &lspdrpc.Encrypted{Data: encrypted}, nil
|
||||
}
|
||||
|
||||
func NewGrpcServer(configs []*config.NodeConfig, address string, certmagicDomain string) (*server, error) {
|
||||
func NewGrpcServer(
|
||||
configs []*config.NodeConfig,
|
||||
address string,
|
||||
certmagicDomain string,
|
||||
store interceptor.InterceptStore,
|
||||
) (*server, error) {
|
||||
if len(configs) == 0 {
|
||||
return nil, fmt.Errorf("no nodes supplied")
|
||||
}
|
||||
@@ -319,6 +326,7 @@ func NewGrpcServer(configs []*config.NodeConfig, address string, certmagicDomain
|
||||
address: address,
|
||||
certmagicDomain: certmagicDomain,
|
||||
nodes: nodes,
|
||||
store: store,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -409,3 +417,14 @@ func getNode(ctx context.Context) (*node, error) {
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func checkPayment(config *config.NodeConfig, incomingAmountMsat, outgoingAmountMsat int64) error {
|
||||
fees := incomingAmountMsat * config.ChannelFeePermyriad / 10_000 / 1_000 * 1_000
|
||||
if fees < config.ChannelMinimumFeeMsat {
|
||||
fees = config.ChannelMinimumFeeMsat
|
||||
}
|
||||
if incomingAmountMsat-outgoingAmountMsat < fees {
|
||||
return fmt.Errorf("not enough fees")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user