db: enforce that bindings be done in order.

This is almost always true already; fix up the few non-standard ones.

This is enforced with an assert, and I ran the entire test suite to
double-check.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2023-07-14 09:58:45 +09:30
parent d17506b899
commit b7b3cbc84a
5 changed files with 35 additions and 25 deletions

View File

@@ -16,6 +16,12 @@
#define NSEC_IN_SEC 1000000000
static int check_bind_pos(struct db_stmt *stmt, int pos)
{
assert(pos == ++stmt->bind_pos);
return pos;
}
/* Local helpers once you have column number */
static bool db_column_is_null(struct db_stmt *stmt, int col)
{
@@ -37,7 +43,7 @@ static bool db_column_null_warn(struct db_stmt *stmt, const char *colname,
void db_bind_int(struct db_stmt *stmt, int pos, int val)
{
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
memcheck(&val, sizeof(val));
stmt->bindings[pos].type = DB_BINDING_INT;
stmt->bindings[pos].v.i = val;
@@ -60,21 +66,21 @@ int db_col_is_null(struct db_stmt *stmt, const char *colname)
void db_bind_null(struct db_stmt *stmt, int pos)
{
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
stmt->bindings[pos].type = DB_BINDING_NULL;
}
void db_bind_u64(struct db_stmt *stmt, int pos, u64 val)
{
memcheck(&val, sizeof(val));
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
stmt->bindings[pos].type = DB_BINDING_UINT64;
stmt->bindings[pos].v.u64 = val;
}
void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len)
{
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
stmt->bindings[pos].type = DB_BINDING_BLOB;
stmt->bindings[pos].v.blob = memcheck(val, len);
stmt->bindings[pos].len = len;
@@ -82,7 +88,7 @@ void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len)
void db_bind_text(struct db_stmt *stmt, int pos, const char *val)
{
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
stmt->bindings[pos].type = DB_BINDING_TEXT;
stmt->bindings[pos].v.text = val;
stmt->bindings[pos].len = strlen(val);

View File

@@ -104,6 +104,9 @@ struct db_stmt {
/* Our entry in the list of pending statements. */
struct list_node list;
/* Bind counter */
int bind_pos;
/* Database we are querying */
struct db *db;

View File

@@ -81,6 +81,7 @@ static struct db_stmt *db_prepare_core(struct db *db,
stmt->query = db_query;
stmt->executed = false;
stmt->inner_stmt = NULL;
stmt->bind_pos = -1;
tal_add_destructor(stmt, db_stmt_free);

View File

@@ -731,11 +731,11 @@ static struct chain_event *find_chain_event(const tal_t *ctx,
" LEFT OUTER JOIN accounts a"
" ON e.account_id = a.id"
" WHERE "
" e.account_id = ?"
" e.spending_txid = ?"
" AND e.account_id = ?"
" AND e.utxo_txid = ?"
" AND e.outnum = ?"
" AND e.spending_txid = ?"));
db_bind_txid(stmt, 3, spending_txid);
" AND e.outnum = ?"));
db_bind_txid(stmt, 0, spending_txid);
} else {
stmt = db_prepare_v2(db, SQL("SELECT"
" e.id"
@@ -760,18 +760,17 @@ static struct chain_event *find_chain_event(const tal_t *ctx,
" LEFT OUTER JOIN accounts a"
" ON e.account_id = a.id"
" WHERE "
" e.account_id = ?"
" e.tag = ?"
" AND e.account_id = ?"
" AND e.utxo_txid = ?"
" AND e.outnum = ?"
" AND e.spending_txid IS NULL"
" AND e.tag = ?"));
db_bind_text(stmt, 3, tag);
" AND e.spending_txid IS NULL"));
db_bind_text(stmt, 0, tag);
}
db_bind_u64(stmt, 0, acct->db_id);
db_bind_txid(stmt, 1, &outpoint->txid);
db_bind_int(stmt, 2, outpoint->n);
db_bind_u64(stmt, 1, acct->db_id);
db_bind_txid(stmt, 2, &outpoint->txid);
db_bind_int(stmt, 3, outpoint->n);
db_query_prepared(stmt);
if (db_step(stmt))

View File

@@ -2736,7 +2736,6 @@ void wallet_htlc_update(struct wallet *wallet, const u64 htlc_dbid,
" WHERE id=?"));
db_bind_int(stmt, 0, htlc_state_in_db(new_state));
db_bind_u64(stmt, 7, htlc_dbid);
if (payment_key)
db_bind_preimage(stmt, 1, payment_key);
@@ -2763,6 +2762,7 @@ void wallet_htlc_update(struct wallet *wallet, const u64 htlc_dbid,
else
db_bind_null(stmt, 6);
db_bind_u64(stmt, 7, htlc_dbid);
db_exec_prepared_v2(take(stmt));
if (terminal) {
@@ -3278,6 +3278,7 @@ void wallet_payment_delete(struct wallet *wallet,
" AND groupid = ?"
" AND partid = ?"
" AND status = ?"));
db_bind_sha256(stmt, 0, payment_hash);
db_bind_u64(stmt, 1, *groupid);
db_bind_u64(stmt, 2, *partid);
db_bind_u64(stmt, 3, *status);
@@ -3287,9 +3288,9 @@ void wallet_payment_delete(struct wallet *wallet,
SQL("DELETE FROM payments"
" WHERE payment_hash = ?"
" AND status = ?"));
db_bind_sha256(stmt, 0, payment_hash);
db_bind_u64(stmt, 1, *status);
}
db_bind_sha256(stmt, 0, payment_hash);
db_exec_prepared_v2(take(stmt));
}
@@ -3572,9 +3573,9 @@ void wallet_payment_set_failinfo(struct wallet *wallet,
" , failcode=?"
" , failnode=?"
" , failscid=?"
" , faildirection=?"
" , failupdate=?"
" , faildetail=?"
" , faildirection=?"
" WHERE payment_hash=?"
" AND partid=?;"));
if (failonionreply)
@@ -3592,18 +3593,18 @@ void wallet_payment_set_failinfo(struct wallet *wallet,
if (failchannel) {
db_bind_short_channel_id(stmt, 5, failchannel);
db_bind_int(stmt, 8, faildirection);
db_bind_int(stmt, 6, faildirection);
} else {
db_bind_null(stmt, 5);
db_bind_null(stmt, 8);
db_bind_null(stmt, 6);
}
db_bind_talarr(stmt, 6, failupdate);
db_bind_talarr(stmt, 7, failupdate);
if (faildetail != NULL)
db_bind_text(stmt, 7, faildetail);
db_bind_text(stmt, 8, faildetail);
else
db_bind_null(stmt, 7);
db_bind_null(stmt, 8);
db_bind_sha256(stmt, 9, payment_hash);
db_bind_u64(stmt, 10, partid);