database: pull out database code into a new module

We're going to reuse the database controllers for the accounting plugin
This commit is contained in:
niftynei
2022-01-03 12:45:35 -06:00
committed by Rusty Russell
parent 03c950bae8
commit ce12d2b8a9
37 changed files with 1478 additions and 1292 deletions

23
db/Makefile Normal file
View File

@@ -0,0 +1,23 @@
#! /usr/bin/make
DB_LIB_SRC := \
db/bindings.c \
db/exec.c \
db/utils.c
DB_DRIVERS := \
db/db_postgres.c \
db/db_sqlite3.c
DB_SRC := $(DB_LIB_SRC) $(DB_DRIVERS)
DB_HEADERS := $(DB_LIB_SRC:.c=.h) db/common.h
DB_OBJS := $(DB_LIB_SRC:.c=.o) $(DB_DRIVERS:.c=.o)
$(DB_OBJS): $(DB_HEADERS)
# Make sure these depend on everything.
ALL_C_SOURCES += $(DB_SRC)
ALL_C_HEADERS += $(DB_HEADERS)
# DB_SQL_FILES is the list of database files
DB_SQL_FILES := db/exec.c

554
db/bindings.c Normal file
View File

@@ -0,0 +1,554 @@
#include "config.h"
#include <bitcoin/privkey.h>
#include <bitcoin/psbt.h>
#include <ccan/mem/mem.h>
#include <ccan/take/take.h>
#include <ccan/tal/str/str.h>
#include <ccan/tal/tal.h>
#include <common/channel_id.h>
#include <common/htlc_state.h>
#include <common/node_id.h>
#include <common/onionreply.h>
#include <db/bindings.h>
#include <db/common.h>
#include <db/utils.h>
#define NSEC_IN_SEC 1000000000
/* Local helpers once you have column number */
static bool db_column_is_null(struct db_stmt *stmt, int col)
{
return stmt->db->config->column_is_null_fn(stmt, col);
}
/* Returns true (and warns) if it's nul */
static bool db_column_null_warn(struct db_stmt *stmt, const char *colname,
int col)
{
if (!db_column_is_null(stmt, col))
return false;
/* FIXME: log broken? */
#if DEVELOPER
db_fatal("Accessing a null column %s/%i in query %s",
colname, col, stmt->query->query);
#endif /* DEVELOPER */
return true;
}
void db_bind_int(struct db_stmt *stmt, int pos, int val)
{
assert(pos < tal_count(stmt->bindings));
memcheck(&val, sizeof(val));
stmt->bindings[pos].type = DB_BINDING_INT;
stmt->bindings[pos].v.i = val;
}
int db_col_int(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return 0;
return stmt->db->config->column_int_fn(stmt, col);
}
int db_col_is_null(struct db_stmt *stmt, const char *colname)
{
return db_column_is_null(stmt, db_query_colnum(stmt, colname));
}
void db_bind_null(struct db_stmt *stmt, int pos)
{
assert(pos < tal_count(stmt->bindings));
stmt->bindings[pos].type = DB_BINDING_NULL;
}
void db_bind_u64(struct db_stmt *stmt, int pos, u64 val)
{
memcheck(&val, sizeof(val));
assert(pos < tal_count(stmt->bindings));
stmt->bindings[pos].type = DB_BINDING_UINT64;
stmt->bindings[pos].v.u64 = val;
}
void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len)
{
assert(pos < tal_count(stmt->bindings));
stmt->bindings[pos].type = DB_BINDING_BLOB;
stmt->bindings[pos].v.blob = memcheck(val, len);
stmt->bindings[pos].len = len;
}
void db_bind_text(struct db_stmt *stmt, int pos, const char *val)
{
assert(pos < tal_count(stmt->bindings));
stmt->bindings[pos].type = DB_BINDING_TEXT;
stmt->bindings[pos].v.text = val;
stmt->bindings[pos].len = strlen(val);
}
void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p)
{
db_bind_blob(stmt, pos, p->r, sizeof(struct preimage));
}
void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s)
{
db_bind_blob(stmt, pos, s->u.u8, sizeof(struct sha256));
}
void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s)
{
db_bind_sha256(stmt, pos, &s->sha);
}
void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s)
{
assert(sizeof(s->data) == 32);
db_bind_blob(stmt, pos, s->data, sizeof(s->data));
}
void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s)
{
size_t num = tal_count(s), elsize = sizeof(s->data);
u8 *ser = tal_arr(stmt, u8, num * elsize);
for (size_t i = 0; i < num; ++i)
memcpy(ser + i * elsize, &s[i], elsize);
db_bind_blob(stmt, col, ser, tal_count(ser));
}
void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t)
{
db_bind_sha256d(stmt, pos, &t->shad);
}
void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id)
{
db_bind_blob(stmt, pos, id->id, sizeof(id->id));
}
void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *id)
{
db_bind_blob(stmt, pos, id->k, sizeof(id->k));
}
void db_bind_node_id_arr(struct db_stmt *stmt, int col,
const struct node_id *ids)
{
/* Copy into contiguous array: ARM will add padding to struct node_id! */
size_t n = tal_count(ids);
u8 *arr = tal_arr(stmt, u8, n * sizeof(ids[0].k));
for (size_t i = 0; i < n; ++i) {
assert(node_id_valid(&ids[i]));
memcpy(arr + sizeof(ids[i].k) * i,
ids[i].k,
sizeof(ids[i].k));
}
db_bind_blob(stmt, col, arr, tal_count(arr));
}
void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *pk)
{
u8 *der = tal_arr(stmt, u8, PUBKEY_CMPR_LEN);
pubkey_to_der(der, pk);
db_bind_blob(stmt, pos, der, PUBKEY_CMPR_LEN);
}
void db_bind_short_channel_id(struct db_stmt *stmt, int col,
const struct short_channel_id *id)
{
char *ser = short_channel_id_to_str(stmt, id);
db_bind_text(stmt, col, ser);
}
void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col,
const struct short_channel_id *id)
{
u8 *ser = tal_arr(stmt, u8, 0);
size_t num = tal_count(id);
for (size_t i = 0; i < num; ++i)
towire_short_channel_id(&ser, &id[i]);
db_bind_talarr(stmt, col, ser);
}
void db_bind_signature(struct db_stmt *stmt, int col,
const secp256k1_ecdsa_signature *sig)
{
u8 *buf = tal_arr(stmt, u8, 64);
int ret = secp256k1_ecdsa_signature_serialize_compact(secp256k1_ctx,
buf, sig);
assert(ret == 1);
db_bind_blob(stmt, col, buf, 64);
}
void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t)
{
u64 timestamp = t.ts.tv_nsec + (((u64) t.ts.tv_sec) * ((u64) NSEC_IN_SEC));
db_bind_u64(stmt, col, timestamp);
}
void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx)
{
u8 *ser = linearize_wtx(stmt, tx);
assert(ser);
db_bind_talarr(stmt, col, ser);
}
void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt)
{
size_t bytes_written;
const u8 *ser = psbt_get_bytes(stmt, psbt, &bytes_written);
assert(ser);
db_bind_blob(stmt, col, ser, bytes_written);
}
void db_bind_amount_msat(struct db_stmt *stmt, int pos,
const struct amount_msat *msat)
{
db_bind_u64(stmt, pos, msat->millisatoshis); /* Raw: low level function */
}
void db_bind_amount_sat(struct db_stmt *stmt, int pos,
const struct amount_sat *sat)
{
db_bind_u64(stmt, pos, sat->satoshis); /* Raw: low level function */
}
void db_bind_json_escape(struct db_stmt *stmt, int pos,
const struct json_escape *esc)
{
db_bind_text(stmt, pos, esc->s);
}
void db_bind_onionreply(struct db_stmt *stmt, int pos, const struct onionreply *r)
{
db_bind_talarr(stmt, pos, r->contents);
}
void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr)
{
if (!arr)
db_bind_null(stmt, col);
else
db_bind_blob(stmt, col, arr, tal_bytelen(arr));
}
static size_t db_column_bytes(struct db_stmt *stmt, int col)
{
if (db_column_is_null(stmt, col))
return 0;
return stmt->db->config->column_bytes_fn(stmt, col);
}
static const void *db_column_blob(struct db_stmt *stmt, int col)
{
if (db_column_is_null(stmt, col))
return NULL;
return stmt->db->config->column_blob_fn(stmt, col);
}
u64 db_col_u64(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return 0;
return stmt->db->config->column_u64_fn(stmt, col);
}
int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col))
return def;
else
return stmt->db->config->column_int_fn(stmt, col);
}
size_t db_col_bytes(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return 0;
return stmt->db->config->column_bytes_fn(stmt, col);
}
const void *db_col_blob(struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return NULL;
return stmt->db->config->column_blob_fn(stmt, col);
}
char *db_col_strdup(const tal_t *ctx,
struct db_stmt *stmt,
const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_null_warn(stmt, colname, col))
return NULL;
return tal_strdup(ctx, (char *)stmt->db->config->column_text_fn(stmt, col));
}
void db_col_preimage(struct db_stmt *stmt, const char *colname,
struct preimage *preimage)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct preimage);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(preimage, raw, size);
}
void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == sizeof(dest->id));
memcpy(dest->id, db_column_blob(stmt, col), sizeof(dest->id));
}
void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == sizeof(dest->k));
memcpy(dest->k, db_column_blob(stmt, col), sizeof(dest->k));
}
struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
struct node_id *ret;
size_t n = db_column_bytes(stmt, col) / sizeof(ret->k);
const u8 *arr = db_column_blob(stmt, col);
assert(n * sizeof(ret->k) == (size_t)db_column_bytes(stmt, col));
ret = tal_arr(ctx, struct node_id, n);
db_column_null_warn(stmt, colname, col);
for (size_t i = 0; i < n; i++)
memcpy(ret[i].k, arr + i * sizeof(ret[i].k), sizeof(ret[i].k));
return ret;
}
void db_col_pubkey(struct db_stmt *stmt,
const char *colname,
struct pubkey *dest)
{
size_t col = db_query_colnum(stmt, colname);
bool ok;
assert(db_column_bytes(stmt, col) == PUBKEY_CMPR_LEN);
ok = pubkey_from_der(db_column_blob(stmt, col), PUBKEY_CMPR_LEN, dest);
assert(ok);
}
/* Yes, we put this in as a string. Past mistakes; do not use! */
bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname,
struct short_channel_id *dest)
{
size_t col = db_query_colnum(stmt, colname);
const char *source = db_column_blob(stmt, col);
size_t sourcelen = db_column_bytes(stmt, col);
db_column_null_warn(stmt, colname, col);
return short_channel_id_from_str(source, sourcelen, dest);
}
struct short_channel_id *
db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *ser;
size_t len;
struct short_channel_id *ret;
db_column_null_warn(stmt, colname, col);
ser = db_column_blob(stmt, col);
len = db_column_bytes(stmt, col);
ret = tal_arr(ctx, struct short_channel_id, 0);
while (len != 0) {
struct short_channel_id scid;
fromwire_short_channel_id(&ser, &len, &scid);
tal_arr_expand(&ret, scid);
}
return ret;
}
bool db_col_signature(struct db_stmt *stmt, const char *colname,
secp256k1_ecdsa_signature *sig)
{
size_t col = db_query_colnum(stmt, colname);
assert(db_column_bytes(stmt, col) == 64);
return secp256k1_ecdsa_signature_parse_compact(
secp256k1_ctx, sig, db_column_blob(stmt, col)) == 1;
}
struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname)
{
struct timeabs t;
u64 timestamp = db_col_u64(stmt, colname);
t.ts.tv_sec = timestamp / NSEC_IN_SEC;
t.ts.tv_nsec = timestamp % NSEC_IN_SEC;
return t;
}
struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *src = db_column_blob(stmt, col);
size_t len = db_column_bytes(stmt, col);
db_column_null_warn(stmt, colname, col);
return pull_bitcoin_tx(ctx, &src, &len);
}
struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *src = db_column_blob(stmt, col);
size_t len = db_column_bytes(stmt, col);
db_column_null_warn(stmt, colname, col);
return psbt_from_bytes(ctx, src, len);
}
struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname)
{
struct wally_psbt *psbt = db_col_psbt(ctx, stmt, colname);
if (!psbt)
return NULL;
return bitcoin_tx_with_psbt(ctx, psbt);
}
void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname,
size_t bytes, const char *label, const char *caller)
{
size_t col = db_query_colnum(stmt, colname);
size_t sourcelen;
void *p;
if (db_column_is_null(stmt, col))
return NULL;
sourcelen = db_column_bytes(stmt, col);
if (sourcelen % bytes != 0)
db_fatal("%s: %s/%zu column size for %zu not a multiple of %s (%zu)",
caller, colname, col, sourcelen, label, bytes);
p = tal_arr_label(ctx, char, sourcelen, label);
memcpy(p, db_column_blob(stmt, col), sourcelen);
return p;
}
void db_col_amount_msat_or_default(struct db_stmt *stmt,
const char *colname,
struct amount_msat *msat,
struct amount_msat def)
{
size_t col = db_query_colnum(stmt, colname);
if (db_column_is_null(stmt, col))
*msat = def;
else
msat->millisatoshis = db_col_u64(stmt, colname); /* Raw: low level function */
}
void db_col_amount_msat(struct db_stmt *stmt, const char *colname,
struct amount_msat *msat)
{
msat->millisatoshis = db_col_u64(stmt, colname); /* Raw: low level function */
}
void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat)
{
sat->satoshis = db_col_u64(stmt, colname); /* Raw: low level function */
}
struct json_escape *db_col_json_escape(const tal_t *ctx,
struct db_stmt *stmt, const char *colname)
{
size_t col = db_query_colnum(stmt, colname);
return json_escape_string_(ctx, db_column_blob(stmt, col),
db_column_bytes(stmt, col));
}
void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct sha256);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(sha, raw, size);
}
void db_col_sha256d(struct db_stmt *stmt, const char *colname,
struct sha256_double *shad)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
size_t size = sizeof(struct sha256_double);
assert(db_column_bytes(stmt, col) == size);
raw = db_column_blob(stmt, col);
memcpy(shad, raw, size);
}
void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s)
{
size_t col = db_query_colnum(stmt, colname);
const u8 *raw;
assert(db_column_bytes(stmt, col) == sizeof(struct secret));
raw = db_column_blob(stmt, col);
memcpy(s, raw, sizeof(struct secret));
}
struct secret *db_col_secret_arr(const tal_t *ctx,
struct db_stmt *stmt,
const char *colname)
{
return db_col_arr(ctx, stmt, colname, struct secret);
}
void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t)
{
db_col_sha256d(stmt, colname, &t->shad);
}
struct onionreply *db_col_onionreply(const tal_t *ctx,
struct db_stmt *stmt, const char *colname)
{
struct onionreply *r = tal(ctx, struct onionreply);
r->contents = db_col_arr(ctx, stmt, colname, u8);
return r;
}
void db_col_ignore(struct db_stmt *stmt, const char *colname)
{
#if DEVELOPER
db_query_colnum(stmt, colname);
#endif
}

118
db/bindings.h Normal file
View File

@@ -0,0 +1,118 @@
#ifndef LIGHTNING_DB_BINDINGS_H
#define LIGHTNING_DB_BINDINGS_H
#include "config.h"
#include <bitcoin/preimage.h>
#include <bitcoin/pubkey.h>
#include <bitcoin/short_channel_id.h>
#include <bitcoin/tx.h>
#include <ccan/json_escape/json_escape.h>
#include <ccan/time/time.h>
struct channel_id;
struct db_stmt;
struct node_id;
struct onionreply;
struct wally_psbt;
struct wally_tx;
int db_col_is_null(struct db_stmt *stmt, const char *colname);
void db_bind_int(struct db_stmt *stmt, int pos, int val);
int db_col_int(struct db_stmt *stmt, const char *colname);
void db_bind_null(struct db_stmt *stmt, int pos);
void db_bind_int(struct db_stmt *stmt, int pos, int val);
void db_bind_u64(struct db_stmt *stmt, int pos, u64 val);
void db_bind_blob(struct db_stmt *stmt, int pos, const u8 *val, size_t len);
void db_bind_text(struct db_stmt *stmt, int pos, const char *val);
void db_bind_preimage(struct db_stmt *stmt, int pos, const struct preimage *p);
void db_bind_sha256(struct db_stmt *stmt, int pos, const struct sha256 *s);
void db_bind_sha256d(struct db_stmt *stmt, int pos, const struct sha256_double *s);
void db_bind_secret(struct db_stmt *stmt, int pos, const struct secret *s);
void db_bind_secret_arr(struct db_stmt *stmt, int col, const struct secret *s);
void db_bind_txid(struct db_stmt *stmt, int pos, const struct bitcoin_txid *t);
void db_bind_channel_id(struct db_stmt *stmt, int pos, const struct channel_id *id);
void db_bind_node_id(struct db_stmt *stmt, int pos, const struct node_id *ni);
void db_bind_node_id_arr(struct db_stmt *stmt, int col,
const struct node_id *ids);
void db_bind_pubkey(struct db_stmt *stmt, int pos, const struct pubkey *p);
void db_bind_short_channel_id(struct db_stmt *stmt, int col,
const struct short_channel_id *id);
void db_bind_short_channel_id_arr(struct db_stmt *stmt, int col,
const struct short_channel_id *id);
void db_bind_signature(struct db_stmt *stmt, int col,
const secp256k1_ecdsa_signature *sig);
void db_bind_timeabs(struct db_stmt *stmt, int col, struct timeabs t);
void db_bind_tx(struct db_stmt *stmt, int col, const struct wally_tx *tx);
void db_bind_psbt(struct db_stmt *stmt, int col, const struct wally_psbt *psbt);
void db_bind_amount_msat(struct db_stmt *stmt, int pos,
const struct amount_msat *msat);
void db_bind_amount_sat(struct db_stmt *stmt, int pos,
const struct amount_sat *sat);
void db_bind_json_escape(struct db_stmt *stmt, int pos,
const struct json_escape *esc);
void db_bind_onionreply(struct db_stmt *stmt, int col,
const struct onionreply *r);
void db_bind_talarr(struct db_stmt *stmt, int col, const u8 *arr);
/* Modern variants: get columns by name from SELECT */
/* Bridge function to get column number from SELECT
(must exist) */
size_t db_query_colnum(const struct db_stmt *stmt, const char *colname);
u64 db_col_u64(struct db_stmt *stmt, const char *colname);
size_t db_col_bytes(struct db_stmt *stmt, const char *colname);
const void* db_col_blob(struct db_stmt *stmt, const char *colname);
char *db_col_strdup(const tal_t *ctx,
struct db_stmt *stmt,
const char *colname);
void db_col_preimage(struct db_stmt *stmt, const char *colname, struct preimage *preimage);
void db_col_amount_msat(struct db_stmt *stmt, const char *colname, struct amount_msat *msat);
void db_col_amount_sat(struct db_stmt *stmt, const char *colname, struct amount_sat *sat);
struct json_escape *db_col_json_escape(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
void db_col_sha256(struct db_stmt *stmt, const char *colname, struct sha256 *sha);
void db_col_sha256d(struct db_stmt *stmt, const char *colname, struct sha256_double *shad);
void db_col_secret(struct db_stmt *stmt, const char *colname, struct secret *s);
struct secret *db_col_secret_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname);
void db_col_txid(struct db_stmt *stmt, const char *colname, struct bitcoin_txid *t);
void db_col_channel_id(struct db_stmt *stmt, const char *colname, struct channel_id *dest);
void db_col_node_id(struct db_stmt *stmt, const char *colname, struct node_id *ni);
struct node_id *db_col_node_id_arr(const tal_t *ctx, struct db_stmt *stmt,
const char *colname);
void db_col_pubkey(struct db_stmt *stmt, const char *colname,
struct pubkey *p);
bool db_col_short_channel_id_str(struct db_stmt *stmt, const char *colname,
struct short_channel_id *dest);
struct short_channel_id *
db_col_short_channel_id_arr(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
bool db_col_signature(struct db_stmt *stmt, const char *colname,
secp256k1_ecdsa_signature *sig);
struct timeabs db_col_timeabs(struct db_stmt *stmt, const char *colname);
struct bitcoin_tx *db_col_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct wally_psbt *db_col_psbt(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct bitcoin_tx *db_col_psbt_to_tx(const tal_t *ctx, struct db_stmt *stmt, const char *colname);
struct onionreply *db_col_onionreply(const tal_t *ctx,
struct db_stmt *stmt, const char *colname);
#define db_col_arr(ctx, stmt, colname, type) \
((type *)db_col_arr_((ctx), (stmt), (colname), \
sizeof(type), TAL_LABEL(type, "[]"), \
__func__))
void *db_col_arr_(const tal_t *ctx, struct db_stmt *stmt, const char *colname,
size_t bytes, const char *label, const char *caller);
/* Some useful default variants */
int db_col_int_or_default(struct db_stmt *stmt, const char *colname, int def);
void db_col_amount_msat_or_default(struct db_stmt *stmt, const char *colname,
struct amount_msat *msat,
struct amount_msat def);
/* Explicitly ignore a column (so we don't complain you didn't use it!) */
void db_col_ignore(struct db_stmt *stmt, const char *colname);
#endif /* LIGHTNING_DB_BINDINGS_H */

201
db/common.h Normal file
View File

@@ -0,0 +1,201 @@
#ifndef LIGHTNING_DB_COMMON_H
#define LIGHTNING_DB_COMMON_H
#include "config.h"
#include <ccan/list/list.h>
#include <ccan/short_types/short_types.h>
#include <ccan/strset/strset.h>
#include <common/autodata.h>
#include <common/utils.h>
/**
* Macro to annotate a named SQL query.
*
* This macro is used to annotate SQL queries that might need rewriting for
* different SQL dialects. It is used both as a marker for the query
* extraction logic in devtools/sql-rewrite.py to identify queries, as well as
* a way to swap out the query text with it's name so that the query execution
* engine can then look up the rewritten query using its name.
*
*/
#define NAMED_SQL(name,x) x
/**
* Simple annotation macro that auto-generates names for NAMED_SQL
*
* If this macro is changed it is likely that the extraction logic in
* devtools/sql-rewrite.py needs to change as well, since they need to
* generate identical names to work correctly.
*/
#define SQL(x) NAMED_SQL( __FILE__ ":" stringify(__COUNTER__), x)
struct db {
char *filename;
const char *in_transaction;
/* DB-specific context */
void *conn;
/* The configuration for the current database driver */
const struct db_config *config;
/* Translated queries for the current database domain + driver */
const struct db_query_set *queries;
const char **changes;
/* List of statements that have been created but not executed yet. */
struct list_head pending_statements;
char *error;
/* Were there any modifying statements in the current transaction?
* Used to bump the data_version in the DB.*/
bool dirty;
/* The current DB version we expect to update if changes are
* committed. */
u32 data_version;
void (*report_changes_fn)(struct db *);
};
struct db_query {
const char *name;
const char *query;
/* How many placeholders are in the query (and how many will we have
to allocate when instantiating this query)? */
size_t placeholders;
/* Is this a read-only query? If it is there's no need to tell plugins
* about it. */
bool readonly;
/* If this is a select statement, what column names */
const struct sqlname_map *colnames;
size_t num_colnames;
};
enum db_binding_type {
DB_BINDING_UNINITIALIZED = 0,
DB_BINDING_NULL,
DB_BINDING_BLOB,
DB_BINDING_TEXT,
DB_BINDING_UINT64,
DB_BINDING_INT,
};
struct db_binding {
enum db_binding_type type;
union {
s32 i;
u64 u64;
const char* text;
const u8 *blob;
} v;
size_t len;
};
struct db_stmt {
/* Our entry in the list of pending statements. */
struct list_node list;
/* Database we are querying */
struct db *db;
/* Which SQL statement are we trying to execute? */
const struct db_query *query;
/* Which parameters are we binding to the statement? */
struct db_binding *bindings;
/* Where are we calling this statement from? */
const char *location;
const char *error;
/* Pointer to DB-specific statement. */
void *inner_stmt;
bool executed;
int row;
#if DEVELOPER
/* Map as we reference into a SELECT statement in query. */
struct strset *cols_used;
#endif
};
struct db_query_set {
const char *name;
const struct db_query *query_table;
size_t query_table_size;
};
struct db_config {
const char *name;
/* Function used to execute a statement that doesn't result in a
* response. */
bool (*exec_fn)(struct db_stmt *stmt);
/* Function to execute a query that will result in a response. */
bool (*query_fn)(struct db_stmt *stmt);
/* Function used to step forwards through query results. Returns
* `false` if there are no more rows to return. */
bool (*step_fn)(struct db_stmt *stmt);
bool (*begin_tx_fn)(struct db *db);
bool (*commit_tx_fn)(struct db *db);
/* The free function must make sure that any associated state stored
* in `stmt->inner_stmt` is freed correctly, setting the pointer to
* NULL after cleaning up. It will ultmately be called by the
* destructor of `struct db_stmt`, before clearing the db_stmt
* itself. */
void (*stmt_free_fn)(struct db_stmt *db_stmt);
/* Column access in a row. Only covers the primitives, others need to
* use these internally to translate (hence the non-allocating
* column_{text,blob}_fn since most other types want in place
* assignment. */
bool (*column_is_null_fn)(struct db_stmt *stmt, int col);
u64 (*column_u64_fn)(struct db_stmt *stmt, int col);
size_t (*column_bytes_fn)(struct db_stmt *stmt, int col);
const void *(*column_blob_fn)(struct db_stmt *stmt, int col);
const unsigned char *(*column_text_fn)(struct db_stmt *stmt, int col);
s64 (*column_int_fn)(struct db_stmt *stmt, int col);
u64 (*last_insert_id_fn)(struct db_stmt *stmt);
size_t (*count_changes_fn)(struct db_stmt *stmt);
bool (*setup_fn)(struct db *db);
void (*teardown_fn)(struct db *db);
bool (*vacuum_fn)(struct db *db);
bool (*rename_column)(struct db *db,
const char *tablename,
const char *from, const char *to);
bool (*delete_columns)(struct db *db,
const char *tablename,
const char **colnames, size_t num_cols);
};
void db_fatal(const char *fmt, ...)
PRINTF_FMT(1, 2);
/* Provide a way for DB backends to register themselves */
AUTODATA_TYPE(db_backends, struct db_config);
/* Provide a way for DB query sets to register themselves */
AUTODATA_TYPE(db_queries, struct db_query_set);
/* devtools/sql-rewrite.py generates this simple htable */
struct sqlname_map {
const char *sqlname;
int val;
};
#endif /* LIGHTNING_DB_COMMON_H */

348
db/db_postgres.c Normal file
View File

@@ -0,0 +1,348 @@
#include "config.h"
#include <ccan/ccan/tal/str/str.h>
#include <ccan/endian/endian.h>
#include <db/common.h>
#include <db/utils.h>
#if HAVE_POSTGRES
/* Indented in order not to trigger the inclusion order check */
#include <libpq-fe.h>
/* Cherry-picked from here: libpq/src/interfaces/ecpg/ecpglib/pg_type.h */
#define BYTEAOID 17
#define INT8OID 20
#define INT4OID 23
#define TEXTOID 25
static bool db_postgres_setup(struct db *db)
{
size_t prefix_len = strlen("postgres://");
/* We attempt to parse the connection string without the `postgres://`
prefix first, so we can correctly handle the key-value-pair style of
DSN that postgresql supports. If that fails we try with the full
string, which matches the `scheme://user:password@host:port/dbname`
style of DSNs. The call to `PQconninfoParse` here is just to verify
`PQconnectdb` would be able to parse it correctly, that's why the
result is discarded again immediately. */
PQconninfoOption *info =
PQconninfoParse(db->filename + prefix_len, NULL);
if (info != NULL) {
PQconninfoFree(info);
db->conn = PQconnectdb(db->filename + prefix_len);
} else {
db->conn = PQconnectdb(db->filename);
}
if (PQstatus(db->conn) != CONNECTION_OK) {
db->error = tal_fmt(db, "Could not connect to %s: %s", db->filename, PQerrorMessage(db->conn));
db->conn = NULL;
return false;
}
return true;
}
static bool db_postgres_begin_tx(struct db *db)
{
assert(db->conn);
PGresult *res;
res = PQexec(db->conn, "BEGIN;");
if (PQresultStatus(res) != PGRES_COMMAND_OK) {
db->error = tal_fmt(db, "BEGIN command failed: %s",
PQerrorMessage(db->conn));
PQclear(res);
return false;
}
PQclear(res);
return true;
}
static bool db_postgres_commit_tx(struct db *db)
{
assert(db->conn);
PGresult *res;
res = PQexec(db->conn, "COMMIT;");
if (PQresultStatus(res) != PGRES_COMMAND_OK) {
db->error = tal_fmt(db, "COMMIT command failed: %s",
PQerrorMessage(db->conn));
PQclear(res);
return false;
}
PQclear(res);
return true;
}
static PGresult *db_postgres_do_exec(struct db_stmt *stmt)
{
int slots = stmt->query->placeholders;
const char *paramValues[slots];
int paramLengths[slots];
int paramFormats[slots];
Oid paramTypes[slots];
int resultFormat = 1; /* We always want binary results. */
/* Since we pass in raw pointers to elements converted to network
* byte-order we need a place to temporarily stash them. */
s32 ints[slots];
u64 u64s[slots];
for (size_t i=0; i<slots; i++) {
struct db_binding *b = &stmt->bindings[i];
switch (b->type) {
case DB_BINDING_UNINITIALIZED:
db_fatal("DB binding not initialized: position=%zu, "
"query=\"%s\n",
i, stmt->query->query);
case DB_BINDING_UINT64:
paramLengths[i] = 8;
paramFormats[i] = 1;
u64s[i] = cpu_to_be64(b->v.u64);
paramValues[i] = (char*)&u64s[i];
paramTypes[i] = INT8OID;
break;
case DB_BINDING_INT:
paramLengths[i] = 4;
paramFormats[i] = 1;
ints[i] = cpu_to_be32(b->v.i);
paramValues[i] = (char*)&ints[i];
paramTypes[i] = INT4OID;
break;
case DB_BINDING_BLOB:
paramLengths[i] = b->len;
paramFormats[i] = 1;
paramValues[i] = (char*)b->v.blob;
paramTypes[i] = BYTEAOID;
break;
case DB_BINDING_TEXT:
paramLengths[i] = b->len;
paramFormats[i] = 1;
paramValues[i] = (char*)b->v.text;
paramTypes[i] = TEXTOID;
break;
case DB_BINDING_NULL:
paramLengths[i] = 0;
paramFormats[i] = 1;
paramValues[i] = NULL;
paramTypes[i] = 0;
break;
}
}
return PQexecParams(stmt->db->conn, stmt->query->query, slots,
paramTypes, paramValues, paramLengths, paramFormats,
resultFormat);
}
static bool db_postgres_query(struct db_stmt *stmt)
{
stmt->inner_stmt = db_postgres_do_exec(stmt);
int res;
res = PQresultStatus(stmt->inner_stmt);
if (res != PGRES_EMPTY_QUERY && res != PGRES_TUPLES_OK) {
stmt->error = PQerrorMessage(stmt->db->conn);
PQclear(stmt->inner_stmt);
stmt->inner_stmt = NULL;
return false;
}
stmt->row = -1;
return true;
}
static bool db_postgres_step(struct db_stmt *stmt)
{
stmt->row++;
if (stmt->row >= PQntuples(stmt->inner_stmt)) {
return false;
}
return true;
}
static bool db_postgres_column_is_null(struct db_stmt *stmt, int col)
{
PGresult *res = (PGresult*)stmt->inner_stmt;
return PQgetisnull(res, stmt->row, col);
}
static u64 db_postgres_column_u64(struct db_stmt *stmt, int col)
{
PGresult *res = (PGresult*)stmt->inner_stmt;
be64 bin;
size_t expected = sizeof(bin), actual = PQgetlength(res, stmt->row, col);
if (expected != actual)
db_fatal(
"u64 field doesn't match size: expected %zu, actual %zu\n",
expected, actual);
memcpy(&bin, PQgetvalue(res, stmt->row, col), sizeof(bin));
return be64_to_cpu(bin);
}
static s64 db_postgres_column_int(struct db_stmt *stmt, int col)
{
PGresult *res = (PGresult*)stmt->inner_stmt;
be32 bin;
size_t expected = sizeof(bin), actual = PQgetlength(res, stmt->row, col);
if (expected != actual)
db_fatal(
"s32 field doesn't match size: expected %zu, actual %zu\n",
expected, actual);
memcpy(&bin, PQgetvalue(res, stmt->row, col), sizeof(bin));
return be32_to_cpu(bin);
}
static size_t db_postgres_column_bytes(struct db_stmt *stmt, int col)
{
PGresult *res = (PGresult *)stmt->inner_stmt;
return PQgetlength(res, stmt->row, col);
}
static const void *db_postgres_column_blob(struct db_stmt *stmt, int col)
{
PGresult *res = (PGresult*)stmt->inner_stmt;
return PQgetvalue(res, stmt->row, col);
}
static const unsigned char *db_postgres_column_text(struct db_stmt *stmt, int col)
{
PGresult *res = (PGresult*)stmt->inner_stmt;
return (unsigned char*)PQgetvalue(res, stmt->row, col);
}
static void db_postgres_stmt_free(struct db_stmt *stmt)
{
if (stmt->inner_stmt)
PQclear(stmt->inner_stmt);
stmt->inner_stmt = NULL;
}
static bool db_postgres_exec(struct db_stmt *stmt)
{
bool ok;
stmt->inner_stmt = db_postgres_do_exec(stmt);
ok = PQresultStatus(stmt->inner_stmt) == PGRES_COMMAND_OK;
if (!ok)
stmt->error = PQerrorMessage(stmt->db->conn);
return ok;
}
static u64 db_postgres_last_insert_id(struct db_stmt *stmt)
{
PGresult *res = PQexec(stmt->db->conn, "SELECT lastval()");
int id = atoi(PQgetvalue(res, 0, 0));
PQclear(res);
return id;
}
static size_t db_postgres_count_changes(struct db_stmt *stmt)
{
PGresult *res = (PGresult*)stmt->inner_stmt;
char *count = PQcmdTuples(res);
return atoi(count);
}
static void db_postgres_teardown(struct db *db)
{
}
static bool db_postgres_vacuum(struct db *db)
{
PGresult *res;
#if DEVELOPER
/* This can use a lot of diskspacem breaking CI! */
if (getenv("LIGHTNINGD_POSTGRES_NO_VACUUM")
&& streq(getenv("LIGHTNINGD_POSTGRES_NO_VACUUM"), "1"))
return true;
#endif
res = PQexec(db->conn, "VACUUM FULL;");
if (PQresultStatus(res) != PGRES_COMMAND_OK) {
db->error = tal_fmt(db, "VACUUM command failed: %s",
PQerrorMessage(db->conn));
PQclear(res);
return false;
}
PQclear(res);
return true;
}
static bool db_postgres_rename_column(struct db *db,
const char *tablename,
const char *from, const char *to)
{
PGresult *res;
char *cmd;
cmd = tal_fmt(db, "ALTER TABLE %s RENAME %s TO %s;",
tablename, from, to);
res = PQexec(db->conn, cmd);
if (PQresultStatus(res) != PGRES_COMMAND_OK) {
db->error = tal_fmt(db, "Rename '%s' failed: %s",
cmd, PQerrorMessage(db->conn));
PQclear(res);
return false;
}
PQclear(res);
return true;
}
static bool db_postgres_delete_columns(struct db *db,
const char *tablename,
const char **colnames, size_t num_cols)
{
PGresult *res;
char *cmd;
cmd = tal_fmt(db, "ALTER TABLE %s ", tablename);
for (size_t i = 0; i < num_cols; i++) {
if (i != 0)
tal_append_fmt(&cmd, ", ");
tal_append_fmt(&cmd, "DROP %s", colnames[i]);
}
tal_append_fmt(&cmd, ";");
res = PQexec(db->conn, cmd);
if (PQresultStatus(res) != PGRES_COMMAND_OK) {
db->error = tal_fmt(db, "Delete '%s' failed: %s",
cmd, PQerrorMessage(db->conn));
PQclear(res);
return false;
}
PQclear(res);
return true;
}
struct db_config db_postgres_config = {
.name = "postgres",
.exec_fn = db_postgres_exec,
.query_fn = db_postgres_query,
.step_fn = db_postgres_step,
.begin_tx_fn = &db_postgres_begin_tx,
.commit_tx_fn = &db_postgres_commit_tx,
.stmt_free_fn = db_postgres_stmt_free,
.column_is_null_fn = db_postgres_column_is_null,
.column_u64_fn = db_postgres_column_u64,
.column_int_fn = db_postgres_column_int,
.column_bytes_fn = db_postgres_column_bytes,
.column_blob_fn = db_postgres_column_blob,
.column_text_fn = db_postgres_column_text,
.last_insert_id_fn = db_postgres_last_insert_id,
.count_changes_fn = db_postgres_count_changes,
.setup_fn = db_postgres_setup,
.teardown_fn = db_postgres_teardown,
.vacuum_fn = db_postgres_vacuum,
.rename_column = db_postgres_rename_column,
.delete_columns = db_postgres_delete_columns,
};
AUTODATA(db_backends, &db_postgres_config);
#endif /* HAVE_POSTGRES */

711
db/db_sqlite3.c Normal file
View File

@@ -0,0 +1,711 @@
#include "config.h"
#include <ccan/ccan/tal/str/str.h>
#include <common/utils.h>
#include <db/common.h>
#include <db/utils.h>
#if HAVE_SQLITE3
#include <sqlite3.h>
struct db_sqlite3 {
/* The actual db connection. */
sqlite3 *conn;
/* A replica db connection, if requested, or NULL otherwise. */
sqlite3 *backup_conn;
};
/**
* @param conn: The db->conn void * pointer.
*
* @return the actual sqlite3 connection.
*/
static inline
sqlite3 *conn2sql(void *conn)
{
struct db_sqlite3 *wrapper = (struct db_sqlite3 *) conn;
return wrapper->conn;
}
static void replicate_statement(struct db_sqlite3 *wrapper,
const char *qry)
{
sqlite3_stmt *stmt;
int err;
if (!wrapper->backup_conn)
return;
sqlite3_prepare_v2(wrapper->backup_conn,
qry, -1, &stmt, NULL);
err = sqlite3_step(stmt);
sqlite3_finalize(stmt);
if (err != SQLITE_DONE)
db_fatal("Failed to replicate query: %s: %s: %s",
sqlite3_errstr(err),
sqlite3_errmsg(wrapper->backup_conn),
qry);
}
static void db_sqlite3_changes_add(struct db_sqlite3 *wrapper,
struct db_stmt *stmt,
const char *qry)
{
replicate_statement(wrapper, qry);
db_changes_add(stmt, qry);
}
/* Check if both sqlite3 databases have a data_version variable,
* *and* are the same.
*/
static bool have_same_data_version(sqlite3 *a, sqlite3 *b)
{
sqlite3_stmt *stmt;
const char *qry = "SELECT intval FROM vars"
" WHERE name = 'data_version';";
int err;
u64 version_a;
u64 version_b;
sqlite3_prepare_v2(a, qry, -1, &stmt, NULL);
err = sqlite3_step(stmt);
if (err != SQLITE_ROW) {
sqlite3_finalize(stmt);
return false;
}
version_a = sqlite3_column_int64(stmt, 0);
sqlite3_finalize(stmt);
sqlite3_prepare_v2(b, qry, -1, &stmt, NULL);
err = sqlite3_step(stmt);
if (err != SQLITE_ROW) {
sqlite3_finalize(stmt);
return false;
}
version_b = sqlite3_column_int64(stmt, 0);
sqlite3_finalize(stmt);
return version_a == version_b;
}
#if !HAVE_SQLITE3_EXPANDED_SQL
/* Prior to sqlite3 v3.14, we have to use tracing to dump statements */
struct db_sqlite3_trace {
struct db_sqlite3 *wrapper;
struct db_stmt *stmt;
};
static void trace_sqlite3(void *stmtv, const char *stmt)
{
struct db_sqlite3_trace *trace = (struct db_sqlite3_trace *)stmtv;
struct db_sqlite3 *wrapper = trace->wrapper;
struct db_stmt *s = trace->stmt;
db_sqlite3_changes_add(wrapper, s, stmt);
}
#endif
static const char *db_sqlite3_fmt_error(struct db_stmt *stmt)
{
return tal_fmt(stmt, "%s: %s: %s", stmt->location, stmt->query->query,
sqlite3_errmsg(conn2sql(stmt->db->conn)));
}
static bool db_sqlite3_setup(struct db *db)
{
char *filename;
char *sep;
char *backup_filename = NULL;
sqlite3_stmt *stmt;
sqlite3 *sql;
int err, flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE;
struct db_sqlite3 *wrapper;
if (!strstarts(db->filename, "sqlite3://") || strlen(db->filename) < 10)
db_fatal("Could not parse the wallet DSN: %s", db->filename);
/* Strip the scheme from the dsn. */
filename = db->filename + strlen("sqlite3://");
/* Look for a replica specification. */
sep = strchr(filename, ':');
if (sep) {
/* Split at ':'. */
filename = tal_strndup(db, filename, sep - filename);
backup_filename = tal_strdup(db, sep + 1);
}
wrapper = tal(db, struct db_sqlite3);
db->conn = wrapper;
err = sqlite3_open_v2(filename, &sql, flags, NULL);
if (err != SQLITE_OK) {
db_fatal("failed to open database %s: %s", filename,
sqlite3_errstr(err));
}
wrapper->conn = sql;
err = sqlite3_extended_result_codes(wrapper->conn, 1);
if (err != SQLITE_OK) {
db_fatal("failed to enable extended result codes: %s",
sqlite3_errstr(err));
}
if (!backup_filename)
wrapper->backup_conn = NULL;
else {
err = sqlite3_open_v2(backup_filename,
&wrapper->backup_conn,
flags, NULL);
if (err != SQLITE_OK) {
db_fatal("failed to open backup database %s: %s",
backup_filename,
sqlite3_errstr(err));
}
sqlite3_prepare_v2(wrapper->backup_conn,
"PRAGMA foreign_keys = ON;", -1, &stmt,
NULL);
err = sqlite3_step(stmt);
sqlite3_finalize(stmt);
if (err != SQLITE_DONE) {
db_fatal("failed to use backup database %s: %s",
backup_filename,
sqlite3_errstr(err));
}
}
/* If we have a backup db, but it does not have a matching
* data_version, copy over the main database. */
if (wrapper->backup_conn &&
!have_same_data_version(wrapper->conn, wrapper->backup_conn)) {
/* Copy the main database over the backup database. */
sqlite3_backup *copier = sqlite3_backup_init(wrapper->backup_conn,
"main",
wrapper->conn,
"main");
if (!copier) {
db_fatal("failed to initiate copy to %s: %s",
backup_filename,
sqlite3_errmsg(wrapper->backup_conn));
}
err = sqlite3_backup_step(copier, -1);
if (err != SQLITE_DONE) {
db_fatal("failed to copy database to %s: %s",
backup_filename,
sqlite3_errstr(err));
}
sqlite3_backup_finish(copier);
}
/* In case another process (litestream?) grabs a lock, we don't
* want to return SQLITE_BUSY immediately (which will cause a
* fatal error): give it 60 seconds.
* We *could* make this an option, but surely the user prefers a
* long timeout over an outright crash.
*/
sqlite3_busy_timeout(conn2sql(db->conn), 60000);
sqlite3_prepare_v2(conn2sql(db->conn),
"PRAGMA foreign_keys = ON;", -1, &stmt, NULL);
err = sqlite3_step(stmt);
sqlite3_finalize(stmt);
return err == SQLITE_DONE;
}
static bool db_sqlite3_query(struct db_stmt *stmt)
{
sqlite3_stmt *s;
sqlite3 *conn = conn2sql(stmt->db->conn);
int err;
err = sqlite3_prepare_v2(conn, stmt->query->query, -1, &s, NULL);
for (size_t i=0; i<stmt->query->placeholders; i++) {
struct db_binding *b = &stmt->bindings[i];
/* sqlite3 uses printf-like offsets, we don't... */
int pos = i+1;
switch (b->type) {
case DB_BINDING_UNINITIALIZED:
db_fatal("DB binding not initialized: position=%zu, "
"query=\"%s\n",
i, stmt->query->query);
case DB_BINDING_UINT64:
sqlite3_bind_int64(s, pos, b->v.u64);
break;
case DB_BINDING_INT:
sqlite3_bind_int(s, pos, b->v.i);
break;
case DB_BINDING_BLOB:
sqlite3_bind_blob(s, pos, b->v.blob, b->len,
SQLITE_TRANSIENT);
break;
case DB_BINDING_TEXT:
sqlite3_bind_text(s, pos, b->v.text, b->len,
SQLITE_TRANSIENT);
break;
case DB_BINDING_NULL:
sqlite3_bind_null(s, pos);
break;
}
}
if (err != SQLITE_OK) {
tal_free(stmt->error);
stmt->error = db_sqlite3_fmt_error(stmt);
return false;
}
stmt->inner_stmt = s;
return true;
}
static bool db_sqlite3_exec(struct db_stmt *stmt)
{
int err;
bool success;
struct db_sqlite3 *wrapper = (struct db_sqlite3 *) stmt->db->conn;
#if !HAVE_SQLITE3_EXPANDED_SQL
/* Register the tracing function if we don't have an explicit way of
* expanding the statement. */
struct db_sqlite3_trace trace;
trace.wrapper = wrapper;
trace.stmt = stmt;
sqlite3_trace(conn2sql(stmt->db->conn), trace_sqlite3, &trace);
#endif
if (!db_sqlite3_query(stmt)) {
/* If the prepare step caused an error we hand it up. */
success = false;
goto done;
}
err = sqlite3_step(stmt->inner_stmt);
if (err != SQLITE_DONE) {
tal_free(stmt->error);
stmt->error = db_sqlite3_fmt_error(stmt);
success = false;
goto done;
}
#if HAVE_SQLITE3_EXPANDED_SQL
/* Manually expand and call the callback */
char *expanded_sql;
expanded_sql = sqlite3_expanded_sql(stmt->inner_stmt);
db_sqlite3_changes_add(wrapper, stmt, expanded_sql);
sqlite3_free(expanded_sql);
#endif
success = true;
done:
#if !HAVE_SQLITE3_EXPANDED_SQL
/* Unregister the trace callback to avoid it accessing the potentially
* stale pointer to stmt */
sqlite3_trace(conn2sql(stmt->db->conn), NULL, NULL);
#endif
return success;
}
static bool db_sqlite3_step(struct db_stmt *stmt)
{
sqlite3_stmt *s = (sqlite3_stmt*)stmt->inner_stmt;
return sqlite3_step(s) == SQLITE_ROW;
}
static bool db_sqlite3_begin_tx(struct db *db)
{
int err;
char *errmsg;
struct db_sqlite3 *wrapper = (struct db_sqlite3 *) db->conn;
err = sqlite3_exec(conn2sql(db->conn),
"BEGIN TRANSACTION;", NULL, NULL, &errmsg);
if (err != SQLITE_OK) {
db->error = tal_fmt(db, "Failed to begin a transaction: %s", errmsg);
return false;
}
replicate_statement(wrapper, "BEGIN TRANSACTION;");
return true;
}
static bool db_sqlite3_commit_tx(struct db *db)
{
int err;
char *errmsg;
struct db_sqlite3 *wrapper = (struct db_sqlite3 *) db->conn;
err = sqlite3_exec(conn2sql(db->conn),
"COMMIT;", NULL, NULL, &errmsg);
if (err != SQLITE_OK) {
db->error = tal_fmt(db, "Failed to commit a transaction: %s", errmsg);
return false;
}
replicate_statement(wrapper, "COMMIT;");
return true;
}
static bool db_sqlite3_column_is_null(struct db_stmt *stmt, int col)
{
sqlite3_stmt *s = (sqlite3_stmt*)stmt->inner_stmt;
return sqlite3_column_type(s, col) == SQLITE_NULL;
}
static u64 db_sqlite3_column_u64(struct db_stmt *stmt, int col)
{
sqlite3_stmt *s = (sqlite3_stmt*)stmt->inner_stmt;
return sqlite3_column_int64(s, col);
}
static s64 db_sqlite3_column_int(struct db_stmt *stmt, int col)
{
sqlite3_stmt *s = (sqlite3_stmt*)stmt->inner_stmt;
return sqlite3_column_int(s, col);
}
static size_t db_sqlite3_column_bytes(struct db_stmt *stmt, int col)
{
sqlite3_stmt *s = (sqlite3_stmt*)stmt->inner_stmt;
return sqlite3_column_bytes(s, col);
}
static const void *db_sqlite3_column_blob(struct db_stmt *stmt, int col)
{
sqlite3_stmt *s = (sqlite3_stmt*)stmt->inner_stmt;
return sqlite3_column_blob(s, col);
}
static const unsigned char *db_sqlite3_column_text(struct db_stmt *stmt, int col)
{
sqlite3_stmt *s = (sqlite3_stmt*)stmt->inner_stmt;
return sqlite3_column_text(s, col);
}
static void db_sqlite3_stmt_free(struct db_stmt *stmt)
{
if (stmt->inner_stmt)
sqlite3_finalize(stmt->inner_stmt);
stmt->inner_stmt = NULL;
}
static size_t db_sqlite3_count_changes(struct db_stmt *stmt)
{
sqlite3 *s = conn2sql(stmt->db->conn);
return sqlite3_changes(s);
}
static void db_sqlite3_close(struct db *db)
{
struct db_sqlite3 *wrapper = (struct db_sqlite3 *) db->conn;
if (wrapper->backup_conn)
sqlite3_close(wrapper->backup_conn);
sqlite3_close(wrapper->conn);
db->conn = tal_free(db->conn);
}
static u64 db_sqlite3_last_insert_id(struct db_stmt *stmt)
{
sqlite3 *s = conn2sql(stmt->db->conn);
return sqlite3_last_insert_rowid(s);
}
static bool db_sqlite3_vacuum(struct db *db)
{
int err;
sqlite3_stmt *stmt;
struct db_sqlite3 *wrapper = (struct db_sqlite3 *) db->conn;
sqlite3_prepare_v2(conn2sql(db->conn), "VACUUM;", -1, &stmt, NULL);
err = sqlite3_step(stmt);
if (err != SQLITE_DONE)
db->error = tal_fmt(db, "%s",
sqlite3_errmsg(conn2sql(db->conn)));
sqlite3_finalize(stmt);
replicate_statement(wrapper, "VACUUM;");
return err == SQLITE_DONE;
}
static bool colname_to_delete(const char **colnames,
size_t num_colnames,
const char *columnname)
{
for (size_t i = 0; i < num_colnames; i++) {
if (streq(columnname, colnames[i]))
return true;
}
return false;
}
static const char *find_column_name(const tal_t *ctx,
const char *sqlpart,
size_t *after)
{
size_t start = 0;
while (isspace(sqlpart[start]))
start++;
*after = strspn(sqlpart + start, "abcdefghijklmnopqrstuvwxyz_0123456789") + start;
if (*after == start)
return NULL;
return tal_strndup(ctx, sqlpart + start, *after - start);
}
/* Move table out the way, return columns */
static char **prepare_table_manip(const tal_t *ctx,
struct db *db, const char *tablename)
{
sqlite3_stmt *stmt;
const char *sql;
char *cmd, *bracket;
char **parts;
int err;
struct db_sqlite3 *wrapper = (struct db_sqlite3 *)db->conn;
/* Get schema. */
sqlite3_prepare_v2(wrapper->conn, "SELECT sql FROM sqlite_master WHERE type = ? AND name = ?;", -1, &stmt, NULL);
sqlite3_bind_text(stmt, 1, "table", strlen("table"), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 2, tablename, strlen(tablename), SQLITE_TRANSIENT);
err = sqlite3_step(stmt);
if (err != SQLITE_ROW) {
db->error = tal_fmt(db, "getting schema: %s",
sqlite3_errmsg(wrapper->conn));
sqlite3_finalize(stmt);
return NULL;
}
sql = tal_strdup(tmpctx, (const char *)sqlite3_column_text(stmt, 0));
sqlite3_finalize(stmt);
bracket = strchr(sql, '(');
if (!strstarts(sql, "CREATE TABLE") || !bracket) {
db->error = tal_fmt(db, "strange schema for %s: %s",
tablename, sql);
return NULL;
}
/* Split after ( by commas: any lower case is assumed to be a field */
parts = tal_strsplit(ctx, bracket + 1, ",", STR_EMPTY_OK);
/* Turn off foreign keys first. */
sqlite3_prepare_v2(wrapper->conn, "PRAGMA foreign_keys = OFF;", -1, &stmt, NULL);
if (sqlite3_step(stmt) != SQLITE_DONE)
goto sqlite_stmt_err;
sqlite3_finalize(stmt);
cmd = tal_fmt(tmpctx, "ALTER TABLE %s RENAME TO temp_%s;",
tablename, tablename);
sqlite3_prepare_v2(wrapper->conn, cmd, -1, &stmt, NULL);
if (sqlite3_step(stmt) != SQLITE_DONE)
goto sqlite_stmt_err;
sqlite3_finalize(stmt);
/* Make sure we do the same to backup! */
replicate_statement(wrapper, "PRAGMA foreign_keys = OFF;");
replicate_statement(wrapper, cmd);
return parts;
sqlite_stmt_err:
db->error = tal_fmt(db, "%s", sqlite3_errmsg(wrapper->conn));
sqlite3_finalize(stmt);
return tal_free(parts);
}
static bool complete_table_manip(struct db *db,
const char *tablename,
const char **coldefs,
const char **oldcolnames)
{
sqlite3_stmt *stmt;
char *create_cmd, *insert_cmd, *drop_cmd;
struct db_sqlite3 *wrapper = (struct db_sqlite3 *)db->conn;
/* Create table */
create_cmd = tal_fmt(tmpctx, "CREATE TABLE %s (", tablename);
for (size_t i = 0; i < tal_count(coldefs); i++) {
if (i != 0)
tal_append_fmt(&create_cmd, ", ");
tal_append_fmt(&create_cmd, "%s", coldefs[i]);
}
tal_append_fmt(&create_cmd, ";");
sqlite3_prepare_v2(wrapper->conn, create_cmd, -1, &stmt, NULL);
if (sqlite3_step(stmt) != SQLITE_DONE)
goto sqlite_stmt_err;
sqlite3_finalize(stmt);
/* Make sure we do the same to backup! */
replicate_statement(wrapper, create_cmd);
/* Populate table from old one */
insert_cmd = tal_fmt(tmpctx, "INSERT INTO %s SELECT ", tablename);
for (size_t i = 0; i < tal_count(oldcolnames); i++) {
if (i != 0)
tal_append_fmt(&insert_cmd, ", ");
tal_append_fmt(&insert_cmd, "%s", oldcolnames[i]);
}
tal_append_fmt(&insert_cmd, " FROM temp_%s;", tablename);
sqlite3_prepare_v2(wrapper->conn, insert_cmd, -1, &stmt, NULL);
if (sqlite3_step(stmt) != SQLITE_DONE)
goto sqlite_stmt_err;
sqlite3_finalize(stmt);
replicate_statement(wrapper, insert_cmd);
/* Cleanup temp table */
drop_cmd = tal_fmt(tmpctx, "DROP TABLE temp_%s;", tablename);
sqlite3_prepare_v2(wrapper->conn, drop_cmd, -1, &stmt, NULL);
if (sqlite3_step(stmt) != SQLITE_DONE)
goto sqlite_stmt_err;
sqlite3_finalize(stmt);
replicate_statement(wrapper, drop_cmd);
/* Allow links between them (esp. cascade deletes!) */
sqlite3_prepare_v2(wrapper->conn, "PRAGMA foreign_keys = ON;", -1, &stmt, NULL);
if (sqlite3_step(stmt) != SQLITE_DONE)
goto sqlite_stmt_err;
sqlite3_finalize(stmt);
replicate_statement(wrapper, "PRAGMA foreign_keys = ON;");
return true;
sqlite_stmt_err:
db->error = tal_fmt(db, "%s", sqlite3_errmsg(wrapper->conn));
sqlite3_finalize(stmt);
return false;
}
static bool db_sqlite3_rename_column(struct db *db,
const char *tablename,
const char *from, const char *to)
{
char **parts;
const char **coldefs, **oldcolnames;
bool colname_found = false;
parts = prepare_table_manip(tmpctx, db, tablename);
if (!parts)
return false;
coldefs = tal_arr(tmpctx, const char *, 0);
oldcolnames = tal_arr(tmpctx, const char *, 0);
for (size_t i = 0; parts[i]; i++) {
/* columnname DETAILS */
size_t after_name;
const char *colname = find_column_name(tmpctx, parts[i],
&after_name);
/* Things like "PRIMARY KEY xxx" must be copied verbatim */
if (!colname) {
tal_arr_expand(&coldefs, parts[i]);
continue;
}
if (streq(colname, from)) {
char *newdef;
colname_found = true;
/* Create column with new name */
newdef = tal_fmt(coldefs,
"%s%s", to, parts[i] + after_name);
tal_arr_expand(&coldefs, newdef);
tal_arr_expand(&oldcolnames, colname);
} else {
/* Not mentioned, keep it as is! */
tal_arr_expand(&coldefs, parts[i]);
tal_arr_expand(&oldcolnames, colname);
}
}
if (!colname_found) {
db->error = tal_fmt(db, "No column called %s", from);
return false;
}
return complete_table_manip(db, tablename, coldefs, oldcolnames);
}
static bool db_sqlite3_delete_columns(struct db *db,
const char *tablename,
const char **colnames, size_t num_cols)
{
char **parts;
const char **coldefs, **oldcolnames;
size_t colnames_found = 0;
parts = prepare_table_manip(tmpctx, db, tablename);
if (!parts)
return false;
coldefs = tal_arr(tmpctx, const char *, 0);
oldcolnames = tal_arr(tmpctx, const char *, 0);
for (size_t i = 0; parts[i]; i++) {
/* columnname DETAILS */
size_t after_name;
const char *colname = find_column_name(tmpctx, parts[i],
&after_name);
/* Things like "PRIMARY KEY xxx" must be copied verbatim */
if (!colname) {
tal_arr_expand(&coldefs, parts[i]);
continue;
}
/* Don't mention columns we're supposed to delete */
if (colname_to_delete(colnames, num_cols, colname)) {
colnames_found++;
continue;
}
/* Keep it as is! */
tal_arr_expand(&coldefs, parts[i]);
tal_arr_expand(&oldcolnames, colname);
}
if (colnames_found != num_cols) {
db->error = tal_fmt(db, "Only %zu/%zu columns found",
colnames_found, num_cols);
return false;
}
return complete_table_manip(db, tablename, coldefs, oldcolnames);
}
struct db_config db_sqlite3_config = {
.name = "sqlite3",
.exec_fn = &db_sqlite3_exec,
.query_fn = &db_sqlite3_query,
.step_fn = &db_sqlite3_step,
.begin_tx_fn = &db_sqlite3_begin_tx,
.commit_tx_fn = &db_sqlite3_commit_tx,
.stmt_free_fn = &db_sqlite3_stmt_free,
.column_is_null_fn = &db_sqlite3_column_is_null,
.column_u64_fn = &db_sqlite3_column_u64,
.column_int_fn = &db_sqlite3_column_int,
.column_bytes_fn = &db_sqlite3_column_bytes,
.column_blob_fn = &db_sqlite3_column_blob,
.column_text_fn = &db_sqlite3_column_text,
.last_insert_id_fn = &db_sqlite3_last_insert_id,
.count_changes_fn = &db_sqlite3_count_changes,
.setup_fn = &db_sqlite3_setup,
.teardown_fn = &db_sqlite3_close,
.vacuum_fn = db_sqlite3_vacuum,
.rename_column = db_sqlite3_rename_column,
.delete_columns = db_sqlite3_delete_columns,
};
AUTODATA(db_backends, &db_sqlite3_config);
#endif /* HAVE_SQLITE3 */

162
db/exec.c Normal file
View File

@@ -0,0 +1,162 @@
#include "config.h"
#include <ccan/tal/tal.h>
#include <db/bindings.h>
#include <db/common.h>
#include <db/exec.h>
#include <db/utils.h>
/**
* db_get_version - Determine the current DB schema version
*
* Will attempt to determine the current schema version of the
* database @db by querying the `version` table. If the table does not
* exist it'll return schema version -1, so that migration 0 is
* applied, which should create the `version` table.
*/
int db_get_version(struct db *db)
{
int res = -1;
struct db_stmt *stmt = db_prepare_v2(db, SQL("SELECT version FROM version LIMIT 1"));
/*
* Tentatively execute a query, but allow failures. Some databases
* like postgres will terminate the DB transaction if there is an
* error during the execution of a query, e.g., trying to access a
* table that doesn't exist yet, so we need to terminate and restart
* the DB transaction.
*/
if (!db_query_prepared(stmt)) {
db_commit_transaction(stmt->db);
db_begin_transaction(stmt->db);
tal_free(stmt);
return res;
}
if (db_step(stmt))
res = db_col_int(stmt, "version");
tal_free(stmt);
return res;
}
u32 db_data_version_get(struct db *db)
{
struct db_stmt *stmt;
u32 version;
stmt = db_prepare_v2(db, SQL("SELECT intval FROM vars WHERE name = 'data_version'"));
db_query_prepared(stmt);
db_step(stmt);
version = db_col_int(stmt, "intval");
tal_free(stmt);
return version;
}
void db_set_intvar(struct db *db, char *varname, s64 val)
{
size_t changes;
struct db_stmt *stmt = db_prepare_v2(db, SQL("UPDATE vars SET intval=? WHERE name=?;"));
db_bind_int(stmt, 0, val);
db_bind_text(stmt, 1, varname);
if (!db_exec_prepared_v2(stmt))
db_fatal("Error executing update: %s", stmt->error);
changes = db_count_changes(stmt);
tal_free(stmt);
if (changes == 0) {
stmt = db_prepare_v2(db, SQL("INSERT INTO vars (name, intval) VALUES (?, ?);"));
db_bind_text(stmt, 0, varname);
db_bind_int(stmt, 1, val);
if (!db_exec_prepared_v2(stmt))
db_fatal("Error executing insert: %s", stmt->error);
tal_free(stmt);
}
}
s64 db_get_intvar(struct db *db, char *varname, s64 defval)
{
s64 res = defval;
struct db_stmt *stmt = db_prepare_v2(
db, SQL("SELECT intval FROM vars WHERE name= ? LIMIT 1"));
db_bind_text(stmt, 0, varname);
if (!db_query_prepared(stmt))
goto done;
if (db_step(stmt))
res = db_col_int(stmt, "intval");
done:
tal_free(stmt);
return res;
}
/* Leak tracking. */
/* By making the update conditional on the current value we expect we
* are implementing an optimistic lock: if the update results in
* changes on the DB we know that the data_version did not change
* under our feet and no other transaction ran in the meantime.
*
* Notice that this update effectively locks the row, so that other
* operations attempting to change this outside the transaction will
* wait for this transaction to complete. The external change will
* ultimately fail the changes test below, it'll just delay its abort
* until our transaction is committed.
*/
static void db_data_version_incr(struct db *db)
{
struct db_stmt *stmt = db_prepare_v2(
db, SQL("UPDATE vars "
"SET intval = intval + 1 "
"WHERE name = 'data_version'"
" AND intval = ?"));
db_bind_int(stmt, 0, db->data_version);
db_exec_prepared_v2(stmt);
if (db_count_changes(stmt) != 1)
db_fatal("Optimistic lock on the database failed. There"
" may be a concurrent access to the database."
" Aborting since concurrent access is unsafe.");
tal_free(stmt);
db->data_version++;
}
void db_begin_transaction_(struct db *db, const char *location)
{
bool ok;
if (db->in_transaction)
db_fatal("Already in transaction from %s", db->in_transaction);
/* No writes yet. */
db->dirty = false;
db_prepare_for_changes(db);
ok = db->config->begin_tx_fn(db);
if (!ok)
db_fatal("Failed to start DB transaction: %s", db->error);
db->in_transaction = location;
}
bool db_in_transaction(struct db *db)
{
return db->in_transaction;
}
void db_commit_transaction(struct db *db)
{
bool ok;
assert(db->in_transaction);
db_assert_no_outstanding_statements(db);
/* Increment before reporting changes to an eventual plugin. */
if (db->dirty)
db_data_version_incr(db);
db_report_changes(db, NULL, 0);
ok = db->config->commit_tx_fn(db);
if (!ok)
db_fatal("Failed to commit DB transaction: %s", db->error);
db->in_transaction = NULL;
db->dirty = false;
}

52
db/exec.h Normal file
View File

@@ -0,0 +1,52 @@
#ifndef LIGHTNING_DB_EXEC_H
#define LIGHTNING_DB_EXEC_H
#include "config.h"
#include <ccan/short_types/short_types.h>
#include <ccan/take/take.h>
struct db;
/**
* db_set_intvar - Set an integer variable in the database
*
* Utility function to store generic integer values in the
* database.
*/
void db_set_intvar(struct db *db, char *varname, s64 val);
/**
* db_get_intvar - Retrieve an integer variable from the database
*
* Either returns the value in the database, or @defval if
* the query failed or no such variable exists.
*/
s64 db_get_intvar(struct db *db, char *varname, s64 defval);
/* Get the current data version (entries). */
u32 db_data_version_get(struct db *db);
/* Get the current database version (migrations). */
int db_get_version(struct db *db);
/**
* db_begin_transaction - Begin a transaction
*
* Begin a new DB transaction. fatal() on database error.
*/
#define db_begin_transaction(db) \
db_begin_transaction_((db), __FILE__ ":" stringify(__LINE__))
void db_begin_transaction_(struct db *db, const char *location);
bool db_in_transaction(struct db *db);
/**
* db_commit_transaction - Commit a running transaction
*
* Requires that we are currently in a transaction. fatal() if we
* fail to commit.
*/
void db_commit_transaction(struct db *db);
#endif /* LIGHTNING_DB_EXEC_H */

324
db/utils.c Normal file
View File

@@ -0,0 +1,324 @@
#include "config.h"
#include <ccan/tal/str/str.h>
#include <common/utils.h>
#include <db/common.h>
#include <db/utils.h>
/* Matches the hash function used in devtools/sql-rewrite.py */
static u32 hash_djb2(const char *str)
{
u32 hash = 5381;
for (size_t i = 0; str[i]; i++)
hash = ((hash << 5) + hash) ^ str[i];
return hash;
}
size_t db_query_colnum(const struct db_stmt *stmt,
const char *colname)
{
u32 col;
assert(stmt->query->colnames != NULL);
col = hash_djb2(colname) % stmt->query->num_colnames;
/* Will crash on NULL, which is the Right Thing */
while (!streq(stmt->query->colnames[col].sqlname,
colname)) {
col = (col + 1) % stmt->query->num_colnames;
}
#if DEVELOPER
strset_add(stmt->cols_used, colname);
#endif
return stmt->query->colnames[col].val;
}
static void db_stmt_free(struct db_stmt *stmt)
{
if (!stmt->executed)
db_fatal("Freeing an un-executed statement from %s: %s",
stmt->location, stmt->query->query);
#if DEVELOPER
/* If they never got a db_step, we don't track */
if (stmt->cols_used) {
for (size_t i = 0; i < stmt->query->num_colnames; i++) {
if (!stmt->query->colnames[i].sqlname)
continue;
if (!strset_get(stmt->cols_used,
stmt->query->colnames[i].sqlname)) {
db_fatal("Never accessed column %s in query %s",
stmt->query->colnames[i].sqlname,
stmt->query->query);
}
}
strset_clear(stmt->cols_used);
}
#endif
if (stmt->inner_stmt)
stmt->db->config->stmt_free_fn(stmt);
assert(stmt->inner_stmt == NULL);
}
struct db_stmt *db_prepare_v2_(const char *location, struct db *db,
const char *query_id)
{
struct db_stmt *stmt = tal(db, struct db_stmt);
size_t num_slots, pos;
/* Normalize query_id paths, because unit tests are compiled with this
* prefix. */
if (strncmp(query_id, "./", 2) == 0)
query_id += 2;
if (!db->in_transaction)
db_fatal("Attempting to prepare a db_stmt outside of a "
"transaction: %s", location);
/* Look up the query by its ID */
pos = hash_djb2(query_id) % db->queries->query_table_size;
for (;;) {
if (!db->queries->query_table[pos].name)
db_fatal("Could not resolve query %s", query_id);
if (streq(query_id, db->queries->query_table[pos].name)) {
stmt->query = &db->queries->query_table[pos];
break;
}
pos = (pos + 1) % db->queries->query_table_size;
}
num_slots = stmt->query->placeholders;
/* Allocate the slots for placeholders/bindings, zeroed next since
* that sets the type to DB_BINDING_UNINITIALIZED for later checks. */
stmt->bindings = tal_arr(stmt, struct db_binding, num_slots);
for (size_t i=0; i<num_slots; i++)
stmt->bindings[i].type = DB_BINDING_UNINITIALIZED;
stmt->location = location;
stmt->error = NULL;
stmt->db = db;
stmt->executed = false;
stmt->inner_stmt = NULL;
tal_add_destructor(stmt, db_stmt_free);
list_add(&db->pending_statements, &stmt->list);
#if DEVELOPER
stmt->cols_used = NULL;
#endif /* DEVELOPER */
return stmt;
}
#define db_prepare_v2(db,query) \
db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query)
bool db_query_prepared(struct db_stmt *stmt)
{
/* Make sure we don't accidentally execute a modifying query using a
* read-only path. */
bool ret;
assert(stmt->query->readonly);
ret = stmt->db->config->query_fn(stmt);
stmt->executed = true;
list_del_from(&stmt->db->pending_statements, &stmt->list);
return ret;
}
bool db_step(struct db_stmt *stmt)
{
bool ret;
assert(stmt->executed);
ret = stmt->db->config->step_fn(stmt);
#if DEVELOPER
/* We only track cols_used if we return a result! */
if (ret && !stmt->cols_used) {
stmt->cols_used = tal(stmt, struct strset);
strset_init(stmt->cols_used);
}
#endif
return ret;
}
bool db_exec_prepared_v2(struct db_stmt *stmt TAKES)
{
bool ret = stmt->db->config->exec_fn(stmt);
/* If this was a write we need to bump the data_version upon commit. */
stmt->db->dirty = stmt->db->dirty || !stmt->query->readonly;
stmt->executed = true;
list_del_from(&stmt->db->pending_statements, &stmt->list);
/* The driver itself doesn't call `fatal` since we want to override it
* for testing. Instead we check here that the error message is set if
* we report an error. */
if (!ret) {
assert(stmt->error);
db_fatal("Error executing statement: %s", stmt->error);
}
if (taken(stmt))
tal_free(stmt);
return ret;
}
size_t db_count_changes(struct db_stmt *stmt)
{
assert(stmt->executed);
return stmt->db->config->count_changes_fn(stmt);
}
const char **db_changes(struct db *db)
{
return db->changes;
}
u64 db_last_insert_id_v2(struct db_stmt *stmt TAKES)
{
u64 id;
assert(stmt->executed);
id = stmt->db->config->last_insert_id_fn(stmt);
if (taken(stmt))
tal_free(stmt);
return id;
}
/* We expect min changes (ie. BEGIN TRANSACTION): report if more.
* Optionally add "final" at the end (ie. COMMIT). */
void db_report_changes(struct db *db, const char *final, size_t min)
{
assert(db->changes);
assert(tal_count(db->changes) >= min);
/* Having changes implies that we have a dirty TX. The opposite is
* currently not true, e.g., the postgres driver doesn't record
* changes yet. */
assert(!tal_count(db->changes) || db->dirty);
if (tal_count(db->changes) > min && db->report_changes_fn)
db->report_changes_fn(db);
db->changes = tal_free(db->changes);
}
void db_changes_add(struct db_stmt *stmt, const char * expanded)
{
struct db *db = stmt->db;
if (stmt->query->readonly) {
return;
}
/* We get a "COMMIT;" after we've sent our changes. */
if (!db->changes) {
assert(streq(expanded, "COMMIT;"));
return;
}
tal_arr_expand(&db->changes, tal_strdup(db->changes, expanded));
}
#if DEVELOPER
void db_assert_no_outstanding_statements(struct db *db)
{
struct db_stmt *stmt;
stmt = list_top(&db->pending_statements, struct db_stmt, list);
if (stmt)
db_fatal("Unfinalized statement %s", stmt->location);
}
#else
void db_assert_no_outstanding_statements(struct db *db)
{
}
#endif
static void destroy_db(struct db *db)
{
db_assert_no_outstanding_statements(db);
if (db->config->teardown_fn)
db->config->teardown_fn(db);
}
static struct db_config *db_config_find(const char *dsn)
{
size_t num_configs;
struct db_config **configs = autodata_get(db_backends, &num_configs);
const char *sep, *driver_name;
sep = strstr(dsn, "://");
if (!sep)
db_fatal("%s doesn't look like a valid data-source name (missing \"://\" separator.", dsn);
driver_name = tal_strndup(tmpctx, dsn, sep - dsn);
for (size_t i=0; i<num_configs; i++) {
if (streq(driver_name, configs[i]->name)) {
tal_free(driver_name);
return configs[i];
}
}
tal_free(driver_name);
return NULL;
}
static struct db_query_set *db_queries_find(const struct db_config *config)
{
size_t num_queries;
struct db_query_set **queries = autodata_get(db_queries, &num_queries);
for (size_t i = 0; i < num_queries; i++) {
if (streq(config->name, queries[i]->name)) {
return queries[i];
}
}
return NULL;
}
void db_prepare_for_changes(struct db *db)
{
assert(!db->changes);
db->changes = tal_arr(db, const char *, 0);
}
struct db *db_open(const tal_t *ctx, char *filename)
{
struct db *db;
db = tal(ctx, struct db);
db->filename = tal_strdup(db, filename);
list_head_init(&db->pending_statements);
if (!strstr(db->filename, "://"))
db_fatal("Could not extract driver name from \"%s\"", db->filename);
db->config = db_config_find(db->filename);
if (!db->config)
db_fatal("Unable to find DB driver for %s", db->filename);
db->queries = db_queries_find(db->config);
if (!db->queries)
db_fatal("Unable to find DB queries for %s", db->config->name);
tal_add_destructor(db, destroy_db);
db->in_transaction = NULL;
db->changes = NULL;
/* This must be outside a transaction, so catch it */
assert(!db->in_transaction);
db_prepare_for_changes(db);
if (db->config->setup_fn && !db->config->setup_fn(db))
db_fatal("Error calling DB setup: %s", db->error);
db_report_changes(db, NULL, 0);
return db;
}

100
db/utils.h Normal file
View File

@@ -0,0 +1,100 @@
#ifndef LIGHTNING_DB_UTILS_H
#define LIGHTNING_DB_UTILS_H
#include "config.h"
#include <ccan/take/take.h>
#include <ccan/tal/tal.h>
struct db;
struct db_stmt;
size_t db_query_colnum(const struct db_stmt *stmt,
const char *colname);
/* Return next 'row' result of statement */
bool db_step(struct db_stmt *stmt);
/* TODO(cdecker) Remove the v2 suffix after finishing the migration */
#define db_prepare_v2(db,query) \
db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query)
/**
* db_exec_prepared -- Execute a prepared statement
*
* After preparing a statement using `db_prepare`, and after binding all
* non-null variables using the `db_bind_*` functions, it can be executed with
* this function. It is a small, transaction-aware, wrapper around `db_step`,
* that calls fatal() if the execution fails. This may take ownership of
* `stmt` if annotated with `take()`and will free it before returning.
*
* If you'd like to issue a query and access the rows returned by the query
* please use `db_query_prepared` instead, since this function will not expose
* returned results, and the `stmt` can only be used for calls to
* `db_count_changes` and `db_last_insert_id` after executing.
*
* @stmt: The prepared statement to execute
*/
bool db_exec_prepared_v2(struct db_stmt *stmt TAKES);
/**
* db_query_prepared -- Execute a prepared query
*
* After preparing a query using `db_prepare`, and after binding all non-null
* variables using the `db_bind_*` functions, it can be executed with this
* function. This function must be called before calling `db_step` or any of
* the `db_col_*` column access functions.
*
* If you are not executing a read-only statement, please use
* `db_exec_prepared` instead.
*
* @stmt: The prepared statement to execute
*/
bool db_query_prepared(struct db_stmt *stmt);
size_t db_count_changes(struct db_stmt *stmt);
void db_report_changes(struct db *db, const char *final, size_t min);
void db_prepare_for_changes(struct db *db);
u64 db_last_insert_id_v2(struct db_stmt *stmt);
/**
* db_prepare -- Prepare a DB query/command
*
* Create an instance of `struct db_stmt` that encapsulates a SQL query or command.
*
* @query MUST be wrapped in a `SQL()` macro call, since that allows the
* extraction and translation of the query into the target SQL dialect.
*
* It does not execute the query and does not check its validity, but
* allocates the placeholders detected in the query. The placeholders in the
* `stmt` can then be bound using the `db_bind_*` functions, and executed
* using `db_exec_prepared` for write-only statements and `db_query_prepared`
* for read-only statements.
*
* @db: Database to query/exec
* @query: The SQL statement to compile
*/
struct db_stmt *db_prepare_v2_(const char *location, struct db *db,
const char *query_id);
/**
* db_open - Open or create a database
*/
struct db *db_open(const tal_t *ctx, char *filename);
/**
* Report a statement that changes the wallet
*
* Allows the DB driver to report an expanded statement during
* execution. Changes are queued up and reported to the `db_write` plugin hook
* upon committing.
*/
void db_changes_add(struct db_stmt *db_stmt, const char * expanded);
void db_assert_no_outstanding_statements(struct db *db);
/**
* Access pending changes that have been added to the current transaction.
*/
const char **db_changes(struct db *db);
#endif /* LIGHTNING_DB_UTILS_H */