msg_queue: don't allow magic MSG_PASS_FD message for peers.

msg_queue was originally designed for inter-daemon comms, and so it has
a special mechanism to mark that we're trying to send an fd.  Unfortunately,
a peer could also send such a message, confusing us!

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2022-01-11 11:46:18 +10:30
parent a93c49ca65
commit d51fb5207a
6 changed files with 21 additions and 15 deletions

View File

@@ -4045,7 +4045,7 @@ int main(int argc, char *argv[])
peer->have_sigs[LOCAL] = peer->have_sigs[REMOTE] = false; peer->have_sigs[LOCAL] = peer->have_sigs[REMOTE] = false;
peer->announce_depth_reached = false; peer->announce_depth_reached = false;
peer->channel_local_active = false; peer->channel_local_active = false;
peer->from_master = msg_queue_new(peer); peer->from_master = msg_queue_new(peer, true);
peer->shutdown_sent[LOCAL] = false; peer->shutdown_sent[LOCAL] = false;
peer->shutdown_wrong_funding = NULL; peer->shutdown_wrong_funding = NULL;
peer->last_update_timestamp = 0; peer->last_update_timestamp = 0;
@@ -4053,7 +4053,7 @@ int main(int argc, char *argv[])
#if EXPERIMENTAL_FEATURES #if EXPERIMENTAL_FEATURES
peer->stfu = false; peer->stfu = false;
peer->stfu_sent[LOCAL] = peer->stfu_sent[REMOTE] = false; peer->stfu_sent[LOCAL] = peer->stfu_sent[REMOTE] = false;
peer->update_queue = msg_queue_new(peer); peer->update_queue = msg_queue_new(peer, false);
#endif #endif
/* We send these to HSM to get real signatures; don't have valgrind /* We send these to HSM to get real signatures; don't have valgrind

View File

@@ -54,7 +54,7 @@ static struct io_plan *daemon_conn_write_next(struct io_conn *conn,
} }
if (msg) { if (msg) {
int fd = msg_extract_fd(msg); int fd = msg_extract_fd(dc->out, msg);
if (fd >= 0) { if (fd >= 0) {
tal_free(msg); tal_free(msg);
return io_send_fd(conn, fd, true, return io_send_fd(conn, fd, true,
@@ -82,7 +82,7 @@ bool daemon_conn_sync_flush(struct daemon_conn *dc)
/* Flush existing messages. */ /* Flush existing messages. */
while ((msg = msg_dequeue(dc->out)) != NULL) { while ((msg = msg_dequeue(dc->out)) != NULL) {
int fd = msg_extract_fd(msg); int fd = msg_extract_fd(dc->out, msg);
if (fd >= 0) { if (fd >= 0) {
tal_free(msg); tal_free(msg);
if (!fdpass_send(daemon_fd, fd)) if (!fdpass_send(daemon_fd, fd))
@@ -125,7 +125,7 @@ struct daemon_conn *daemon_conn_new_(const tal_t *ctx, int fd,
dc->outq_empty = outq_empty; dc->outq_empty = outq_empty;
dc->arg = arg; dc->arg = arg;
dc->msg_in = NULL; dc->msg_in = NULL;
dc->out = msg_queue_new(dc); dc->out = msg_queue_new(dc, true);
dc->conn = io_new_conn(dc, fd, daemon_conn_start, dc); dc->conn = io_new_conn(dc, fd, daemon_conn_start, dc);
tal_add_destructor2(dc->conn, destroy_dc_from_conn, dc); tal_add_destructor2(dc->conn, destroy_dc_from_conn, dc);

View File

@@ -5,12 +5,14 @@
#include <wire/wire.h> #include <wire/wire.h>
struct msg_queue { struct msg_queue {
bool fd_passing;
const u8 **q; const u8 **q;
}; };
struct msg_queue *msg_queue_new(const tal_t *ctx) struct msg_queue *msg_queue_new(const tal_t *ctx, bool fd_passing)
{ {
struct msg_queue *q = tal(ctx, struct msg_queue); struct msg_queue *q = tal(ctx, struct msg_queue);
q->fd_passing = fd_passing;
q->q = tal_arr(q, const u8 *, 0); q->q = tal_arr(q, const u8 *, 0);
return q; return q;
} }
@@ -30,6 +32,7 @@ size_t msg_queue_length(const struct msg_queue *q)
void msg_enqueue(struct msg_queue *q, const u8 *add) void msg_enqueue(struct msg_queue *q, const u8 *add)
{ {
if (q->fd_passing)
assert(fromwire_peektype(add) != MSG_PASS_FD); assert(fromwire_peektype(add) != MSG_PASS_FD);
do_enqueue(q, add); do_enqueue(q, add);
} }
@@ -37,6 +40,7 @@ void msg_enqueue(struct msg_queue *q, const u8 *add)
void msg_enqueue_fd(struct msg_queue *q, int fd) void msg_enqueue_fd(struct msg_queue *q, int fd)
{ {
u8 *fdmsg = tal_arr(q, u8, 0); u8 *fdmsg = tal_arr(q, u8, 0);
assert(q->fd_passing);
towire_u16(&fdmsg, MSG_PASS_FD); towire_u16(&fdmsg, MSG_PASS_FD);
towire_u32(&fdmsg, fd); towire_u32(&fdmsg, fd);
do_enqueue(q, take(fdmsg)); do_enqueue(q, take(fdmsg));
@@ -56,11 +60,12 @@ const u8 *msg_dequeue(struct msg_queue *q)
return msg; return msg;
} }
int msg_extract_fd(const u8 *msg) int msg_extract_fd(const struct msg_queue *q, const u8 *msg)
{ {
const u8 *p = msg + sizeof(u16); const u8 *p = msg + sizeof(u16);
size_t len = tal_count(msg) - sizeof(u16); size_t len = tal_count(msg) - sizeof(u16);
assert(q->fd_passing);
if (fromwire_peektype(msg) != MSG_PASS_FD) if (fromwire_peektype(msg) != MSG_PASS_FD)
return -1; return -1;

View File

@@ -8,8 +8,9 @@
/* Reserved type used to indicate we're actually passing an fd. */ /* Reserved type used to indicate we're actually passing an fd. */
#define MSG_PASS_FD 0xFFFF #define MSG_PASS_FD 0xFFFF
/* Allocate a new msg queue. */ /* Allocate a new msg queue; if we control all msgs we send/receive,
struct msg_queue *msg_queue_new(const tal_t *ctx); * we can pass fds. Otherwise, set @fd_passing to false. */
struct msg_queue *msg_queue_new(const tal_t *ctx, bool fd_passing);
/* If add is taken(), freed after sending. msg_wake() implied. */ /* If add is taken(), freed after sending. msg_wake() implied. */
void msg_enqueue(struct msg_queue *q, const u8 *add TAKES); void msg_enqueue(struct msg_queue *q, const u8 *add TAKES);
@@ -27,7 +28,7 @@ void msg_wake(const struct msg_queue *q);
const u8 *msg_dequeue(struct msg_queue *q); const u8 *msg_dequeue(struct msg_queue *q);
/* Returns -1 if not an fd: close after sending. */ /* Returns -1 if not an fd: close after sending. */
int msg_extract_fd(const u8 *msg); int msg_extract_fd(const struct msg_queue *q, const u8 *msg);
#define msg_queue_wait(conn, q, next, arg) \ #define msg_queue_wait(conn, q, next, arg) \
io_out_wait((conn), (q), (next), (arg)) io_out_wait((conn), (q), (next), (arg))

View File

@@ -355,8 +355,8 @@ static struct peer *new_peer(struct daemon *daemon,
peer->peer_in = NULL; peer->peer_in = NULL;
peer->sent_to_peer = NULL; peer->sent_to_peer = NULL;
peer->urgent = false; peer->urgent = false;
peer->peer_outq = msg_queue_new(peer); peer->peer_outq = msg_queue_new(peer, false);
peer->subd_outq = msg_queue_new(peer); peer->subd_outq = msg_queue_new(peer, false);
#if DEVELOPER #if DEVELOPER
peer->dev_writes_enabled = NULL; peer->dev_writes_enabled = NULL;

View File

@@ -657,7 +657,7 @@ static struct io_plan *msg_send_next(struct io_conn *conn, struct subd *sd)
if (!msg) if (!msg)
return msg_queue_wait(conn, sd->outq, msg_send_next, sd); return msg_queue_wait(conn, sd->outq, msg_send_next, sd);
fd = msg_extract_fd(msg); fd = msg_extract_fd(sd->outq, msg);
if (fd >= 0) { if (fd >= 0) {
tal_free(msg); tal_free(msg);
return io_send_fd(conn, fd, true, msg_send_next, sd); return io_send_fd(conn, fd, true, msg_send_next, sd);
@@ -741,7 +741,7 @@ static struct subd *new_subd(struct lightningd *ld,
sd->errcb = errcb; sd->errcb = errcb;
sd->billboardcb = billboardcb; sd->billboardcb = billboardcb;
sd->fds_in = NULL; sd->fds_in = NULL;
sd->outq = msg_queue_new(sd); sd->outq = msg_queue_new(sd, true);
tal_add_destructor(sd, destroy_subd); tal_add_destructor(sd, destroy_subd);
list_head_init(&sd->reqs); list_head_init(&sd->reqs);
sd->channel = channel; sd->channel = channel;