diff --git a/integration_tests/grpc_test.go b/integration_tests/grpc_test.go index af5bafc..3057603 100644 --- a/integration_tests/grpc_test.go +++ b/integration_tests/grpc_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net" "testing" "github.com/getAlby/lndhub.go/common" @@ -53,7 +54,18 @@ func (suite *GrpcTestSuite) SetupSuite() { suite.invoiceUpdateSubCancelFn = cancel go svc.InvoiceUpdateSubscription(ctx) - go svc.StartGrpcServer(ctx) + //start grpc server + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", svc.Config.GRPCPort)) + if err != nil { + svc.Logger.Fatalf("Failed to start grpc server: %v", err) + } + grpcServer := svc.NewGrpcServer(ctx) + go func() { + err = grpcServer.Serve(lis) + if err != nil { + svc.Logger.Error(err) + } + }() go StartGrpcClient(ctx, svc.Config.GRPCPort, suite.invoiceChan) diff --git a/lib/service/checkpayments.go b/lib/service/checkpayments.go index 76a7538..10e1e00 100644 --- a/lib/service/checkpayments.go +++ b/lib/service/checkpayments.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "fmt" + "sync" "github.com/getAlby/lndhub.go/db/models" "github.com/getsentry/sentry-go" @@ -22,13 +23,19 @@ func (svc *LndhubService) CheckAllPendingOutgoingPayments(ctx context.Context) ( } svc.Logger.Infof("Found %d pending payments", len(pendingPayments)) //call trackoutgoingpaymentstatus for each one + var wg sync.WaitGroup for _, inv := range pendingPayments { + wg.Add(1) //spawn goroutines //https://go.dev/doc/faq#closures_and_goroutines inv := inv svc.Logger.Infof("Spawning tracker for payment with hash %s", inv.RHash) - go svc.TrackOutgoingPaymentstatus(ctx, &inv) + go func() { + svc.TrackOutgoingPaymentstatus(ctx, &inv) + wg.Done() + }() } + wg.Wait() return nil } diff --git a/lib/service/grpc_server.go b/lib/service/grpc_server.go index 1f1369c..1133d01 100644 --- a/lib/service/grpc_server.go +++ b/lib/service/grpc_server.go @@ -2,8 +2,6 @@ package service import ( "context" - "fmt" - "net" "github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/db/models" @@ -13,21 +11,14 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -func (svc *LndhubService) StartGrpcServer(ctx context.Context) { - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", svc.Config.GRPCPort)) - if err != nil { - svc.Logger.Fatalf("Failed to start grpc server: %v", err) - } +func (svc *LndhubService) NewGrpcServer(ctx context.Context) *grpc.Server { s := grpc.NewServer() grpcServer, err := NewGrpcServer(svc, ctx) if err != nil { svc.Logger.Fatalf("Failed to init grpc server, %s", err.Error()) } lndhubrpc.RegisterInvoiceSubscriptionServer(s, grpcServer) - svc.Logger.Infof("gRPC server started at %v", lis.Addr()) - if err := s.Serve(lis); err != nil { - svc.Logger.Fatalf("failed to serve: %v", err) - } + return s } // server is used to implement helloworld.GreeterServer. diff --git a/main.go b/main.go index c96e09e..828933e 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "embed" "fmt" "log" + "net" "net/http" "os" "os/signal" @@ -35,6 +36,7 @@ import ( "github.com/uptrace/bun/migrate" "github.com/ziflex/lecho/v3" "golang.org/x/time/rate" + "google.golang.org/grpc" ) //go:embed templates/index.html @@ -82,13 +84,14 @@ func main() { } // Migrate the DB - ctx := context.Background() + //Todo: use timeout for startupcontext + startupCtx := context.Background() migrator := migrate.NewMigrator(dbConn, migrations.Migrations) - err = migrator.Init(ctx) + err = migrator.Init(startupCtx) if err != nil { logger.Fatalf("Error initializing db migrator: %v", err) } - _, err = migrator.Migrate(ctx) + _, err = migrator.Migrate(startupCtx) if err != nil { logger.Fatalf("Error migrating database: %v", err) } @@ -140,7 +143,7 @@ func main() { if err != nil { e.Logger.Fatalf("Error initializing the LND connection: %v", err) } - getInfo, err := lndClient.GetInfo(ctx, &lnrpc.GetInfoRequest{}) + getInfo, err := lndClient.GetInfo(startupCtx, &lnrpc.GetInfoRequest{}) if err != nil { e.Logger.Fatalf("Error getting node info: %v", err) } @@ -171,32 +174,55 @@ func main() { e.GET("/swagger/*", echoSwagger.WrapHandler) var backgroundWg sync.WaitGroup - ctx, cancelBackgroundRoutines := context.WithCancel(context.Background()) + backGroundCtx, _ := signal.NotifyContext(context.Background(), os.Interrupt) // Subscribe to LND invoice updates in the background backgroundWg.Add(1) go func() { - err = svc.InvoiceUpdateSubscription(ctx) + err = svc.InvoiceUpdateSubscription(backGroundCtx) if err != nil { svc.Logger.Error(err) } + svc.Logger.Info("Invoice routine done") backgroundWg.Done() }() // Check the status of all pending outgoing payments // A goroutine will be spawned for each one - err = svc.CheckAllPendingOutgoingPayments(ctx) - if err != nil { - svc.Logger.Error(err) - } + backgroundWg.Add(1) + go func() { + err = svc.CheckAllPendingOutgoingPayments(backGroundCtx) + if err != nil { + svc.Logger.Error(err) + } + svc.Logger.Info("Pending payment check routines done") + backgroundWg.Done() + }() //Start webhook subscription if svc.Config.WebhookUrl != "" { - go svc.StartWebhookSubscribtion(ctx, svc.Config.WebhookUrl) + backgroundWg.Add(1) + go func() { + svc.StartWebhookSubscribtion(backGroundCtx, svc.Config.WebhookUrl) + svc.Logger.Info("Webhook routine done") + backgroundWg.Done() + }() } + var grpcServer *grpc.Server if svc.Config.EnableGRPC { //start grpc server - go svc.StartGrpcServer(ctx) + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", svc.Config.GRPCPort)) + if err != nil { + svc.Logger.Fatalf("Failed to start grpc server: %v", err) + } + grpcServer = svc.NewGrpcServer(startupCtx) + go func() { + svc.Logger.Infof("Starting grpc server at port %d", svc.Config.GRPCPort) + err = grpcServer.Serve(lis) + if err != nil { + svc.Logger.Error(err) + } + }() } //Start Prometheus server if necessary @@ -224,11 +250,7 @@ func main() { } }() - // Wait for interrupt signal to gracefully shutdown the server with a timeout of 10 seconds. - // Use a buffered channel to avoid missing signals as recommended for signal.Notify - quit := make(chan os.Signal, 1) - signal.Notify(quit, os.Interrupt) - <-quit + <-backGroundCtx.Done() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := e.Shutdown(ctx); err != nil { @@ -239,10 +261,13 @@ func main() { e.Logger.Fatal(err) } } - //cancel and wait for graceful shutdown of background routines - cancelBackgroundRoutines() + if c.EnableGRPC { + grpcServer.Stop() + svc.Logger.Info("GRPC server exited.") + } + //Wait for graceful shutdown of background routines backgroundWg.Wait() - fmt.Println("LNDhub exiting gracefully. Goodbye.") + svc.Logger.Info("LNDhub exiting gracefully. Goodbye.") } func createRateLimitMiddleware(seconds int, burst int) echo.MiddlewareFunc {