diff --git a/cln_plugin/server.go b/cln_plugin/server.go index e0901a5..58cba3f 100644 --- a/cln_plugin/server.go +++ b/cln_plugin/server.go @@ -2,7 +2,6 @@ package cln_plugin import ( "fmt" - "io" "log" "net" "sync" @@ -12,13 +11,6 @@ import ( grpc "google.golang.org/grpc" ) -// A subscription represents a grpc client that is connected to the server. -type subscription struct { - stream proto.ClnPlugin_HtlcStreamServer - done chan struct{} - err chan error -} - // Internal htlc_accepted message meant for the sendQueue. type htlcAcceptedMsg struct { id string @@ -38,11 +30,11 @@ type server struct { subscriberTimeout time.Duration grpcServer *grpc.Server mtx sync.Mutex - subscription *subscription + stream proto.ClnPlugin_HtlcStreamServer newSubscriber chan struct{} started chan struct{} - startError chan error done chan struct{} + startError chan error sendQueue chan *htlcAcceptedMsg recvQueue chan *htlcResultMsg } @@ -114,9 +106,10 @@ func (s *server) Stop() { return } - close(s.done) s.grpcServer.Stop() s.grpcServer = nil + + close(s.done) log.Printf("Server stopped.") } @@ -125,7 +118,7 @@ func (s *server) Stop() { // from or to the subscriber, the subscription is closed. func (s *server) HtlcStream(stream proto.ClnPlugin_HtlcStreamServer) error { s.mtx.Lock() - if s.subscription == nil { + if s.stream == nil { log.Printf("Got a new HTLC stream subscription request.") } else { s.mtx.Unlock() @@ -134,12 +127,7 @@ func (s *server) HtlcStream(stream proto.ClnPlugin_HtlcStreamServer) error { return fmt.Errorf("already subscribed") } - sb := &subscription{ - stream: stream, - done: make(chan struct{}), - err: make(chan error, 1), - } - s.subscription = sb + s.stream = stream // Notify listeners that a new subscriber is active. Replace the chan with // a new one immediately in case this subscriber is dropped later. @@ -147,27 +135,15 @@ func (s *server) HtlcStream(stream proto.ClnPlugin_HtlcStreamServer) error { s.newSubscriber = make(chan struct{}) s.mtx.Unlock() - defer func() { - // When the HtlcStream function returns, that means the subscriber will - // be gone. Cleanup the subscription so we'll be ready to accept a new - // one later. - s.removeSubscriptionIfUnchanged(sb, nil) - }() + <-stream.Context().Done() + log.Printf("HtlcStream context is done. Return: %v", stream.Context().Err()) - select { - case <-s.done: - log.Printf("HTLC server signalled done. Return EOF.") - return io.EOF - case <-sb.done: - log.Printf("HTLC stream signalled done. Return EOF.") - return io.EOF - case err := <-sb.err: - log.Printf("HTLC stream signalled error. Return %v", err) - return err - case <-stream.Context().Done(): - log.Printf("HtlcStream context is done. Return: %v", stream.Context().Err()) - return stream.Context().Err() - } + // Remove the subscriber. + s.mtx.Lock() + s.stream = nil + s.mtx.Unlock() + + return stream.Context().Err() } // Enqueues a htlc_accepted message for send to the grpc client. @@ -212,13 +188,13 @@ func (s *server) listenHtlcRequests() { func (s *server) handleHtlcAccepted(msg *htlcAcceptedMsg) { for { s.mtx.Lock() - sb := s.subscription + stream := s.stream ns := s.newSubscriber s.mtx.Unlock() // If there is no active subscription, wait until there is a new // subscriber, or the message times out. - if sb == nil { + if stream == nil { select { case <-s.done: log.Printf("handleHtlcAccepted received server done. Stop processing.") @@ -234,6 +210,10 @@ func (s *server) handleHtlcAccepted(msg *htlcAcceptedMsg) { s.subscriberTimeout, msg.htlc, ) + + // If the subscriber timeout expires while holding the htlc + // we short circuit the htlc by sending the default result + // (continue) to cln. s.recvQueue <- &htlcResultMsg{ id: msg.id, result: s.defaultResult(), @@ -244,7 +224,7 @@ func (s *server) handleHtlcAccepted(msg *htlcAcceptedMsg) { } // There is a subscriber. Attempt to send the htlc_accepted message. - err := sb.stream.Send(&proto.HtlcAccepted{ + err := stream.Send(&proto.HtlcAccepted{ Correlationid: msg.id, Onion: &proto.Onion{ Payload: msg.htlc.Onion.Payload, @@ -272,9 +252,10 @@ func (s *server) handleHtlcAccepted(msg *htlcAcceptedMsg) { // If we end up here, there was an error sending the message to the // grpc client. - log.Printf("Error sending htlc_accepted message to subscriber. "+ - "Removing subscription: %v", err) - s.removeSubscriptionIfUnchanged(sb, err) + // TODO: If the Send errors, but the context is not done, this will + // currently retry immediately. Check whether the context is really + // done on an error! + log.Printf("Error sending htlc_accepted message to subscriber. Retrying: %v", err) } } @@ -306,11 +287,11 @@ func (s *server) recv() *proto.HtlcResolution { // surprise us. The newSubscriber chan is swapped whenever a new // subscriber arrives. s.mtx.Lock() - sb := s.subscription + stream := s.stream ns := s.newSubscriber s.mtx.Unlock() - if sb == nil { + if stream == nil { log.Printf("Got no subscribers for receive. Waiting for subscriber.") select { case <-s.done: @@ -323,7 +304,7 @@ func (s *server) recv() *proto.HtlcResolution { } // There is a subscription active. Attempt to receive a message. - r, err := sb.stream.Recv() + r, err := stream.Recv() if err == nil { log.Printf("Received HtlcResolution %+v", r) return r @@ -332,32 +313,13 @@ func (s *server) recv() *proto.HtlcResolution { // Receiving the message failed, so the subscription is broken. Remove // it if it hasn't been updated already. We'll try receiving again in // the next iteration of the for loop. - log.Printf("Recv() errored, removing subscription: %v", err) - s.removeSubscriptionIfUnchanged(sb, err) + // TODO: If the Recv errors, but the context is not done, this will + // currently retry immediately. Check whether the context is really + // done on an error! + log.Printf("Recv() errored, Retrying: %v", err) } } -// Stops and removes the subscription if this is the currently active -// subscription. If the subscription was changed in the meantime, this function -// does nothing. -func (s *server) removeSubscriptionIfUnchanged(sb *subscription, err error) { - s.mtx.Lock() - // If the subscription reference hasn't changed yet in the meantime, kill it. - if s.subscription == sb { - if err == nil { - log.Printf("Removing active subscription without error.") - close(sb.done) - } else { - log.Printf("Removing active subscription with error: %v", err) - sb.err <- err - } - s.subscription = nil - } else { - log.Printf("removeSubscriptionIfUnchanged: Subscription already removed.") - } - s.mtx.Unlock() -} - // Maps a grpc result to the corresponding result for cln. The cln message // is a raw json message, so it's easiest to use a map directly. func (s *server) mapResult(outcome interface{}) interface{} {