diff --git a/server.go b/server.go index ff84cc6..2577e04 100644 --- a/server.go +++ b/server.go @@ -13,18 +13,18 @@ import ( "github.com/breez/lspd/btceclegacy" lspdrpc "github.com/breez/lspd/rpc" "github.com/golang/protobuf/proto" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/caddyserver/certmagic" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/lightningnetwork/lnd/lnwire" "golang.org/x/sync/singleflight" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" ) const ( @@ -40,7 +40,10 @@ const ( maxInactiveDuration = 45 * 24 * 3600 ) -type server struct{} +type server struct { + lis net.Listener + s *grpc.Server +} var ( client *LndClient @@ -241,21 +244,11 @@ func getNotFakeChannels(nodeID string, channelPoints map[string]uint64) (map[str return r, nil } -func main() { - if len(os.Args) > 1 && os.Args[1] == "genkey" { - p, err := btcec.NewPrivateKey() - if err != nil { - log.Fatalf("btcec.NewPrivateKey() error: %v", err) - } - fmt.Printf("LSPD_PRIVATE_KEY=\"%x\"\n", p.Serialize()) - return - } - - err := pgConnect() - if err != nil { - log.Fatalf("pgConnect() error: %v", err) - } +func NewGrpcServer() *server { + return &server{} +} +func (s *server) Start() error { privateKeyBytes, err := hex.DecodeString(os.Getenv("LSPD_PRIVATE_KEY")) if err != nil { log.Fatalf("hex.DecodeString(os.Getenv(\"LSPD_PRIVATE_KEY\")=%v) error: %v", os.Getenv("LSPD_PRIVATE_KEY"), err) @@ -282,6 +275,51 @@ func main() { } } + srv := grpc.NewServer( + grpc_middleware.WithUnaryServerChain(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if md, ok := metadata.FromIncomingContext(ctx); ok { + for _, auth := range md.Get("authorization") { + if auth == "Bearer "+os.Getenv("TOKEN") { + return handler(ctx, req) + } + } + } + return nil, status.Errorf(codes.PermissionDenied, "Not authorized") + }), + ) + lspdrpc.RegisterChannelOpenerServer(srv, &server{}) + + s.s = srv + s.lis = lis + if err := srv.Serve(lis); err != nil { + return fmt.Errorf("failed to serve: %v", err) + } + + return nil +} + +func (s *server) Stop() { + srv := s.s + if srv != nil { + srv.GracefulStop() + } +} + +func main() { + if len(os.Args) > 1 && os.Args[1] == "genkey" { + p, err := btcec.NewPrivateKey() + if err != nil { + log.Fatalf("btcec.NewPrivateKey() error: %v", err) + } + fmt.Printf("LSPD_PRIVATE_KEY=\"%x\"\n", p.Serialize()) + return + } + + err := pgConnect() + if err != nil { + log.Fatalf("pgConnect() error: %v", err) + } + client = NewLndClient() info, err := client.GetInfo() @@ -300,20 +338,11 @@ func main() { go forwardingHistorySynchronize(client) go channelsSynchronize(client) - s := grpc.NewServer( - grpc_middleware.WithUnaryServerChain(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if md, ok := metadata.FromIncomingContext(ctx); ok { - for _, auth := range md.Get("authorization") { - if auth == "Bearer "+os.Getenv("TOKEN") { - return handler(ctx, req) - } - } - } - return nil, status.Errorf(codes.PermissionDenied, "Not authorized") - }), - ) - lspdrpc.RegisterChannelOpenerServer(s, &server{}) - if err := s.Serve(lis); err != nil { - log.Fatalf("failed to serve: %v", err) + s := NewGrpcServer() + err = s.Start() + if err != nil { + log.Fatalf("%v", err) } + + log.Printf("lspd exited") }