db: enforce that bindings be done in order.

This is almost always true already; fix up the few non-standard ones.

This is enforced with an assert, and I ran the entire test suite to
double-check.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2023-07-14 09:58:45 +09:30
parent d17506b899
commit b7b3cbc84a
5 changed files with 35 additions and 25 deletions

View File

@@ -16,6 +16,12 @@
#define NSEC_IN_SEC 1000000000
static int check_bind_pos(struct db_stmt *stmt, int pos)
{
assert(pos == ++stmt->bind_pos);
return pos;
}
/* Local helpers once you have column number */
static bool db_column_is_null(struct db_stmt *stmt, int col)
{
@@ -37,7 +43,7 @@ static bool db_column_null_warn(struct db_stmt *stmt, const char *colname,
void db_bind_int(struct db_stmt *stmt, int pos, int val)
{
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
memcheck(&val, sizeof(val));
stmt->bindings[pos].type = DB_BINDING_INT;
stmt->bindings[pos].v.i = val;
@@ -60,21 +66,21 @@ int db_col_is_null(struct db_stmt *stmt, const char *colname)
void db_bind_null(struct db_stmt *stmt, int pos)
{
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
stmt->bindings[pos].type = DB_BINDING_NULL;
}
void db_bind_u64(struct db_stmt *stmt, int pos, u64 val)
{
memcheck(&val, sizeof(val));
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
stmt->bindings[pos].type = DB_BINDING_UINT64;
stmt->bindings[pos].v.u64 = val;
}
void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len)
{
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
stmt->bindings[pos].type = DB_BINDING_BLOB;
stmt->bindings[pos].v.blob = memcheck(val, len);
stmt->bindings[pos].len = len;
@@ -82,7 +88,7 @@ void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len)
void db_bind_text(struct db_stmt *stmt, int pos, const char *val)
{
assert(pos < tal_count(stmt->bindings));
pos = check_bind_pos(stmt, pos);
stmt->bindings[pos].type = DB_BINDING_TEXT;
stmt->bindings[pos].v.text = val;
stmt->bindings[pos].len = strlen(val);