diff --git a/db/db.go b/db/db.go index a194ae7..7f0a437 100644 --- a/db/db.go +++ b/db/db.go @@ -4,19 +4,25 @@ import ( "database/sql" "fmt" "strings" + "time" + "github.com/getAlby/lndhub.go/lib/service" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/driver/pgdriver" "github.com/uptrace/bun/extra/bundebug" ) -func Open(dsn string) (*bun.DB, error) { +func Open(config *service.Config) (*bun.DB, error) { var db *bun.DB + dsn := config.DatabaseUri switch { case strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") || strings.HasPrefix(dsn, "unix://"): dbConn := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) db = bun.NewDB(dbConn, pgdialect.New()) + db.SetMaxOpenConns(config.DatabaseMaxConns) + db.SetMaxIdleConns(config.DatabaseMaxIdleConns) + db.SetConnMaxLifetime(time.Duration(config.DatabaseConnMaxLifetime) * time.Second) default: return nil, fmt.Errorf("Invalid database connection string %s, only (postgres|postgresql|unix):// is supported", dsn) } diff --git a/integration_tests/grpc_test.go b/integration_tests/grpc_test.go index af5bafc..3057603 100644 --- a/integration_tests/grpc_test.go +++ b/integration_tests/grpc_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net" "testing" "github.com/getAlby/lndhub.go/common" @@ -53,7 +54,18 @@ func (suite *GrpcTestSuite) SetupSuite() { suite.invoiceUpdateSubCancelFn = cancel go svc.InvoiceUpdateSubscription(ctx) - go svc.StartGrpcServer(ctx) + //start grpc server + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", svc.Config.GRPCPort)) + if err != nil { + svc.Logger.Fatalf("Failed to start grpc server: %v", err) + } + grpcServer := svc.NewGrpcServer(ctx) + go func() { + err = grpcServer.Serve(lis) + if err != nil { + svc.Logger.Error(err) + } + }() go StartGrpcClient(ctx, svc.Config.GRPCPort, suite.invoiceChan) diff --git a/integration_tests/hodl_invoice_test.go b/integration_tests/hodl_invoice_test.go index c731dd7..7b73fa8 100644 --- a/integration_tests/hodl_invoice_test.go +++ b/integration_tests/hodl_invoice_test.go @@ -115,8 +115,12 @@ func (suite *HodlInvoiceSuite) TestHodlInvoice() { assert.Equal(suite.T(), common.InvoiceStateInitialized, inv.State) //start payment checking loop - err = suite.service.CheckAllPendingOutgoingPayments(context.Background()) - assert.NoError(suite.T(), err) + go func() { + err = suite.service.CheckAllPendingOutgoingPayments(context.Background()) + assert.NoError(suite.T(), err) + }() + //wait a bit for routine to start + time.Sleep(time.Second) //send cancel invoice with lnrpc.payment suite.hodlLND.SettlePayment(lnrpc.Payment{ PaymentHash: hex.EncodeToString(invoice.RHash), @@ -177,8 +181,12 @@ func (suite *HodlInvoiceSuite) TestHodlInvoice() { assert.NoError(suite.T(), err) assert.Equal(suite.T(), common.InvoiceStateInitialized, inv.State) //start payment checking loop - err = suite.service.CheckAllPendingOutgoingPayments(context.Background()) - assert.NoError(suite.T(), err) + go func() { + err = suite.service.CheckAllPendingOutgoingPayments(context.Background()) + assert.NoError(suite.T(), err) + }() + //wait a bit for routine to start + time.Sleep(time.Second) //send settle invoice with lnrpc.payment suite.hodlLND.SettlePayment(lnrpc.Payment{ PaymentHash: hex.EncodeToString(invoice.RHash), diff --git a/integration_tests/util.go b/integration_tests/util.go index 3c4bbd4..c270141 100644 --- a/integration_tests/util.go +++ b/integration_tests/util.go @@ -47,12 +47,15 @@ const ( func LndHubTestServiceInit(lndClientMock lnd.LightningClientWrapper) (svc *service.LndhubService, err error) { dbUri := "postgresql://user:password@localhost/lndhub?sslmode=disable" c := &service.Config{ - DatabaseUri: dbUri, - JWTSecret: []byte("SECRET"), - JWTAccessTokenExpiry: 3600, - JWTRefreshTokenExpiry: 3600, - LNDAddress: mockLNDAddress, - LNDMacaroonHex: mockLNDMacaroonHex, + DatabaseUri: dbUri, + DatabaseMaxConns: 1, + DatabaseMaxIdleConns: 1, + DatabaseConnMaxLifetime: 10, + JWTSecret: []byte("SECRET"), + JWTAccessTokenExpiry: 3600, + JWTRefreshTokenExpiry: 3600, + LNDAddress: mockLNDAddress, + LNDMacaroonHex: mockLNDMacaroonHex, } rabbitmqUri, ok := os.LookupEnv("RABBITMQ_URI") @@ -61,7 +64,7 @@ func LndHubTestServiceInit(lndClientMock lnd.LightningClientWrapper) (svc *servi c.RabbitMQInvoiceExchange = "test_lndhub_invoices" } - dbConn, err := db.Open(c.DatabaseUri) + dbConn, err := db.Open(c) if err != nil { return nil, fmt.Errorf("failed to connect to database: %w", err) } @@ -95,7 +98,7 @@ func LndHubTestServiceInit(lndClientMock lnd.LightningClientWrapper) (svc *servi } func clearTable(svc *service.LndhubService, tableName string) error { - dbConn, err := db.Open(svc.Config.DatabaseUri) + dbConn, err := db.Open(svc.Config) if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } diff --git a/lib/service/checkpayments.go b/lib/service/checkpayments.go index 76a7538..10e1e00 100644 --- a/lib/service/checkpayments.go +++ b/lib/service/checkpayments.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "fmt" + "sync" "github.com/getAlby/lndhub.go/db/models" "github.com/getsentry/sentry-go" @@ -22,13 +23,19 @@ func (svc *LndhubService) CheckAllPendingOutgoingPayments(ctx context.Context) ( } svc.Logger.Infof("Found %d pending payments", len(pendingPayments)) //call trackoutgoingpaymentstatus for each one + var wg sync.WaitGroup for _, inv := range pendingPayments { + wg.Add(1) //spawn goroutines //https://go.dev/doc/faq#closures_and_goroutines inv := inv svc.Logger.Infof("Spawning tracker for payment with hash %s", inv.RHash) - go svc.TrackOutgoingPaymentstatus(ctx, &inv) + go func() { + svc.TrackOutgoingPaymentstatus(ctx, &inv) + wg.Done() + }() } + wg.Wait() return nil } diff --git a/lib/service/config.go b/lib/service/config.go index 74b6333..e607325 100644 --- a/lib/service/config.go +++ b/lib/service/config.go @@ -7,6 +7,9 @@ import ( type Config struct { DatabaseUri string `envconfig:"DATABASE_URI" required:"true"` + DatabaseMaxConns int `envconfig:"DATABASE_MAX_CONNS" default:"10"` + DatabaseMaxIdleConns int `envconfig:"DATABASE_MAX_IDLE_CONNS" default:"5"` + DatabaseConnMaxLifetime int `envconfig:"DATABASE_CONN_MAX_LIFETIME" default:"1800"` // 30 minutes SentryDSN string `envconfig:"SENTRY_DSN"` SentryTracesSampleRate float64 `envconfig:"SENTRY_TRACES_SAMPLE_RATE"` LogFilePath string `envconfig:"LOG_FILE_PATH"` diff --git a/lib/service/grpc_server.go b/lib/service/grpc_server.go index 1f1369c..1133d01 100644 --- a/lib/service/grpc_server.go +++ b/lib/service/grpc_server.go @@ -2,8 +2,6 @@ package service import ( "context" - "fmt" - "net" "github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/db/models" @@ -13,21 +11,14 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -func (svc *LndhubService) StartGrpcServer(ctx context.Context) { - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", svc.Config.GRPCPort)) - if err != nil { - svc.Logger.Fatalf("Failed to start grpc server: %v", err) - } +func (svc *LndhubService) NewGrpcServer(ctx context.Context) *grpc.Server { s := grpc.NewServer() grpcServer, err := NewGrpcServer(svc, ctx) if err != nil { svc.Logger.Fatalf("Failed to init grpc server, %s", err.Error()) } lndhubrpc.RegisterInvoiceSubscriptionServer(s, grpcServer) - svc.Logger.Infof("gRPC server started at %v", lis.Addr()) - if err := s.Serve(lis); err != nil { - svc.Logger.Fatalf("failed to serve: %v", err) - } + return s } // server is used to implement helloworld.GreeterServer. diff --git a/lib/service/invoicesubscription.go b/lib/service/invoicesubscription.go index 7440ffb..8545cc6 100644 --- a/lib/service/invoicesubscription.go +++ b/lib/service/invoicesubscription.go @@ -243,11 +243,6 @@ func (svc *LndhubService) InvoiceUpdateSubscription(ctx context.Context) error { if err != nil { svc.Logger.Errorf("Error processing invoice update subscription: %v", err) sentry.CaptureException(err) - // TODO: close the stream somehoe before retrying? - // Wait 30 seconds and try to reconnect - // TODO: implement some backoff - time.Sleep(30 * time.Second) - invoiceSubscriptionStream, _ = svc.ConnectInvoiceSubscription(ctx) continue } @@ -256,7 +251,7 @@ func (svc *LndhubService) InvoiceUpdateSubscription(ctx context.Context) error { // Processing open invoices here could cause a race condition: // We could get this notification faster than we finish the AddInvoice call if rawInvoice.State == lnrpc.Invoice_OPEN { - svc.Logger.Infof("Invoice state is open. Ignoring update. r_hash:%v", hex.EncodeToString(rawInvoice.RHash)) + svc.Logger.Debugf("Invoice state is open. Ignoring update. r_hash:%v", hex.EncodeToString(rawInvoice.RHash)) continue } diff --git a/main.go b/main.go index 7ebdb3d..e85f238 100644 --- a/main.go +++ b/main.go @@ -5,9 +5,11 @@ import ( "embed" "fmt" "log" + "net" "net/http" "os" "os/signal" + "sync" "time" cache "github.com/SporkHubr/echo-http-cache" @@ -34,6 +36,7 @@ import ( "github.com/uptrace/bun/migrate" "github.com/ziflex/lecho/v3" "golang.org/x/time/rate" + "google.golang.org/grpc" ) //go:embed templates/index.html @@ -75,19 +78,20 @@ func main() { logger := lib.Logger(c.LogFilePath) // Open a DB connection based on the configured DATABASE_URI - dbConn, err := db.Open(c.DatabaseUri) + dbConn, err := db.Open(c) if err != nil { logger.Fatalf("Error initializing db connection: %v", err) } // Migrate the DB - ctx := context.Background() + //Todo: use timeout for startupcontext + startupCtx := context.Background() migrator := migrate.NewMigrator(dbConn, migrations.Migrations) - err = migrator.Init(ctx) + err = migrator.Init(startupCtx) if err != nil { logger.Fatalf("Error initializing db migrator: %v", err) } - _, err = migrator.Migrate(ctx) + _, err = migrator.Migrate(startupCtx) if err != nil { logger.Fatalf("Error migrating database: %v", err) } @@ -139,7 +143,7 @@ func main() { if err != nil { e.Logger.Fatalf("Error initializing the LND connection: %v", err) } - getInfo, err := lndClient.GetInfo(ctx, &lnrpc.GetInfoRequest{}) + getInfo, err := lndClient.GetInfo(startupCtx, &lnrpc.GetInfoRequest{}) if err != nil { e.Logger.Fatalf("Error getting node info: %v", err) } @@ -169,21 +173,39 @@ func main() { docs.SwaggerInfo.Host = c.Host e.GET("/swagger/*", echoSwagger.WrapHandler) + var backgroundWg sync.WaitGroup + backGroundCtx, _ := signal.NotifyContext(context.Background(), os.Interrupt) // Subscribe to LND invoice updates in the background - go svc.InvoiceUpdateSubscription(context.Background()) + backgroundWg.Add(1) + go func() { + err = svc.InvoiceUpdateSubscription(backGroundCtx) + if err != nil { + svc.Logger.Error(err) + } + svc.Logger.Info("Invoice routine done") + backgroundWg.Done() + }() // Check the status of all pending outgoing payments // A goroutine will be spawned for each one - err = svc.CheckAllPendingOutgoingPayments(context.Background()) - if err != nil { - svc.Logger.Error(err) - } + backgroundWg.Add(1) + go func() { + err = svc.CheckAllPendingOutgoingPayments(backGroundCtx) + if err != nil { + svc.Logger.Error(err) + } + svc.Logger.Info("Pending payment check routines done") + backgroundWg.Done() + }() //Start webhook subscription if svc.Config.WebhookUrl != "" { - webhookCtx, cancelWebhook := context.WithCancel(context.Background()) - go svc.StartWebhookSubscribtion(webhookCtx, svc.Config.WebhookUrl) - defer cancelWebhook() + backgroundWg.Add(1) + go func() { + svc.StartWebhookSubscribtion(backGroundCtx, svc.Config.WebhookUrl) + svc.Logger.Info("Webhook routine done") + backgroundWg.Done() + }() } //Start rabbit publisher if svc.Config.RabbitMQUri != "" { @@ -192,11 +214,21 @@ func main() { defer cancelRabbit() } + var grpcServer *grpc.Server if svc.Config.EnableGRPC { //start grpc server - grpcContext, grpcCancel := context.WithCancel(context.Background()) - go svc.StartGrpcServer(grpcContext) - defer grpcCancel() + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", svc.Config.GRPCPort)) + if err != nil { + svc.Logger.Fatalf("Failed to start grpc server: %v", err) + } + grpcServer = svc.NewGrpcServer(startupCtx) + go func() { + svc.Logger.Infof("Starting grpc server at port %d", svc.Config.GRPCPort) + err = grpcServer.Serve(lis) + if err != nil { + svc.Logger.Error(err) + } + }() } //Start Prometheus server if necessary @@ -224,11 +256,7 @@ func main() { } }() - // Wait for interrupt signal to gracefully shutdown the server with a timeout of 10 seconds. - // Use a buffered channel to avoid missing signals as recommended for signal.Notify - quit := make(chan os.Signal, 1) - signal.Notify(quit, os.Interrupt) - <-quit + <-backGroundCtx.Done() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := e.Shutdown(ctx); err != nil { @@ -239,7 +267,13 @@ func main() { e.Logger.Fatal(err) } } - + if c.EnableGRPC { + grpcServer.Stop() + svc.Logger.Info("GRPC server exited.") + } + //Wait for graceful shutdown of background routines + backgroundWg.Wait() + svc.Logger.Info("LNDhub exiting gracefully. Goodbye.") } func createRateLimitMiddleware(seconds int, burst int) echo.MiddlewareFunc {