db: create simple hashtable of fields in SELECT.

This simplistically maps names to numbers, eg:

	SELECT foo, bar FROM tbl;

'foo' -> 0
'bar' -> 1

If a statement is too complex for our simple parsing, we treat it as a
single field (which currently it always is).

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2021-11-15 04:23:46 +10:30
parent 57328fe59e
commit 787fbb1228
3 changed files with 72 additions and 5 deletions

View File

@@ -83,13 +83,58 @@ rewriters = {
"postgres": PostgresRewriter(), "postgres": PostgresRewriter(),
} }
# djb2 is simple and effective: see http://www.cse.yorku.ca/~oz/hash.html
def hash_djb2(string):
val = 5381
for s in string:
val = ((val * 33) & 0xFFFFFFFF) ^ ord(s)
return val
def colname_htable(query):
assert query.upper().startswith("SELECT")
colquery = query[6:query.upper().index(" FROM ")]
colnames = colquery.split(',')
# If split caused unbalanced brackets, it's complex: assume
# a single field!
if any([colname.count('(') != colname.count(')') for colname in colnames]):
return [('"' + colquery.strip() + '"', 0)]
# 50% density htable
tablesize = len(colnames) * 2 - 1
table = [("NULL", -1)] * tablesize
for colnum, colname in enumerate(colnames):
colname = colname.strip()
# SELECT xxx AS yyy -> Y
as_clause = colname.upper().find(" AS ")
if as_clause != -1:
colname = colname[as_clause + 4:].strip()
pos = hash_djb2(colname) % tablesize
while table[pos][0] != "NULL":
pos = (pos + 1) % tablesize
table[pos] = ('"' + colname + '"', colnum)
return table
template = Template("""#ifndef LIGHTNINGD_WALLET_GEN_DB_${f.upper()} template = Template("""#ifndef LIGHTNINGD_WALLET_GEN_DB_${f.upper()}
#define LIGHTNINGD_WALLET_GEN_DB_${f.upper()} #define LIGHTNINGD_WALLET_GEN_DB_${f.upper()}
#include <config.h> #include <config.h>
#include <ccan/array_size/array_size.h>
#include <wallet/db_common.h> #include <wallet/db_common.h>
#if HAVE_${f.upper()} #if HAVE_${f.upper()}
% for colname, table in colhtables.items():
static const struct sqlname_map ${colname}[] = {
% for t in table:
{ ${t[0]}, ${t[1]} },
% endfor
};
% endfor
struct db_query db_${f}_queries[] = { struct db_query db_${f}_queries[] = {
@@ -99,6 +144,10 @@ struct db_query db_${f}_queries[] = {
.query = "${elem['query']}", .query = "${elem['query']}",
.placeholders = ${elem['placeholders']}, .placeholders = ${elem['placeholders']},
.readonly = ${elem['readonly']}, .readonly = ${elem['readonly']},
% if elem['colnames'] is not None:
.colnames = ${elem['colnames']},
.num_colnames = ARRAY_SIZE(${elem['colnames']}),
% endif
}, },
% endfor % endfor
}; };
@@ -129,6 +178,7 @@ def extract_queries(pofile):
if chunk != []: if chunk != []:
yield chunk yield chunk
colhtables = {}
queries = [] queries = []
for c in chunk(pofile): for c in chunk(pofile):
@@ -140,13 +190,21 @@ def extract_queries(pofile):
# Strip header and surrounding quotes # Strip header and surrounding quotes
query = c[i][7:][:-1] query = c[i][7:][:-1]
is_select = query.upper().startswith("SELECT")
if is_select:
colnames = 'col_table{}'.format(len(queries))
colhtables[colnames] = colname_htable(query)
else:
colnames = None
queries.append({ queries.append({
'name': query, 'name': query,
'query': query, 'query': query,
'placeholders': query.count('?'), 'placeholders': query.count('?'),
'readonly': "true" if query.upper().startswith("SELECT") else "false", 'readonly': "true" if is_select else "false",
'colnames': colnames,
}) })
return queries return colhtables, queries
if __name__ == "__main__": if __name__ == "__main__":
@@ -165,7 +223,7 @@ if __name__ == "__main__":
rewriter = rewriters[dialect] rewriter = rewriters[dialect]
queries = extract_queries(sys.argv[1]) colhtables, queries = extract_queries(sys.argv[1])
queries = rewriter.rewrite(queries) queries = rewriter.rewrite(queries)
print(template.render(f=dialect, queries=queries)) print(template.render(f=dialect, queries=queries, colhtables=colhtables))

View File

@@ -49,6 +49,10 @@ struct db_query {
/* Is this a read-only query? If it is there's no need to tell plugins /* Is this a read-only query? If it is there's no need to tell plugins
* about it. */ * about it. */
bool readonly; bool readonly;
/* If this is a select statement, what column names */
const struct sqlname_map *colnames;
size_t num_colnames;
}; };
enum db_binding_type { enum db_binding_type {
@@ -155,5 +159,10 @@ AUTODATA_TYPE(db_backends, struct db_config);
*/ */
void db_changes_add(struct db_stmt *db_stmt, const char * expanded); void db_changes_add(struct db_stmt *db_stmt, const char * expanded);
/* devtools/sql-rewrite.py generates this simple htable */
struct sqlname_map {
const char *sqlname;
int val;
};
#endif /* LIGHTNING_WALLET_DB_COMMON_H */ #endif /* LIGHTNING_WALLET_DB_COMMON_H */

View File

@@ -4321,7 +4321,7 @@ struct amount_msat wallet_total_forward_fees(struct wallet *w)
stmt = db_prepare_v2(w->db, SQL("SELECT" stmt = db_prepare_v2(w->db, SQL("SELECT"
" CAST(COALESCE(SUM(in_msatoshi - out_msatoshi), 0) AS BIGINT)" " CAST(COALESCE(SUM(in_msatoshi - out_msatoshi), 0) AS BIGINT)"
"FROM forwarded_payments " " FROM forwarded_payments "
"WHERE state = ?;")); "WHERE state = ?;"));
db_bind_int(stmt, 0, wallet_forward_status_in_db(FORWARD_SETTLED)); db_bind_int(stmt, 0, wallet_forward_status_in_db(FORWARD_SETTLED));
db_query_prepared(stmt); db_query_prepared(stmt);