diff --git a/channelAcceptor.go b/channelAcceptor.go index 5656eb5..ff19a46 100644 --- a/channelAcceptor.go +++ b/channelAcceptor.go @@ -11,8 +11,8 @@ import ( log "github.com/sirupsen/logrus" ) -func dispatchChannelAcceptor(ctx context.Context) { - client := ctx.Value(clientKey).(lnrpc.LightningClient) +func (app *app) dispatchChannelAcceptor(ctx context.Context) { + client := app.client // wait group for channel acceptor defer ctx.Value(ctxKeyWaitGroup).(*sync.WaitGroup).Done() @@ -27,10 +27,11 @@ func dispatchChannelAcceptor(ctx context.Context) { err = acceptClient.RecvMsg(&req) if err != nil { log.Errorf(err.Error()) + return } // print the incoming channel request - alias, err := getNodeAlias(ctx, hex.EncodeToString(req.NodePubkey)) + alias, err := app.getNodeAlias(ctx, hex.EncodeToString(req.NodePubkey)) if err != nil { log.Errorf(err.Error()) } diff --git a/helpers.go b/helpers.go index c36f59c..ec5bb3f 100644 --- a/helpers.go +++ b/helpers.go @@ -8,7 +8,11 @@ import ( ) func trimPubKey(pubkey []byte) string { - return fmt.Sprintf("%s...%s", hex.EncodeToString(pubkey)[:6], hex.EncodeToString(pubkey)[len(hex.EncodeToString(pubkey))-6:]) + if len(pubkey) > 12 { + return fmt.Sprintf("%s...%s", hex.EncodeToString(pubkey)[:6], hex.EncodeToString(pubkey)[len(hex.EncodeToString(pubkey))-6:]) + } else { + return hex.EncodeToString(pubkey) + } } func welcome() { diff --git a/htlcInterceptor.go b/htlcInterceptor.go index dae7a51..57f10ef 100644 --- a/htlcInterceptor.go +++ b/htlcInterceptor.go @@ -11,11 +11,10 @@ import ( "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/routing/route" log "github.com/sirupsen/logrus" - "google.golang.org/grpc" ) -func dispatchHTLCAcceptor(ctx context.Context) { - conn := ctx.Value(connKey).(*grpc.ClientConn) +func (app *app) dispatchHTLCAcceptor(ctx context.Context) { + conn := app.conn router := routerrpc.NewRouterClient(conn) // htlc event subscriber, reports on incoming htlc events @@ -25,7 +24,7 @@ func dispatchHTLCAcceptor(ctx context.Context) { } go func() { - err := logHtlcEvents(ctx, stream) + err := app.logHtlcEvents(ctx, stream) if err != nil { log.Error("htlc events error", "err", err) @@ -39,7 +38,7 @@ func dispatchHTLCAcceptor(ctx context.Context) { } go func() { - err := interceptHtlcEvents(ctx, interceptor) + err := app.interceptHtlcEvents(ctx, interceptor) if err != nil { log.Error("interceptor error", "err", err) @@ -49,7 +48,7 @@ func dispatchHTLCAcceptor(ctx context.Context) { log.Info("Listening for incoming HTLCs") } -func logHtlcEvents(ctx context.Context, stream routerrpc.Router_SubscribeHtlcEventsClient) error { +func (app *app) logHtlcEvents(ctx context.Context, stream routerrpc.Router_SubscribeHtlcEventsClient) error { for { event, err := stream.Recv() if err != nil { @@ -78,7 +77,7 @@ func logHtlcEvents(ctx context.Context, stream routerrpc.Router_SubscribeHtlcEve } } -func interceptHtlcEvents(ctx context.Context, interceptor routerrpc.Router_HtlcInterceptorClient) error { +func (app *app) interceptHtlcEvents(ctx context.Context, interceptor routerrpc.Router_HtlcInterceptorClient) error { for { event, err := interceptor.Recv() if err != nil { @@ -87,23 +86,25 @@ func interceptHtlcEvents(ctx context.Context, interceptor routerrpc.Router_HtlcI go func() { // decision for routing decision_chan := make(chan bool, 1) - go htlcInterceptDecision(ctx, event, decision_chan) + go app.htlcInterceptDecision(ctx, event, decision_chan) - channelEdge, err := getPubKeyFromChannel(ctx, event.IncomingCircuitKey.ChanId) + channelEdge, err := app.getPubKeyFromChannel(ctx, event.IncomingCircuitKey.ChanId) if err != nil { log.Error("Error getting pubkey for channel %d", event.IncomingCircuitKey.ChanId) } - alias, err := getNodeAlias(ctx, channelEdge.node1Pub.String()) - if err != nil { - log.Errorf(err.Error()) - } - var forward_info_string string - if alias != "" { - forward_info_string = fmt.Sprintf("from %s (%d sat, htlc_id:%d, chan_id:%d->%d)", alias, event.IncomingAmountMsat/1000, event.IncomingCircuitKey.HtlcId, event.IncomingCircuitKey.ChanId, event.OutgoingRequestedChanId) + var remote_pubkey, alias string + if channelEdge.node1Pub.String() != app.myPubkey { + remote_pubkey = channelEdge.node1Pub.String() } else { - forward_info_string = fmt.Sprintf("(%d sat, htlc_id:%d, chan_id:%d->%d)", event.IncomingAmountMsat/1000, event.IncomingCircuitKey.HtlcId, event.IncomingCircuitKey.ChanId, event.OutgoingRequestedChanId) + remote_pubkey = channelEdge.node2Pub.String() } + alias, err = app.getNodeAlias(ctx, remote_pubkey) + if err != nil { + log.Error("Error getting alias for node %s", remote_pubkey) + } + forward_info_string := fmt.Sprintf("from %s (%d sat, htlc_id:%d, chan_id:%d->%d)", alias, event.IncomingAmountMsat/1000, event.IncomingCircuitKey.HtlcId, event.IncomingCircuitKey.ChanId, event.OutgoingRequestedChanId) + response := &routerrpc.ForwardHtlcInterceptResponse{ IncomingCircuitKey: event.IncomingCircuitKey, } @@ -122,7 +123,7 @@ func interceptHtlcEvents(ctx context.Context, interceptor routerrpc.Router_HtlcI } } -func htlcInterceptDecision(ctx context.Context, event *routerrpc.ForwardHtlcInterceptRequest, decision_chan chan bool) { +func (app *app) htlcInterceptDecision(ctx context.Context, event *routerrpc.ForwardHtlcInterceptRequest, decision_chan chan bool) { var accept bool if Configuration.ForwardMode == "whitelist" { @@ -156,8 +157,8 @@ func htlcInterceptDecision(ctx context.Context, event *routerrpc.ForwardHtlcInte } // Heavily inspired by by Joost Jager's circuitbreaker -func getNodeAlias(ctx context.Context, pubkey string) (string, error) { - client := ctx.Value(clientKey).(lnrpc.LightningClient) +func (app *app) getNodeAlias(ctx context.Context, pubkey string) (string, error) { + client := app.client ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() @@ -175,12 +176,25 @@ func getNodeAlias(ctx context.Context, pubkey string) (string, error) { return info.Node.Alias, nil } +func (app *app) getMyPubkey(ctx context.Context) (string, error) { + client := app.client + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + info, err := client.GetInfo(ctx, &lnrpc.GetInfoRequest{}) + if err != nil { + return "", err + } + + return info.IdentityPubkey, nil +} + type channelEdge struct { node1Pub, node2Pub route.Vertex } -func getPubKeyFromChannel(ctx context.Context, chan_id uint64) (*channelEdge, error) { - client := ctx.Value(clientKey).(lnrpc.LightningClient) +func (app *app) getPubKeyFromChannel(ctx context.Context, chan_id uint64) (*channelEdge, error) { + client := app.client ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() diff --git a/main.go b/main.go index 1da3470..449e42a 100644 --- a/main.go +++ b/main.go @@ -20,10 +20,11 @@ const ( ctxKeyWaitGroup key = iota ) -type ContextKey string - -var connKey ContextKey = "connKey" -var clientKey ContextKey = "clientKey" +type app struct { + client lnrpc.LightningClient + conn *grpc.ClientConn + myPubkey string +} // gets the lnd grpc connection func getClientConnection(ctx context.Context) (*grpc.ClientConn, error) { @@ -59,24 +60,36 @@ func getClientConnection(ctx context.Context) (*grpc.ClientConn, error) { func main() { ctx := context.Background() - conn, err := getClientConnection(ctx) - if err != nil { - panic(err) + for { + conn, err := getClientConnection(ctx) + if err != nil { + log.Errorf("Could not connect to lnd: %s", err) + return + } + client := lnrpc.NewLightningClient(conn) + + app := app{ + client: client, + conn: conn, + } + app.myPubkey, err = app.getMyPubkey(ctx) + if err != nil { + log.Errorf("Could not get my pubkey: %s", err) + return + } + + var wg sync.WaitGroup + ctx = context.WithValue(ctx, ctxKeyWaitGroup, &wg) + wg.Add(1) + + // channel acceptor + go app.dispatchChannelAcceptor(ctx) + + // htlc acceptor + go app.dispatchHTLCAcceptor(ctx) + + wg.Wait() + log.Info("All routines stopped. Waiting for new connection.") } - client := lnrpc.NewLightningClient(conn) - var wg sync.WaitGroup - ctx = context.WithValue(ctx, ctxKeyWaitGroup, &wg) - wg.Add(1) - - ctx = context.WithValue(ctx, clientKey, client) - ctx = context.WithValue(ctx, connKey, conn) - - // channel acceptor - go dispatchChannelAcceptor(ctx) - - // htlc acceptor - go dispatchHTLCAcceptor(ctx) - - wg.Wait() }