From bdaec48400703d336f109a2fffdc21f30fe5fa94 Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Wed, 13 Oct 2021 14:12:43 +1030 Subject: [PATCH] wallet: wrap htlc_state enum in db function. All enums in the db should be wrapped this way on reading/writing them. Signed-off-by: Rusty Russell --- wallet/wallet.c | 11 ++++---- wallet/wallet.h | 71 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 6 deletions(-) diff --git a/wallet/wallet.c b/wallet/wallet.c index 0fc3f2878..227740077 100644 --- a/wallet/wallet.c +++ b/wallet/wallet.c @@ -968,7 +968,7 @@ static struct fee_states *wallet_channel_fee_states_load(struct wallet *w, /* Start with blank slate. */ fee_states = new_fee_states(w, opener, NULL); while (db_step(stmt)) { - enum htlc_state hstate = db_column_int(stmt, 0); + enum htlc_state hstate = htlc_state_in_db(db_column_int(stmt, 0)); u32 feerate = db_column_int(stmt, 1); if (fee_states->feerate[hstate] != NULL) { @@ -1004,7 +1004,7 @@ static struct height_states *wallet_channel_height_states_load(struct wallet *w, /* Start with blank slate. */ states = new_height_states(w, opener, NULL); while (db_step(stmt)) { - enum htlc_state hstate = db_column_int(stmt, 0); + enum htlc_state hstate = htlc_state_in_db(db_column_int(stmt, 0)); u32 blockheight = db_column_int(stmt, 1); if (states->height[hstate] != NULL) { @@ -1936,7 +1936,7 @@ void wallet_channel_save(struct wallet *w, struct channel *chan) stmt = db_prepare_v2(w->db, SQL("INSERT INTO channel_feerates " " VALUES(?, ?, ?)")); db_bind_u64(stmt, 0, chan->dbid); - db_bind_int(stmt, 1, i); + db_bind_int(stmt, 1, htlc_state_in_db(i)); db_bind_int(stmt, 2, *chan->fee_states->feerate[i]); db_exec_prepared_v2(take(stmt)); } @@ -1955,7 +1955,7 @@ void wallet_channel_save(struct wallet *w, struct channel *chan) stmt = db_prepare_v2(w->db, SQL("INSERT INTO channel_blockheights " " VALUES(?, ?, ?)")); db_bind_u64(stmt, 0, chan->dbid); - db_bind_int(stmt, 1, i); + db_bind_int(stmt, 1, htlc_state_in_db(i)); db_bind_int(stmt, 2, *chan->blockheight_states->height[i]); db_exec_prepared_v2(take(stmt)); } @@ -2429,8 +2429,7 @@ void wallet_htlc_update(struct wallet *wallet, const u64 htlc_dbid, "we_filled=?" " WHERE id=?")); - /* FIXME: htlc_state_in_db */ - db_bind_int(stmt, 0, new_state); + db_bind_int(stmt, 0, htlc_state_in_db(new_state)); db_bind_u64(stmt, 6, htlc_dbid); if (payment_key) diff --git a/wallet/wallet.h b/wallet/wallet.h index 95e4814aa..7cf7eaaf1 100644 --- a/wallet/wallet.h +++ b/wallet/wallet.h @@ -159,6 +159,77 @@ static inline const char* forward_status_name(enum forward_status status) bool string_to_forward_status(const char *status_str, enum forward_status *status); +/* DB wrapper to check htlc_state */ +static inline enum htlc_state htlc_state_in_db(enum htlc_state s) +{ + switch (s) { + case SENT_ADD_HTLC: + BUILD_ASSERT(SENT_ADD_HTLC == 0); + return s; + case SENT_ADD_COMMIT: + BUILD_ASSERT(SENT_ADD_COMMIT == 1); + return s; + case RCVD_ADD_REVOCATION: + BUILD_ASSERT(RCVD_ADD_REVOCATION == 2); + return s; + case RCVD_ADD_ACK_COMMIT: + BUILD_ASSERT(RCVD_ADD_ACK_COMMIT == 3); + return s; + case SENT_ADD_ACK_REVOCATION: + BUILD_ASSERT(SENT_ADD_ACK_REVOCATION == 4); + return s; + case RCVD_REMOVE_HTLC: + BUILD_ASSERT(RCVD_REMOVE_HTLC == 5); + return s; + case RCVD_REMOVE_COMMIT: + BUILD_ASSERT(RCVD_REMOVE_COMMIT == 6); + return s; + case SENT_REMOVE_REVOCATION: + BUILD_ASSERT(SENT_REMOVE_REVOCATION == 7); + return s; + case SENT_REMOVE_ACK_COMMIT: + BUILD_ASSERT(SENT_REMOVE_ACK_COMMIT == 8); + return s; + case RCVD_REMOVE_ACK_REVOCATION: + BUILD_ASSERT(RCVD_REMOVE_ACK_REVOCATION == 9); + return s; + case RCVD_ADD_HTLC: + BUILD_ASSERT(RCVD_ADD_HTLC == 10); + return s; + case RCVD_ADD_COMMIT: + BUILD_ASSERT(RCVD_ADD_COMMIT == 11); + return s; + case SENT_ADD_REVOCATION: + BUILD_ASSERT(SENT_ADD_REVOCATION == 12); + return s; + case SENT_ADD_ACK_COMMIT: + BUILD_ASSERT(SENT_ADD_ACK_COMMIT == 13); + return s; + case RCVD_ADD_ACK_REVOCATION: + BUILD_ASSERT(RCVD_ADD_ACK_REVOCATION == 14); + return s; + case SENT_REMOVE_HTLC: + BUILD_ASSERT(SENT_REMOVE_HTLC == 15); + return s; + case SENT_REMOVE_COMMIT: + BUILD_ASSERT(SENT_REMOVE_COMMIT == 16); + return s; + case RCVD_REMOVE_REVOCATION: + BUILD_ASSERT(RCVD_REMOVE_REVOCATION == 17); + return s; + case RCVD_REMOVE_ACK_COMMIT: + BUILD_ASSERT(RCVD_REMOVE_ACK_COMMIT == 18); + return s; + case SENT_REMOVE_ACK_REVOCATION: + BUILD_ASSERT(SENT_REMOVE_ACK_REVOCATION == 19); + return s; + case HTLC_STATE_INVALID: + /* Not in db! */ + break; + } + fatal("%s: %u is invalid", __func__, s); +} + struct forwarding { struct short_channel_id channel_in, channel_out; struct amount_msat msat_in, msat_out, fee;