diff --git a/channeld/channeld.c b/channeld/channeld.c index 2cccbfefb..980bb1033 100644 --- a/channeld/channeld.c +++ b/channeld/channeld.c @@ -4045,7 +4045,7 @@ int main(int argc, char *argv[]) peer->have_sigs[LOCAL] = peer->have_sigs[REMOTE] = false; peer->announce_depth_reached = false; peer->channel_local_active = false; - peer->from_master = msg_queue_new(peer); + peer->from_master = msg_queue_new(peer, true); peer->shutdown_sent[LOCAL] = false; peer->shutdown_wrong_funding = NULL; peer->last_update_timestamp = 0; @@ -4053,7 +4053,7 @@ int main(int argc, char *argv[]) #if EXPERIMENTAL_FEATURES peer->stfu = false; peer->stfu_sent[LOCAL] = peer->stfu_sent[REMOTE] = false; - peer->update_queue = msg_queue_new(peer); + peer->update_queue = msg_queue_new(peer, false); #endif /* We send these to HSM to get real signatures; don't have valgrind diff --git a/common/daemon_conn.c b/common/daemon_conn.c index 0a9a1ef25..25bbac2bc 100644 --- a/common/daemon_conn.c +++ b/common/daemon_conn.c @@ -54,7 +54,7 @@ static struct io_plan *daemon_conn_write_next(struct io_conn *conn, } if (msg) { - int fd = msg_extract_fd(msg); + int fd = msg_extract_fd(dc->out, msg); if (fd >= 0) { tal_free(msg); return io_send_fd(conn, fd, true, @@ -82,7 +82,7 @@ bool daemon_conn_sync_flush(struct daemon_conn *dc) /* Flush existing messages. */ while ((msg = msg_dequeue(dc->out)) != NULL) { - int fd = msg_extract_fd(msg); + int fd = msg_extract_fd(dc->out, msg); if (fd >= 0) { tal_free(msg); if (!fdpass_send(daemon_fd, fd)) @@ -125,7 +125,7 @@ struct daemon_conn *daemon_conn_new_(const tal_t *ctx, int fd, dc->outq_empty = outq_empty; dc->arg = arg; dc->msg_in = NULL; - dc->out = msg_queue_new(dc); + dc->out = msg_queue_new(dc, true); dc->conn = io_new_conn(dc, fd, daemon_conn_start, dc); tal_add_destructor2(dc->conn, destroy_dc_from_conn, dc); diff --git a/common/msg_queue.c b/common/msg_queue.c index 120a77a9f..f3926ab67 100644 --- a/common/msg_queue.c +++ b/common/msg_queue.c @@ -5,12 +5,14 @@ #include struct msg_queue { + bool fd_passing; const u8 **q; }; -struct msg_queue *msg_queue_new(const tal_t *ctx) +struct msg_queue *msg_queue_new(const tal_t *ctx, bool fd_passing) { struct msg_queue *q = tal(ctx, struct msg_queue); + q->fd_passing = fd_passing; q->q = tal_arr(q, const u8 *, 0); return q; } @@ -30,13 +32,15 @@ size_t msg_queue_length(const struct msg_queue *q) void msg_enqueue(struct msg_queue *q, const u8 *add) { - assert(fromwire_peektype(add) != MSG_PASS_FD); + if (q->fd_passing) + assert(fromwire_peektype(add) != MSG_PASS_FD); do_enqueue(q, add); } void msg_enqueue_fd(struct msg_queue *q, int fd) { u8 *fdmsg = tal_arr(q, u8, 0); + assert(q->fd_passing); towire_u16(&fdmsg, MSG_PASS_FD); towire_u32(&fdmsg, fd); do_enqueue(q, take(fdmsg)); @@ -56,11 +60,12 @@ const u8 *msg_dequeue(struct msg_queue *q) return msg; } -int msg_extract_fd(const u8 *msg) +int msg_extract_fd(const struct msg_queue *q, const u8 *msg) { const u8 *p = msg + sizeof(u16); size_t len = tal_count(msg) - sizeof(u16); + assert(q->fd_passing); if (fromwire_peektype(msg) != MSG_PASS_FD) return -1; diff --git a/common/msg_queue.h b/common/msg_queue.h index 15c9c66d4..8bdbe1557 100644 --- a/common/msg_queue.h +++ b/common/msg_queue.h @@ -8,8 +8,9 @@ /* Reserved type used to indicate we're actually passing an fd. */ #define MSG_PASS_FD 0xFFFF -/* Allocate a new msg queue. */ -struct msg_queue *msg_queue_new(const tal_t *ctx); +/* Allocate a new msg queue; if we control all msgs we send/receive, + * we can pass fds. Otherwise, set @fd_passing to false. */ +struct msg_queue *msg_queue_new(const tal_t *ctx, bool fd_passing); /* If add is taken(), freed after sending. msg_wake() implied. */ void msg_enqueue(struct msg_queue *q, const u8 *add TAKES); @@ -27,7 +28,7 @@ void msg_wake(const struct msg_queue *q); const u8 *msg_dequeue(struct msg_queue *q); /* Returns -1 if not an fd: close after sending. */ -int msg_extract_fd(const u8 *msg); +int msg_extract_fd(const struct msg_queue *q, const u8 *msg); #define msg_queue_wait(conn, q, next, arg) \ io_out_wait((conn), (q), (next), (arg)) diff --git a/connectd/connectd.c b/connectd/connectd.c index 59492daa1..287c4b5fc 100644 --- a/connectd/connectd.c +++ b/connectd/connectd.c @@ -355,8 +355,8 @@ static struct peer *new_peer(struct daemon *daemon, peer->peer_in = NULL; peer->sent_to_peer = NULL; peer->urgent = false; - peer->peer_outq = msg_queue_new(peer); - peer->subd_outq = msg_queue_new(peer); + peer->peer_outq = msg_queue_new(peer, false); + peer->subd_outq = msg_queue_new(peer, false); #if DEVELOPER peer->dev_writes_enabled = NULL; diff --git a/lightningd/subd.c b/lightningd/subd.c index 17ba6dc09..36c836308 100644 --- a/lightningd/subd.c +++ b/lightningd/subd.c @@ -657,7 +657,7 @@ static struct io_plan *msg_send_next(struct io_conn *conn, struct subd *sd) if (!msg) return msg_queue_wait(conn, sd->outq, msg_send_next, sd); - fd = msg_extract_fd(msg); + fd = msg_extract_fd(sd->outq, msg); if (fd >= 0) { tal_free(msg); return io_send_fd(conn, fd, true, msg_send_next, sd); @@ -741,7 +741,7 @@ static struct subd *new_subd(struct lightningd *ld, sd->errcb = errcb; sd->billboardcb = billboardcb; sd->fds_in = NULL; - sd->outq = msg_queue_new(sd); + sd->outq = msg_queue_new(sd, true); tal_add_destructor(sd, destroy_subd); list_head_init(&sd->reqs); sd->channel = channel;