diff --git a/lightningd/subd.c b/lightningd/subd.c index 517ad5b70..066419d29 100644 --- a/lightningd/subd.c +++ b/lightningd/subd.c @@ -19,53 +19,72 @@ #include #include -static bool move_fd(int from, int to) +/* Carefully move fd *@from to @to: on success *from set to to */ +static bool move_fd(int *from, int to) { - assert(from >= 0); + assert(*from >= 0); /* dup2 with same arguments may be a no-op, but * the later close would make the fd invalid. * Handle this edge case. */ - if (from == to) + if (*from == to) return true; - if (dup2(from, to) == -1) + if (dup2(*from, to) == -1) return false; /* dup2 does not duplicate flags, copy it here. * This should be benign; the only POSIX-defined * flag is FD_CLOEXEC, and we only use it rarely. */ - if (fcntl(to, F_SETFD, fcntl(from, F_GETFD)) < 0) + if (fcntl(to, F_SETFD, fcntl(*from, F_GETFD)) < 0) return false; - close(from); + close(*from); + *from = to; return true; } -/* Like the above, but move the fd from whatever it currently has - * to any other unused fd number that is *not* its current value. - */ -static bool move_fd_any(int *fd) +/* Returns index of fds which is == this fd, or -1 */ +static int fd_used(int **fds, size_t num_fds, int fd) { - int old_fd = *fd; - int new_fd; + for (size_t i = 0; i < num_fds; i++) { + if (*fds[i] == fd) + return i; + } + return -1; +} - assert(old_fd >= 0); +/* Move an series of fd pointers into 0, 1, ... */ +static bool shuffle_fds(int **fds, size_t num_fds) +{ + /* If we need to move an fd out the way, this is a good place to start + * looking */ + size_t next_free_fd = num_fds; + for (size_t i = 0; i < num_fds; i++) { + int in_the_way; - if ((new_fd = dup(old_fd)) == -1) - return false; + /* Already in the right place? Great! */ + if (*fds[i] == i) + continue; + /* Is something we care about in the way? */ + in_the_way = fd_used(fds + i, num_fds - i, i); + if (in_the_way != -1) { + /* Find a high-numbered unused fd. */ + while (fd_used(fds + i, num_fds - i, next_free_fd) != -1) + next_free_fd++; + /* Trick: in_the_way is offset by i! */ + if (!move_fd(fds[i + in_the_way], next_free_fd)) + return false; + next_free_fd++; + } - /* dup does not duplicate flags. */ - if (fcntl(new_fd, F_SETFD, fcntl(old_fd, F_GETFD)) < 0) - return false; - - close(old_fd); - - *fd = new_fd; - - assert(old_fd != *fd); + /* Now there should be nothing in the way. */ + assert(fd_used(fds, num_fds, i) == -1); + if (!move_fd(fds[i], i)) + return false; + } return true; } @@ -191,65 +210,37 @@ static int subd(const char *path, const char *name, goto close_execfail_fail; if (childpid == 0) { - int fdnum = 3, stdin_is_now = STDIN_FILENO; size_t num_args; char *args[] = { NULL, NULL, NULL, NULL, NULL }; + int **fds = tal_arr(tmpctx, int *, 3); + int stdout = STDOUT_FILENO, stderr = STDERR_FILENO; close(childmsg[0]); close(execfail[0]); - // msg = STDIN - if (childmsg[1] != STDIN_FILENO) { - /* Do we need to move STDIN out the way? */ - stdin_is_now = dup(STDIN_FILENO); - if (!move_fd(childmsg[1], STDIN_FILENO)) - goto child_errno_fail; - } + /* msg = STDIN (0) */ + fds[0] = &childmsg[1]; + /* These are untouched */ + fds[1] = &stdout; + fds[2] = &stderr; - /* Dup any extra fds up first. */ while ((fd = va_arg(*ap, int *)) != NULL) { - int actual_fd = *fd; - /* If this were stdin, we moved it above! */ - if (actual_fd == STDIN_FILENO) - actual_fd = stdin_is_now; - - /* If we would overwrite important fds, move those. */ - if (fdnum == dev_disconnect_fd) { - if (!move_fd_any(&dev_disconnect_fd)) - goto child_errno_fail; - } - if (fdnum == execfail[1]) { - if (!move_fd_any(&execfail[1])) - goto child_errno_fail; - } - - if (!move_fd(actual_fd, fdnum)) - goto child_errno_fail; - fdnum++; + assert(*fd != -1); + tal_arr_expand(&fds, fd); } - /* Move dev_disconnect_fd *after* the extra fds above. */ - if (dev_disconnect_fd != -1) { - /* Do not overwrite execfail[1]. */ - if (fdnum == execfail[1]) { - if (!move_fd_any(&execfail[1])) - goto child_errno_fail; - } - if (!move_fd(dev_disconnect_fd, fdnum)) - goto child_errno_fail; - dev_disconnect_fd = fdnum; - fdnum++; - } + /* If we have a dev_disconnect_fd, add it after. */ + if (dev_disconnect_fd != -1) + tal_arr_expand(&fds, &dev_disconnect_fd); - /* Move execfail[1] *after* the fds we will pass - * to the subdaemon. */ - if (!move_fd(execfail[1], fdnum)) + /* Finally, the fd to report exec errors on */ + tal_arr_expand(&fds, &execfail[1]); + + if (!shuffle_fds(fds, tal_count(fds))) goto child_errno_fail; - execfail[1] = fdnum; - fdnum++; /* Make (fairly!) sure all other fds are closed. */ - closefrom(fdnum); + closefrom(tal_count(fds) + 1); num_args = 0; args[num_args++] = tal_strdup(NULL, path); diff --git a/lightningd/test/run-shuffle_fds.c b/lightningd/test/run-shuffle_fds.c new file mode 100644 index 000000000..7f9a38060 --- /dev/null +++ b/lightningd/test/run-shuffle_fds.c @@ -0,0 +1,185 @@ +#include "config.h" +#include +#include +#include + +#undef dup2 +#undef close +#undef fcntl + +#define dup2 test_dup2 +#define close test_close +#define fcntl test_fcntl + +/* Indexed by fd num, -1 == not open. */ +#define MAX_TEST_FDS 100 +static int test_fd_arr[MAX_TEST_FDS]; + +static int test_dup2(int oldfd, int newfd) +{ + /* Must not clobber an existing fd */ + assert(test_fd_arr[newfd] == -1); + test_fd_arr[newfd] = test_fd_arr[oldfd]; + return 0; +} + +static int test_close(int fd) +{ + assert(test_fd_arr[fd] != -1); + test_fd_arr[fd] = -1; + return 0; +} + +static int test_fcntl(int fd, int cmd, ... /* arg */ ) +{ + assert(test_fd_arr[fd] != -1); + return 0; +} + +#include "../subd.c" +#include +#include + +/* AUTOGENERATED MOCKS START */ +/* Generated stub for db_begin_transaction_ */ +void db_begin_transaction_(struct db *db UNNEEDED, const char *location UNNEEDED) +{ fprintf(stderr, "db_begin_transaction_ called!\n"); abort(); } +/* Generated stub for db_commit_transaction */ +void db_commit_transaction(struct db *db UNNEEDED) +{ fprintf(stderr, "db_commit_transaction called!\n"); abort(); } +/* Generated stub for db_in_transaction */ +bool db_in_transaction(struct db *db UNNEEDED) +{ fprintf(stderr, "db_in_transaction called!\n"); abort(); } +/* Generated stub for fatal */ +void fatal(const char *fmt UNNEEDED, ...) +{ fprintf(stderr, "fatal called!\n"); abort(); } +/* Generated stub for fromwire_bigsize */ +bigsize_t fromwire_bigsize(const u8 **cursor UNNEEDED, size_t *max UNNEEDED) +{ fprintf(stderr, "fromwire_bigsize called!\n"); abort(); } +/* Generated stub for fromwire_channel_id */ +void fromwire_channel_id(const u8 **cursor UNNEEDED, size_t *max UNNEEDED, + struct channel_id *channel_id UNNEEDED) +{ fprintf(stderr, "fromwire_channel_id called!\n"); abort(); } +/* Generated stub for fromwire_node_id */ +void fromwire_node_id(const u8 **cursor UNNEEDED, size_t *max UNNEEDED, struct node_id *id UNNEEDED) +{ fprintf(stderr, "fromwire_node_id called!\n"); abort(); } +/* Generated stub for fromwire_status_fail */ +bool fromwire_status_fail(const tal_t *ctx UNNEEDED, const void *p UNNEEDED, enum status_failreason *failreason UNNEEDED, wirestring **desc UNNEEDED) +{ fprintf(stderr, "fromwire_status_fail called!\n"); abort(); } +/* Generated stub for fromwire_status_peer_billboard */ +bool fromwire_status_peer_billboard(const tal_t *ctx UNNEEDED, const void *p UNNEEDED, bool *perm UNNEEDED, wirestring **happenings UNNEEDED) +{ fprintf(stderr, "fromwire_status_peer_billboard called!\n"); abort(); } +/* Generated stub for fromwire_status_peer_error */ +bool fromwire_status_peer_error(const tal_t *ctx UNNEEDED, const void *p UNNEEDED, struct channel_id *channel UNNEEDED, wirestring **desc UNNEEDED, bool *warning UNNEEDED, struct per_peer_state **pps UNNEEDED, u8 **error_for_them UNNEEDED) +{ fprintf(stderr, "fromwire_status_peer_error called!\n"); abort(); } +/* Generated stub for fromwire_status_version */ +bool fromwire_status_version(const tal_t *ctx UNNEEDED, const void *p UNNEEDED, wirestring **version UNNEEDED) +{ fprintf(stderr, "fromwire_status_version called!\n"); abort(); } +/* Generated stub for json_add_member */ +void json_add_member(struct json_stream *js UNNEEDED, + const char *fieldname UNNEEDED, + bool quote UNNEEDED, + const char *fmt UNNEEDED, ...) +{ fprintf(stderr, "json_add_member called!\n"); abort(); } +/* Generated stub for json_member_direct */ +char *json_member_direct(struct json_stream *js UNNEEDED, + const char *fieldname UNNEEDED, size_t extra UNNEEDED) +{ fprintf(stderr, "json_member_direct called!\n"); abort(); } +/* Generated stub for log_ */ +void log_(struct log *log UNNEEDED, enum log_level level UNNEEDED, + const struct node_id *node_id UNNEEDED, + bool call_notifier UNNEEDED, + const char *fmt UNNEEDED, ...) + +{ fprintf(stderr, "log_ called!\n"); abort(); } +/* Generated stub for log_prefix */ +const char *log_prefix(const struct log *log UNNEEDED) +{ fprintf(stderr, "log_prefix called!\n"); abort(); } +/* Generated stub for log_print_level */ +enum log_level log_print_level(struct log *log UNNEEDED) +{ fprintf(stderr, "log_print_level called!\n"); abort(); } +/* Generated stub for log_status_msg */ +bool log_status_msg(struct log *log UNNEEDED, + const struct node_id *node_id UNNEEDED, + const u8 *msg UNNEEDED) +{ fprintf(stderr, "log_status_msg called!\n"); abort(); } +/* Generated stub for new_log */ +struct log *new_log(const tal_t *ctx UNNEEDED, struct log_book *record UNNEEDED, + const struct node_id *default_node_id UNNEEDED, + const char *fmt UNNEEDED, ...) +{ fprintf(stderr, "new_log called!\n"); abort(); } +/* Generated stub for per_peer_state_set_fds_arr */ +void per_peer_state_set_fds_arr(struct per_peer_state *pps UNNEEDED, const int *fds UNNEEDED) +{ fprintf(stderr, "per_peer_state_set_fds_arr called!\n"); abort(); } +/* Generated stub for subdaemon_path */ +const char *subdaemon_path(const tal_t *ctx UNNEEDED, const struct lightningd *ld UNNEEDED, const char *name UNNEEDED) +{ fprintf(stderr, "subdaemon_path called!\n"); abort(); } +/* Generated stub for towire_bigsize */ +void towire_bigsize(u8 **pptr UNNEEDED, const bigsize_t val UNNEEDED) +{ fprintf(stderr, "towire_bigsize called!\n"); abort(); } +/* Generated stub for towire_channel_id */ +void towire_channel_id(u8 **pptr UNNEEDED, const struct channel_id *channel_id UNNEEDED) +{ fprintf(stderr, "towire_channel_id called!\n"); abort(); } +/* Generated stub for towire_node_id */ +void towire_node_id(u8 **pptr UNNEEDED, const struct node_id *id UNNEEDED) +{ fprintf(stderr, "towire_node_id called!\n"); abort(); } +/* Generated stub for version */ +const char *version(void) +{ fprintf(stderr, "version called!\n"); abort(); } +/* AUTOGENERATED MOCKS END */ + +static void run_test_(int fd0, ...) +{ + va_list ap; + int fd, i; + int *test_fds = tal_arr(tmpctx, int, 1); + int **test_fd_ptrs; + + /* They start all closed */ + memset(test_fd_arr, 0xFF, sizeof(test_fd_arr)); + + test_fds[0] = fd0; + test_fd_arr[fd0] = fd0; + + va_start(ap, fd0); + while ((fd = va_arg(ap, int)) != -1) { + tal_arr_expand(&test_fds, fd); + test_fd_arr[fd] = fd; + } + va_end(ap); + + test_fd_ptrs = tal_arr(tmpctx, int *, tal_count(test_fds)); + for (i = 0; i < tal_count(test_fds); i++) + test_fd_ptrs[i] = &test_fds[i]; + + assert(shuffle_fds(test_fd_ptrs, tal_count(test_fd_ptrs))); + + /* Make sure fds ended up where expected */ + i = 0; + assert(test_fd_arr[i++] == fd0); + va_start(ap, fd0); + while ((fd = va_arg(ap, int)) != -1) + assert(test_fd_arr[i++] == fd); + va_end(ap); + + /* And rest were closed */ + for (; i < MAX_TEST_FDS; i++) + assert(test_fd_arr[i] == -1); +} + +#define run_test(...) \ + run_test_(__VA_ARGS__, -1) + +int main(int argc, char *argv[]) +{ + common_setup(argv[0]); + + run_test(0); + run_test(1); + run_test(0, 1); + run_test(0, 1, 3); + run_test(3, 2, 1, 0); + run_test(5, 2, 1); + + common_shutdown(); +}