diff --git a/db.go b/db.go index 4e40169..b0332ec 100644 --- a/db.go +++ b/db.go @@ -64,18 +64,22 @@ func setFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error { return err } -func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64) error { +func registerPayment(destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error { + var t *string + if tag != "" { + t = &tag + } commandTag, err := pgxPool.Exec(context.Background(), `INSERT INTO - payments (destination, payment_hash, payment_secret, incoming_amount_msat, outgoing_amount_msat) - VALUES ($1, $2, $3, $4, $5) + payments (destination, payment_hash, payment_secret, incoming_amount_msat, outgoing_amount_msat, tag) + VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO NOTHING`, - destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat) - log.Printf("registerPayment(%x, %x, %x, %v, %v) rows: %v err: %v", - destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, commandTag.RowsAffected(), err) + destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, t) + log.Printf("registerPayment(%x, %x, %x, %v, %v, %v) rows: %v err: %v", + destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, tag, commandTag.RowsAffected(), err) if err != nil { - return fmt.Errorf("registerPayment(%x, %x, %x, %v, %v) error: %w", - destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, err) + return fmt.Errorf("registerPayment(%x, %x, %x, %v, %v, %v) error: %w", + destination, paymentHash, paymentSecret, incomingAmountMsat, outgoingAmountMsat, tag, err) } return nil } diff --git a/postgresql/migrations/000009_register_payment_tag.down.sql b/postgresql/migrations/000009_register_payment_tag.down.sql new file mode 100644 index 0000000..a41df51 --- /dev/null +++ b/postgresql/migrations/000009_register_payment_tag.down.sql @@ -0,0 +1 @@ +ALTER TABLE public.payments DROP COLUMN tag; diff --git a/postgresql/migrations/000009_register_payment_tag.up.sql b/postgresql/migrations/000009_register_payment_tag.up.sql new file mode 100644 index 0000000..89f004a --- /dev/null +++ b/postgresql/migrations/000009_register_payment_tag.up.sql @@ -0,0 +1 @@ +ALTER TABLE public.payments ADD tag jsonb NULL; diff --git a/server.go b/server.go index b045d74..0bb13c0 100644 --- a/server.go +++ b/server.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "encoding/hex" + "encoding/json" "fmt" "log" "net" @@ -91,14 +92,27 @@ func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymen log.Printf("proto.Unmarshal(%x) error: %v", data, err) return nil, fmt.Errorf("proto.Unmarshal(%x) error: %w", data, err) } - log.Printf("RegisterPayment - Destination: %x, pi.PaymentHash: %x, pi.PaymentSecret: %x, pi.IncomingAmountMsat: %v, pi.OutgoingAmountMsat: %v", - pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) + log.Printf("RegisterPayment - Destination: %x, pi.PaymentHash: %x, pi.PaymentSecret: %x, pi.IncomingAmountMsat: %v, pi.OutgoingAmountMsat: %v, pi.Tag: %v", + pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag) + + if len(pi.Tag) > 1000 { + return nil, fmt.Errorf("tag too long") + } + + if len(pi.Tag) != 0 { + var tag json.RawMessage + err = json.Unmarshal([]byte(pi.Tag), &tag) + if err != nil { + return nil, fmt.Errorf("tag is not a valid json object") + } + } + err = checkPayment(node.nodeConfig, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) if err != nil { log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) } - err = registerPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) + err = registerPayment(pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag) if err != nil { log.Printf("RegisterPayment() error: %v", err) return nil, fmt.Errorf("RegisterPayment() error: %w", err)