set the tag field on registerpayment

This commit is contained in:
Jesse de Wit
2023-02-13 10:07:10 +01:00
parent ef3a001d54
commit 8e2c9bd9ce
4 changed files with 31 additions and 11 deletions

20
db.go
View File

@@ -64,18 +64,22 @@ func setFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error {
return err return err
} }
func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64) error { func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error {
var t *string
if tag != "" {
t = &tag
}
commandTag, err := pgxPool.Exec(context.Background(), commandTag, err := pgxPool.Exec(context.Background(),
`INSERT INTO `INSERT INTO
payments (destination, payment_hash, payment_secret, incoming_amount_msat, outgoing_amount_msat) payments (destination, payment_hash, payment_secret, incoming_amount_msat, outgoing_amount_msat, tag)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT DO NOTHING`, ON CONFLICT DO NOTHING`,
destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat) destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, t)
log.Printf("registerPayment(%x, %x, %x, %v, %v) rows: %v err: %v", log.Printf("registerPayment(%x, %x, %x, %v, %v, %v) rows: %v err: %v",
destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, commandTag.RowsAffected(), err) destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, tag, commandTag.RowsAffected(), err)
if err != nil { if err != nil {
return fmt.Errorf("registerPayment(%x, %x, %x, %v, %v) error: %w", return fmt.Errorf("registerPayment(%x, %x, %x, %v, %v, %v) error: %w",
destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, err) destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, tag, err)
} }
return nil return nil
} }

View File

@@ -0,0 +1 @@
ALTER TABLE public.payments DROP COLUMN tag;

View File

@@ -0,0 +1 @@
ALTER TABLE public.payments ADD tag jsonb NULL;

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net" "net"
@@ -91,14 +92,27 @@ func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymen
log.Printf("proto.Unmarshal(%x) error: %v", data, err) log.Printf("proto.Unmarshal(%x) error: %v", data, err)
return nil, fmt.Errorf("proto.Unmarshal(%x) error: %w", data, err) return nil, fmt.Errorf("proto.Unmarshal(%x) error: %w", data, err)
} }
log.Printf("RegisterPayment - Destination: %x, pi.PaymentHash: %x, pi.PaymentSecret: %x, pi.IncomingAmountMsat: %v, pi.OutgoingAmountMsat: %v", log.Printf("RegisterPayment - Destination: %x, pi.PaymentHash: %x, pi.PaymentSecret: %x, pi.IncomingAmountMsat: %v, pi.OutgoingAmountMsat: %v, pi.Tag: %v",
pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag)
if len(pi.Tag) > 1000 {
return nil, fmt.Errorf("tag too long")
}
if len(pi.Tag) != 0 {
var tag json.RawMessage
err = json.Unmarshal([]byte(pi.Tag), &tag)
if err != nil {
return nil, fmt.Errorf("tag is not a valid json object")
}
}
err = checkPayment(node.nodeConfig, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) err = checkPayment(node.nodeConfig, pi.IncomingAmountMsat, pi.OutgoingAmountMsat)
if err != nil { if err != nil {
log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err)
return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err)
} }
err = registerPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) err = registerPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag)
if err != nil { if err != nil {
log.Printf("RegisterPayment() error: %v", err) log.Printf("RegisterPayment() error: %v", err)
return nil, fmt.Errorf("RegisterPayment() error: %w", err) return nil, fmt.Errorf("RegisterPayment() error: %w", err)