db: db_col_ variants for accessing SELECT statements by name.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2021-11-15 04:24:46 +10:30
parent 787fbb1228
commit 5b482eb04b
2 changed files with 401 additions and 1 deletions

View File

@@ -2168,6 +2168,323 @@ u8 *db_column_talarr(const tal_t *ctx, struct db_stmt *stmt, int col)
db_column_bytes(stmt, col), 0); db_column_bytes(stmt, col), 0);
} }
/* Modern variants: by name */
u64 db_col_u64(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col)) {
log_broken(stmt->db->log, "Accessing a null column %s/%zu in query %s", colname, col, stmt->query->query);
return 0;
}
return stmt->db->config->column_u64_fn(stmt, col);
}
int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col))
return def;
else
return db_column_int(stmt, col);
}
int db_col_int(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col)) {
log_broken(stmt->db->log, "Accessing a null column %s/%zu in query %s", colname, col, stmt->query->query);
return 0;
}
return stmt->db->config->column_int_fn(stmt, col);
}
size_t db_col_bytes(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col)) {
log_broken(stmt->db->log, "Accessing a null column %s/%zu in query %s", colname, col, stmt->query->query);
return 0;
}
return stmt->db->config->column_bytes_fn(stmt, col);
}
int db_col_is_null(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
return stmt->db->config->column_is_null_fn(stmt, col);
}
const void *db_col_blob(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col)) {
log_broken(stmt->db->log, "Accessing a null column %s/%zu in query %s", colname, col, stmt->query->query);
return NULL;
}
return stmt->db->config->column_blob_fn(stmt, col);
}
const unsigned char *db_col_text(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col)) {
log_broken(stmt->db->log, "Accessing a null column %s/%zu in query %s", colname, col, stmt->query->query);
return NULL;
}
return stmt->db->config->column_text_fn(stmt, col);
}
void db_col_preimage(struct db_stmt *stmt, const char *colname,
struct preimage *preimage)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct preimage);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(preimage, raw, size);
}
void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == sizeof(dest->id));
memcpy(dest->id, db_column_blob(stmt, col), sizeof(dest->id));
}
void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == sizeof(dest->k));
memcpy(dest->k, db_column_blob(stmt, col), sizeof(dest->k));
}
struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
struct node_id *ret;
size_t n = db_column_bytes(stmt, col) / sizeof(ret->k);
const u8 *arr = db_column_blob(stmt, col);
assert(n * sizeof(ret->k) == (size_t)db_column_bytes(stmt, col));
ret = tal_arr(ctx, struct node_id, n);
for (size_t i = 0; i < n; i++)
memcpy(ret[i].k, arr + i * sizeof(ret[i].k), sizeof(ret[i].k));
return ret;
}
void db_col_pubkey(struct db_stmt *stmt,
const char *colname,
struct pubkey *dest)
{
size_t col = db_query_colnum(stmt, colname);
bool ok;
assert(db_column_bytes(stmt, col) == PUBKEY_CMPR_LEN);
ok = pubkey_from_der(db_column_blob(stmt, col), PUBKEY_CMPR_LEN, dest);
assert(ok);
}
bool db_col_short_channel_id(struct db_stmt *stmt, const char *colname,
struct short_channel_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
const char *source = db_column_blob(stmt, col);
size_t sourcelen = db_column_bytes(stmt, col);
return short_channel_id_from_str(source, sourcelen, dest);
}
struct short_channel_id *
db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *ser;
size_t len;
struct short_channel_id *ret;
ser = db_column_blob(stmt, col);
len = db_column_bytes(stmt, col);
ret = tal_arr(ctx, struct short_channel_id, 0);
while (len != 0) {
struct short_channel_id scid;
fromwire_short_channel_id(&ser, &len, &scid);
tal_arr_expand(&ret, scid);
}
return ret;
}
bool db_col_signature(struct db_stmt *stmt, const char *colname,
secp256k1_ecdsa_signature *sig)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == 64);
return secp256k1_ecdsa_signature_parse_compact(
secp256k1_ctx, sig, db_column_blob(stmt, col)) == 1;
}
struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
struct timeabs t;
u64 timestamp = db_column_u64(stmt, col);
t.ts.tv_sec = timestamp / NSEC_IN_SEC;
t.ts.tv_nsec = timestamp % NSEC_IN_SEC;
return t;
}
struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *src = db_column_blob(stmt, col);
size_t len = db_column_bytes(stmt, col);
return pull_bitcoin_tx(ctx, &src, &len);
}
struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *src = db_column_blob(stmt, col);
size_t len = db_column_bytes(stmt, col);
return psbt_from_bytes(ctx, src, len);
}
struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
struct wally_psbt *psbt = db_column_psbt(ctx, stmt, col);
if (!psbt)
return NULL;
return bitcoin_tx_with_psbt(ctx, psbt);
}
void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname,
size_t bytes, const char *label, const char *caller)
{
size_t col = db_query_colnum(stmt, colname);
size_t sourcelen;
void *p;
if (db_column_is_null(stmt, col))
return NULL;
sourcelen = db_column_bytes(stmt, col);
if (sourcelen % bytes != 0)
db_fatal("%s: %s/%zu column size for %zu not a multiple of %s (%zu)",
caller, colname, col, sourcelen, label, bytes);
p = tal_arr_label(ctx, char, sourcelen, label);
memcpy(p, db_column_blob(stmt, col), sourcelen);
return p;
}
void db_col_amount_msat_or_default(struct db_stmt *stmt,
const char *colname,
struct amount_msat *msat,
struct amount_msat def)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col))
*msat = def;
else
msat->millisatoshis = db_column_u64(stmt, col); /* Raw: low level function */
}
void db_col_amount_msat(struct db_stmt *stmt, const char *colname,
struct amount_msat *msat)
{
size_t col = db_query_colnum(stmt, colname);
msat->millisatoshis = db_column_u64(stmt, col); /* Raw: low level function */
}
void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat)
{
size_t col = db_query_colnum(stmt, colname);
sat->satoshis = db_column_u64(stmt, col); /* Raw: low level function */
}
struct json_escape *db_col_json_escape(const tal_t *ctx,
struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
return json_escape_string_(ctx, db_column_blob(stmt, col),
db_column_bytes(stmt, col));
}
void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct sha256);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(sha, raw, size);
}
void db_col_sha256d(struct db_stmt *stmt, const char *colname,
struct sha256_double *shad)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct sha256_double);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(shad, raw, size);
}
void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
assert(db_column_bytes(stmt, col) == sizeof(struct secret));
raw = db_column_blob(stmt, col);
memcpy(s, raw, sizeof(struct secret));
}
struct secret *db_col_secret_arr(const tal_t *ctx,
struct db_stmt *stmt,
const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
return db_column_arr(ctx, stmt, col, struct secret);
}
void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t)
{
size_t col = db_query_colnum(stmt, colname);
db_column_sha256d(stmt, col, &t->shad);
}
struct onionreply *db_col_onionreply(const tal_t *ctx,
struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
struct onionreply *r = tal(ctx, struct onionreply);
r->contents = tal_dup_arr(r, u8,
db_column_blob(stmt, col),
db_column_bytes(stmt, col), 0);
return r;
}
bool db_exec_prepared_v2(struct db_stmt *stmt TAKES) bool db_exec_prepared_v2(struct db_stmt *stmt TAKES)
{ {
bool ret = stmt->db->config->exec_fn(stmt); bool ret = stmt->db->config->exec_fn(stmt);
@@ -2224,3 +2541,28 @@ const char **db_changes(struct db *db)
{ {
return db->changes; return db->changes;
} }
/* Matches the hash function used in devtools/sql-rewrite.py */
static u32 hash_djb2(const char *str)
{
u32 hash = 5381;
for (size_t i = 0; str[i]; i++)
hash = ((hash << 5) + hash) ^ str[i];
return hash;
}
size_t db_query_colnum(const struct db_stmt *stmt,
const char *colname)
{
u32 col;
assert(stmt->query->colnames != NULL);
col = hash_djb2(colname) % stmt->query->num_colnames;
/* Will crash on NULL, which is the Right Thing */
while (!streq(stmt->query->colnames[col].sqlname,
colname)) {
col = (col + 1) % stmt->query->num_colnames;
}
return stmt->query->colnames[col].val;
}

View File

@@ -128,6 +128,7 @@ void db_bind_onionreply(struct db_stmt *stmt, int col,
void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr); void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr);
bool db_step(struct db_stmt *stmt); bool db_step(struct db_stmt *stmt);
u64 db_column_u64(struct db_stmt *stmt, int col); u64 db_column_u64(struct db_stmt *stmt, int col);
int db_column_int(struct db_stmt *stmt, int col); int db_column_int(struct db_stmt *stmt, int col);
size_t db_column_bytes(struct db_stmt *stmt, int col); size_t db_column_bytes(struct db_stmt *stmt, int col);
@@ -178,6 +179,63 @@ void db_column_amount_msat_or_default(struct db_stmt *stmt, int col,
struct amount_msat *msat, struct amount_msat *msat,
struct amount_msat def); struct amount_msat def);
/* Modern variants: get columns by name from SELECT */
/* Bridge function to get column number from SELECT
(must exist) */
size_t db_query_colnum(const struct db_stmt *stmt,
const char *colname);
u64 db_col_u64(struct db_stmt *stmt, const char *colname);
int db_col_int(struct db_stmt *stmt, const char *colname);
size_t db_col_bytes(struct db_stmt *stmt, const char *colname);
int db_col_is_null(struct db_stmt *stmt, const char *colname);
const void* db_col_blob(struct db_stmt *stmt, const char *colname);
const unsigned char *db_col_text(struct db_stmt *stmt, const char *colname);
void db_col_preimage(struct db_stmt *stmt, const char *colname, struct preimage *preimage);
void db_col_amount_msat(struct db_stmt *stmt, const char *colname, struct amount_msat *msat);
void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat);
struct json_escape *db_col_json_escape(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha);
void db_col_sha256d(struct db_stmt *stmt, const char *colname, struct sha256_double *shad);
void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s);
struct secret *db_col_secret_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname);
void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t);
void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest);
void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *ni);
struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname);
void db_col_pubkey(struct db_stmt *stmt, const char *colname,
struct pubkey *p);
bool db_col_short_channel_id(struct db_stmt *stmt, const char *colname,
struct short_channel_id *dest);
struct short_channel_id *
db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
bool db_col_signature(struct db_stmt *stmt, const char *colname,
secp256k1_ecdsa_signature *sig);
struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname);
struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct onionreply *db_col_onionreply(const tal_t *ctx,
struct db_stmt *stmt, const char *colname);
#define db_col_arr(ctx, stmt, colname, type) \
((type *)db_col_arr_((ctx), (stmt), (colname), \
sizeof(type), TAL_LABEL(type, "[]"), \
__func__))
void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname,
size_t bytes, const char *label, const char *caller);
/* Some useful default variants */
int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def);
void db_col_amount_msat_or_default(struct db_stmt *stmt, const char *colname,
struct amount_msat *msat,
struct amount_msat def);
/** /**
* db_exec_prepared -- Execute a prepared statement * db_exec_prepared -- Execute a prepared statement
* *
@@ -202,7 +260,7 @@ bool db_exec_prepared_v2(struct db_stmt *stmt TAKES);
* After preparing a query using `db_prepare`, and after binding all non-null * After preparing a query using `db_prepare`, and after binding all non-null
* variables using the `db_bind_*` functions, it can be executed with this * variables using the `db_bind_*` functions, it can be executed with this
* function. This function must be called before calling `db_step` or any of * function. This function must be called before calling `db_step` or any of
* the `db_column_*` column access functions. * the `db_col_*` column access functions.
* *
* If you are not executing a read-only statement, please use * If you are not executing a read-only statement, please use
* `db_exec_prepared` instead. * `db_exec_prepared` instead.