diff --git a/common/crypto_sync.c b/common/crypto_sync.c index 878ff0628..633474bc5 100644 --- a/common/crypto_sync.c +++ b/common/crypto_sync.c @@ -25,6 +25,9 @@ bool sync_crypto_write(struct crypto_state *cs, int fd, const void *msg TAKES) case DEV_DISCONNECT_AFTER: post_sabotage = true; break; + case DEV_DISCONNECT_BLACKHOLE: + dev_blackhole_fd(fd); + break; default: break; } diff --git a/common/cryptomsg.c b/common/cryptomsg.c index aeed15c35..db4288c9e 100644 --- a/common/cryptomsg.c +++ b/common/cryptomsg.c @@ -352,6 +352,9 @@ struct io_plan *peer_write_message(struct io_conn *conn, case DEV_DISCONNECT_AFTER: post = peer_write_postclose; break; + case DEV_DISCONNECT_BLACKHOLE: + dev_blackhole_fd(io_conn_fd(conn)); + break; default: break; } diff --git a/common/dev_disconnect.c b/common/dev_disconnect.c index b1fd99369..e7d9cec50 100644 --- a/common/dev_disconnect.c +++ b/common/dev_disconnect.c @@ -3,8 +3,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -64,7 +66,7 @@ void dev_sabotage_fd(int fd) int fds[2]; if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) != 0) - errx(1, "dev_sabotage_fd: creating socketpair"); + err(1, "dev_sabotage_fd: creating socketpair"); /* Close one. */ close(fds[0]); @@ -72,3 +74,40 @@ void dev_sabotage_fd(int fd) dup2(fds[1], fd); close(fds[1]); } + +/* Replace fd with blackhole until dev_disconnect file is truncated. */ +void dev_blackhole_fd(int fd) +{ + int fds[2]; + int i; + struct stat st; + + if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) != 0) + err(1, "dev_blackhole_fd: creating socketpair"); + + switch (fork()) { + case -1: + err(1, "dev_blackhole_fd: forking"); + case 0: + /* Close everything but the dev_disconnect_fd, the socket + * which is pretending to be the peer, and stderr. */ + for (i = 0; i < sysconf(_SC_OPEN_MAX); i++) + if (i != fds[0] + && i != dev_disconnect_fd + && i != STDERR_FILENO) + close(i); + + /* Close once dev_disconnect file is truncated. */ + for (;;) { + if (fstat(dev_disconnect_fd, &st) != 0) + err(1, "fstat of dev_disconnect_fd failed"); + if (st.st_size == 0) + _exit(0); + sleep(1); + } + } + + close(fds[0]); + dup2(fds[1], fd); + close(fds[1]); +} diff --git a/common/dev_disconnect.h b/common/dev_disconnect.h index efa83a5d5..108b9c92f 100644 --- a/common/dev_disconnect.h +++ b/common/dev_disconnect.h @@ -6,6 +6,7 @@ #define DEV_DISCONNECT_BEFORE '-' #define DEV_DISCONNECT_AFTER '+' #define DEV_DISCONNECT_DROPPKT '@' +#define DEV_DISCONNECT_BLACKHOLE '0' #define DEV_DISCONNECT_NORMAL 0 /* Force a close fd before or after a certain packet type */ @@ -14,6 +15,9 @@ char dev_disconnect(int pkt_type); /* Make next write on fd fail as if they'd disconnected. */ void dev_sabotage_fd(int fd); +/* No more data to arrive, what's written is swallowed. */ +void dev_blackhole_fd(int fd); + /* For debug code to set in daemon. */ void dev_disconnect_init(int fd); diff --git a/lightningd/peer_control.c b/lightningd/peer_control.c index c6084a08d..ecb0f1bed 100644 --- a/lightningd/peer_control.c +++ b/lightningd/peer_control.c @@ -263,6 +263,11 @@ void dev_sabotage_fd(int fd) abort(); } +void dev_blackhole_fd(int fd) +{ + abort(); +} + /* Send (encrypted) error message, then close. */ static struct io_plan *send_error(struct io_conn *conn, struct peer_crypto_state *pcs) diff --git a/lightningd/test/run-cryptomsg.c b/lightningd/test/run-cryptomsg.c index 76466552c..1f0b4f57c 100644 --- a/lightningd/test/run-cryptomsg.c +++ b/lightningd/test/run-cryptomsg.c @@ -39,10 +39,14 @@ static void do_write(const void *buf, size_t len) #define status_trace(fmt, ...) \ printf(fmt "\n", __VA_ARGS__) -void dev_sabotage_fd(int fd) -{ - abort(); -} +/* AUTOGENERATED MOCKS START */ +/* Generated stub for dev_blackhole_fd */ +void dev_blackhole_fd(int fd UNNEEDED) +{ fprintf(stderr, "dev_blackhole_fd called!\n"); abort(); } +/* Generated stub for dev_sabotage_fd */ +void dev_sabotage_fd(int fd UNNEEDED) +{ fprintf(stderr, "dev_sabotage_fd called!\n"); abort(); } +/* AUTOGENERATED MOCKS END */ char dev_disconnect(int pkt_type) { diff --git a/tests/utils.py b/tests/utils.py index ecb9e9ac8..776009c83 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -59,15 +59,15 @@ class TailableProc(object): self.running = True def stop(self): - self.proc.terminate() - self.proc.kill() - self.proc.wait() - self.thread.join() if self.outputDir: logpath = os.path.join(self.outputDir, 'log') with open(logpath, 'w') as f: for l in self.logs: f.write(l + '\n') + self.proc.terminate() + self.proc.kill() + self.proc.wait() + self.thread.join() def tail(self): """Tail the stdout of the process and remember it.