From b7b3cbc84af3d3b1d2d45fb9b18bf262d7c4479a Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Fri, 14 Jul 2023 09:58:45 +0930 Subject: [PATCH] db: enforce that bindings be done in order. This is almost always true already; fix up the few non-standard ones. This is enforced with an assert, and I ran the entire test suite to double-check. Signed-off-by: Rusty Russell --- db/bindings.c | 16 +++++++++++----- db/common.h | 3 +++ db/utils.c | 1 + plugins/bkpr/recorder.c | 23 +++++++++++------------ wallet/wallet.c | 17 +++++++++-------- 5 files changed, 35 insertions(+), 25 deletions(-) diff --git a/db/bindings.c b/db/bindings.c index 34941d7f2..a9ca4c630 100644 --- a/db/bindings.c +++ b/db/bindings.c @@ -16,6 +16,12 @@ #define NSEC_IN_SEC 1000000000 +static int check_bind_pos(struct db_stmt *stmt, int pos) +{ + assert(pos == ++stmt->bind_pos); + return pos; +} + /* Local helpers once you have column number */ static bool db_column_is_null(struct db_stmt *stmt, int col) { @@ -37,7 +43,7 @@ static bool db_column_null_warn(struct db_stmt *stmt, const char *colname, void db_bind_int(struct db_stmt *stmt, int pos, int val) { - assert(pos < tal_count(stmt->bindings)); + pos = check_bind_pos(stmt, pos); memcheck(&val, sizeof(val)); stmt->bindings[pos].type = DB_BINDING_INT; stmt->bindings[pos].v.i = val; @@ -60,21 +66,21 @@ int db_col_is_null(struct db_stmt *stmt, const char *colname) void db_bind_null(struct db_stmt *stmt, int pos) { - assert(pos < tal_count(stmt->bindings)); + pos = check_bind_pos(stmt, pos); stmt->bindings[pos].type = DB_BINDING_NULL; } void db_bind_u64(struct db_stmt *stmt, int pos, u64 val) { memcheck(&val, sizeof(val)); - assert(pos < tal_count(stmt->bindings)); + pos = check_bind_pos(stmt, pos); stmt->bindings[pos].type = DB_BINDING_UINT64; stmt->bindings[pos].v.u64 = val; } void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len) { - assert(pos < tal_count(stmt->bindings)); + pos = check_bind_pos(stmt, pos); stmt->bindings[pos].type = DB_BINDING_BLOB; stmt->bindings[pos].v.blob = memcheck(val, len); stmt->bindings[pos].len = len; @@ -82,7 +88,7 @@ void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len) void db_bind_text(struct db_stmt *stmt, int pos, const char *val) { - assert(pos < tal_count(stmt->bindings)); + pos = check_bind_pos(stmt, pos); stmt->bindings[pos].type = DB_BINDING_TEXT; stmt->bindings[pos].v.text = val; stmt->bindings[pos].len = strlen(val); diff --git a/db/common.h b/db/common.h index f1ef49c1f..2e87bf05e 100644 --- a/db/common.h +++ b/db/common.h @@ -104,6 +104,9 @@ struct db_stmt { /* Our entry in the list of pending statements. */ struct list_node list; + /* Bind counter */ + int bind_pos; + /* Database we are querying */ struct db *db; diff --git a/db/utils.c b/db/utils.c index 48ae718a9..33e797ab1 100644 --- a/db/utils.c +++ b/db/utils.c @@ -81,6 +81,7 @@ static struct db_stmt *db_prepare_core(struct db *db, stmt->query = db_query; stmt->executed = false; stmt->inner_stmt = NULL; + stmt->bind_pos = -1; tal_add_destructor(stmt, db_stmt_free); diff --git a/plugins/bkpr/recorder.c b/plugins/bkpr/recorder.c index 3a600614c..14ba19e13 100644 --- a/plugins/bkpr/recorder.c +++ b/plugins/bkpr/recorder.c @@ -731,11 +731,11 @@ static struct chain_event *find_chain_event(const tal_t *ctx, " LEFT OUTER JOIN accounts a" " ON e.account_id = a.id" " WHERE " - " e.account_id = ?" + " e.spending_txid = ?" + " AND e.account_id = ?" " AND e.utxo_txid = ?" - " AND e.outnum = ?" - " AND e.spending_txid = ?")); - db_bind_txid(stmt, 3, spending_txid); + " AND e.outnum = ?")); + db_bind_txid(stmt, 0, spending_txid); } else { stmt = db_prepare_v2(db, SQL("SELECT" " e.id" @@ -760,18 +760,17 @@ static struct chain_event *find_chain_event(const tal_t *ctx, " LEFT OUTER JOIN accounts a" " ON e.account_id = a.id" " WHERE " - " e.account_id = ?" + " e.tag = ?" + " AND e.account_id = ?" " AND e.utxo_txid = ?" " AND e.outnum = ?" - " AND e.spending_txid IS NULL" - " AND e.tag = ?")); - - db_bind_text(stmt, 3, tag); + " AND e.spending_txid IS NULL")); + db_bind_text(stmt, 0, tag); } - db_bind_u64(stmt, 0, acct->db_id); - db_bind_txid(stmt, 1, &outpoint->txid); - db_bind_int(stmt, 2, outpoint->n); + db_bind_u64(stmt, 1, acct->db_id); + db_bind_txid(stmt, 2, &outpoint->txid); + db_bind_int(stmt, 3, outpoint->n); db_query_prepared(stmt); if (db_step(stmt)) diff --git a/wallet/wallet.c b/wallet/wallet.c index 3900838c3..aa220d083 100644 --- a/wallet/wallet.c +++ b/wallet/wallet.c @@ -2736,7 +2736,6 @@ void wallet_htlc_update(struct wallet *wallet, const u64 htlc_dbid, " WHERE id=?")); db_bind_int(stmt, 0, htlc_state_in_db(new_state)); - db_bind_u64(stmt, 7, htlc_dbid); if (payment_key) db_bind_preimage(stmt, 1, payment_key); @@ -2763,6 +2762,7 @@ void wallet_htlc_update(struct wallet *wallet, const u64 htlc_dbid, else db_bind_null(stmt, 6); + db_bind_u64(stmt, 7, htlc_dbid); db_exec_prepared_v2(take(stmt)); if (terminal) { @@ -3278,6 +3278,7 @@ void wallet_payment_delete(struct wallet *wallet, " AND groupid = ?" " AND partid = ?" " AND status = ?")); + db_bind_sha256(stmt, 0, payment_hash); db_bind_u64(stmt, 1, *groupid); db_bind_u64(stmt, 2, *partid); db_bind_u64(stmt, 3, *status); @@ -3287,9 +3288,9 @@ void wallet_payment_delete(struct wallet *wallet, SQL("DELETE FROM payments" " WHERE payment_hash = ?" " AND status = ?")); + db_bind_sha256(stmt, 0, payment_hash); db_bind_u64(stmt, 1, *status); } - db_bind_sha256(stmt, 0, payment_hash); db_exec_prepared_v2(take(stmt)); } @@ -3572,9 +3573,9 @@ void wallet_payment_set_failinfo(struct wallet *wallet, " , failcode=?" " , failnode=?" " , failscid=?" + " , faildirection=?" " , failupdate=?" " , faildetail=?" - " , faildirection=?" " WHERE payment_hash=?" " AND partid=?;")); if (failonionreply) @@ -3592,18 +3593,18 @@ void wallet_payment_set_failinfo(struct wallet *wallet, if (failchannel) { db_bind_short_channel_id(stmt, 5, failchannel); - db_bind_int(stmt, 8, faildirection); + db_bind_int(stmt, 6, faildirection); } else { db_bind_null(stmt, 5); - db_bind_null(stmt, 8); + db_bind_null(stmt, 6); } - db_bind_talarr(stmt, 6, failupdate); + db_bind_talarr(stmt, 7, failupdate); if (faildetail != NULL) - db_bind_text(stmt, 7, faildetail); + db_bind_text(stmt, 8, faildetail); else - db_bind_null(stmt, 7); + db_bind_null(stmt, 8); db_bind_sha256(stmt, 9, payment_hash); db_bind_u64(stmt, 10, partid);