diff --git a/lightningd/htlc_end.h b/lightningd/htlc_end.h index 6a7d8f108..73e1d3d37 100644 --- a/lightningd/htlc_end.h +++ b/lightningd/htlc_end.h @@ -15,6 +15,10 @@ struct htlc_key { /* Incoming HTLC */ struct htlc_in { + /* The database primary key for this htlc. Must be 0 until it + * is saved to the database, must be >0 after saving to the + * database. */ + u64 dbid; struct htlc_key key; u64 msatoshi; u32 cltv_expiry; @@ -39,6 +43,10 @@ struct htlc_in { }; struct htlc_out { + /* The database primary key for this htlc. Must be 0 until it + * is saved to the database, must be >0 after saving to the + * database. */ + u64 dbid; struct htlc_key key; u64 msatoshi; u32 cltv_expiry; diff --git a/wallet/db.c b/wallet/db.c index ad782cde9..7fa245dad 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -99,7 +99,7 @@ char *dbmigrations[] = { " direction INTEGER," " origin_htlc INTEGER," " msatoshi INTEGER," - " ctlv_expiry INTEGER," + " cltv_expiry INTEGER," " payment_hash BLOB," " payment_key BLOB," " routing_onion BLOB," @@ -107,7 +107,6 @@ char *dbmigrations[] = { " malformed_onion INTEGER," " hstate INTEGER," " shared_secret BLOB," - " preimage BLOB," " PRIMARY KEY (id)," " UNIQUE (channel_id, channel_htlc_id, direction)" ");", diff --git a/wallet/wallet.c b/wallet/wallet.c index b433b1949..b8122ca58 100644 --- a/wallet/wallet.c +++ b/wallet/wallet.c @@ -10,6 +10,8 @@ #include #define SQLITE_MAX_UINT 0x7FFFFFFFFFFFFFFF +#define DIRECTION_INCOMING 0 +#define DIRECTION_OUTGOING 1 struct wallet *wallet_new(const tal_t *ctx, struct log *log) { @@ -845,6 +847,220 @@ int wallet_extract_owned_outputs(struct wallet *w, const struct bitcoin_tx *tx, return num_utxos; } +bool wallet_htlc_save_in(struct wallet *wallet, + const struct wallet_channel *chan, struct htlc_in *in) +{ + bool ok = true; + tal_t *tmpctx = tal_tmpctx(wallet); + + ok &= db_exec( + __func__, wallet->db, + "INSERT INTO channel_htlcs " + "(channel_id, channel_htlc_id, direction, origin_htlc, msatoshi, cltv_expiry, payment_hash, payment_key, hstate, shared_secret, routing_onion) VALUES " + "(%" PRIu64 ", %"PRIu64", %d, NULL, %"PRIu64", %d, '%s', %s, %d, '%s', '%s');", + chan->id, in->key.id, DIRECTION_INCOMING, in->msatoshi, in->cltv_expiry, + tal_hexstr(tmpctx, &in->payment_hash, sizeof(struct sha256)), + in->preimage == NULL ? "NULL" : tal_hexstr(tmpctx, &in->preimage, + sizeof(struct preimage)), + in->hstate, + tal_hexstr(tmpctx, &in->shared_secret, sizeof(struct secret)), + tal_hexstr(tmpctx, &in->onion_routing_packet, sizeof(in->onion_routing_packet)) + ); + + tal_free(tmpctx); + if (ok) { + in->dbid = sqlite3_last_insert_rowid(wallet->db->sql); + } + return ok; +} + +bool wallet_htlc_save_out(struct wallet *wallet, + const struct wallet_channel *chan, + struct htlc_out *out) +{ + bool ok = true; + tal_t *tmpctx = tal_tmpctx(wallet); + + /* We absolutely need the incoming HTLC to be persisted before + * we can persist it's dependent */ + assert(out->in == NULL || out->in->dbid != 0); + + ok &= db_exec( + __func__, wallet->db, + "INSERT INTO channel_htlcs " + "(channel_id, channel_htlc_id, direction, origin_htlc, msatoshi, cltv_expiry, " + "payment_hash, payment_key, hstate, shared_secret, routing_onion) VALUES " + "(%" PRIu64 ", %" PRIu64 ", %d, %s, %" PRIu64", %d, '%s', %s, %d, NULL, '%s');", + chan->id, + out->key.id, + DIRECTION_OUTGOING, + out->in ? tal_fmt(tmpctx, "%" PRIu64, out->in->dbid) : "NULL", + out->msatoshi, + out->cltv_expiry, + tal_hexstr(tmpctx, &out->payment_hash, sizeof(struct sha256)), + out->preimage ? tal_hexstr(tmpctx, &out->preimage, sizeof(struct preimage)) : "NULL", + out->hstate, + tal_hexstr(tmpctx, &out->onion_routing_packet, sizeof(out->onion_routing_packet)) + ); + + tal_free(tmpctx); + if (ok) { + out->dbid = sqlite3_last_insert_rowid(wallet->db->sql); + } + return ok; +} + +bool wallet_htlc_update(struct wallet *wallet, const u64 htlc_dbid, + const enum htlc_state new_state, + const struct preimage *payment_key) +{ + bool ok = true; + tal_t *tmpctx = tal_tmpctx(wallet); + + if (payment_key) { + ok &= db_exec( + __func__, wallet->db, "UPDATE channel_htlcs SET hstate=%d, " + "payment_key='%s' WHERE id=%" PRIu64, + new_state, + tal_hexstr(tmpctx, payment_key, sizeof(struct preimage)), + htlc_dbid); + + } else { + ok &= db_exec( + __func__, wallet->db, + "UPDATE channel_htlcs SET hstate = %d WHERE id=%" PRIu64, + new_state, htlc_dbid); + } + tal_free(tmpctx); + return ok; +} + +static bool wallet_stmt2htlc_in(const struct wallet_channel *channel, + sqlite3_stmt *stmt, struct htlc_in *in) +{ + bool ok = true; + in->dbid = sqlite3_column_int64(stmt, 0); + in->key.id = sqlite3_column_int64(stmt, 1); + in->key.peer = channel->peer; + in->msatoshi = sqlite3_column_int64(stmt, 2); + in->cltv_expiry = sqlite3_column_int(stmt, 3); + in->hstate = sqlite3_column_int(stmt, 4); + + ok &= sqlite3_column_hexval(stmt, 5, &in->payment_hash, + sizeof(in->payment_hash)); + ok &= sqlite3_column_hexval(stmt, 6, &in->shared_secret, + sizeof(in->shared_secret)); + + if (sqlite3_column_type(stmt, 7) != SQLITE_NULL) { + in->preimage = tal(in, struct preimage); + ok &= sqlite3_column_hexval(stmt, 7, in->preimage, sizeof(*in->preimage)); + } else { + in->preimage = NULL; + } + + sqlite3_column_hexval(stmt, 8, &in->onion_routing_packet, + sizeof(in->onion_routing_packet)); + + in->failuremsg = NULL; + in->malformed = 0; + + return ok; +} +static bool wallet_stmt2htlc_out(const struct wallet_channel *channel, + sqlite3_stmt *stmt, struct htlc_out *out) +{ + bool ok = true; + out->dbid = sqlite3_column_int64(stmt, 0); + out->key.id = sqlite3_column_int64(stmt, 1); + out->key.peer = channel->peer; + out->msatoshi = sqlite3_column_int64(stmt, 2); + out->cltv_expiry = sqlite3_column_int(stmt, 3); + out->hstate = sqlite3_column_int(stmt, 4); + ok &= sqlite3_column_hexval(stmt, 5, &out->payment_hash, + sizeof(out->payment_hash)); + + if (sqlite3_column_type(stmt, 6) != SQLITE_NULL) { + out->origin_htlc_id = sqlite3_column_int64(stmt, 6); + } else { + out->origin_htlc_id = 0; + } + + if (sqlite3_column_type(stmt, 7) != SQLITE_NULL) { + out->preimage = tal(out, struct preimage); + ok &= sqlite3_column_hexval(stmt, 7, &out->preimage, sizeof(struct preimage)); + } else { + out->preimage = NULL; + } + + sqlite3_column_hexval(stmt, 8, &out->onion_routing_packet, + sizeof(out->onion_routing_packet)); + + out->failuremsg = NULL; + out->malformed = 0; + + /* Need to defer wiring until we can look up all incoming + * htlcs, will wire using origin_htlc_id */ + out->in = NULL; + + return ok; +} + +bool wallet_htlcs_load_for_channel(struct wallet *wallet, + struct wallet_channel *chan, + struct htlc_in_map *htlcs_in, + struct htlc_out_map *htlcs_out) +{ + bool ok = true; + int incount = 0, outcount = 0; + + log_debug(wallet->log, "Loading HTLCs for channel %"PRIu64, chan->id); + sqlite3_stmt *stmt = db_query( + __func__, wallet->db, + "SELECT id, channel_htlc_id, msatoshi, cltv_expiry, hstate, " + "payment_hash, shared_secret, payment_key FROM channel_htlcs WHERE " + "direction=%d AND channel_id=%" PRIu64 " AND hstate != %d", + DIRECTION_INCOMING, chan->id, SENT_REMOVE_ACK_REVOCATION); + + if (!stmt) { + log_broken(wallet->log, "Could not select htlc_ins: %s", wallet->db->err); + return false; + } + + while (ok && stmt && sqlite3_step(stmt) == SQLITE_ROW) { + struct htlc_in *in = tal(chan, struct htlc_in); + ok &= wallet_stmt2htlc_in(chan, stmt, in); + connect_htlc_in(htlcs_in, in); + ok &= htlc_in_check(in, "wallet_htlcs_load") != NULL; + incount++; + } + sqlite3_finalize(stmt); + + stmt = db_query( + __func__, wallet->db, + "SELECT id, channel_htlc_id, msatoshi, cltv_expiry, hstate, " + "payment_hash, origin_htlc, payment_key FROM channel_htlcs WHERE " + "direction=%d AND channel_id=%" PRIu64 " AND hstate != %d", + DIRECTION_OUTGOING, chan->id, RCVD_REMOVE_ACK_REVOCATION); + + if (!stmt) { + log_broken(wallet->log, "Could not select htlc_outs: %s", wallet->db->err); + return false; + } + + while (ok && stmt && sqlite3_step(stmt) == SQLITE_ROW) { + struct htlc_out *out = tal(chan, struct htlc_out); + ok &= wallet_stmt2htlc_out(chan, stmt, out); + connect_htlc_out(htlcs_out, out); + /* Cannot htlc_out_check because we haven't wired the + * dependencies in yet */ + outcount++; + } + sqlite3_finalize(stmt); + log_debug(wallet->log, "Restored %d incoming and %d outgoing HTLCS", incount, outcount); + + return ok; +} + /** * wallet_shachain_delete - Drop the shachain from the database * diff --git a/wallet/wallet.h b/wallet/wallet.h index a1d423a30..f642671c0 100644 --- a/wallet/wallet.h +++ b/wallet/wallet.h @@ -9,6 +9,7 @@ #include #include #include +#include #include struct lightningd; @@ -219,4 +220,48 @@ bool wallet_channels_load_active(struct wallet *w, struct list_head *peers); int wallet_extract_owned_outputs(struct wallet *w, const struct bitcoin_tx *tx, u64 *total_satoshi); +/** + * wallet_htlc_save_in - store a htlc_in in the database + * + * @wallet: wallet to store the htlc into + * @chan: the `wallet_channel` this HTLC is associated with + * @in: the htlc_in to store + * + * This will store the contents of the `struct htlc_in` in the + * database. Since `struct htlc_in` commonly only change state after + * being created we do not support updating arbitrary fields and this + * function will fail when attempting to call it multiple times for + * the same `struct htlc_in`. Instead `wallet_htlc_update` may be used + * for state transitions or to set the `payment_key` for completed + * HTLCs. + */ +bool wallet_htlc_save_in(struct wallet *wallet, + const struct wallet_channel *chan, struct htlc_in *in); + +/** + * wallet_htlc_save_out - store a htlc_out in the database + * + * See comment for wallet_htlc_save_in. + */ +bool wallet_htlc_save_out(struct wallet *wallet, + const struct wallet_channel *chan, + struct htlc_out *out); + +/** + * wallet_htlc_update - perform state transition or add payment_key + * + * @wallet: the wallet containing the HTLC to update + * @htlc_dbid: the database ID used to identify the HTLC + * @new_state: the state we should transition to + * @payment_key: the `payment_key` which hashes to the `payment_hash`, + * or NULL if unknown. + * + * Used to update the state of an HTLC, either a `struct htlc_in` or a + * `struct htlc_out` and optionally set the `payment_key` should the + * HTLC have been settled. + */ +bool wallet_htlc_update(struct wallet *wallet, const u64 htlc_dbid, + const enum htlc_state new_state, + const struct preimage *payment_key); + #endif /* WALLET_WALLET_H */ diff --git a/wallet/wallet_tests.c b/wallet/wallet_tests.c index 40bda648a..7740e87ca 100644 --- a/wallet/wallet_tests.c +++ b/wallet/wallet_tests.c @@ -294,6 +294,48 @@ static bool test_channel_config_crud(const tal_t *ctx) return true; } +static bool test_htlc_crud(const tal_t *ctx) +{ + struct htlc_in in; + struct htlc_out out; + struct preimage payment_key; + struct wallet_channel chan; + struct wallet *w = create_test_wallet(ctx); + + /* Make sure we have our references correct */ + db_exec(__func__, w->db, "INSERT INTO channels (id) VALUES (1);"); + chan.id = 1; + + memset(&in, 0, sizeof(in)); + memset(&out, 0, sizeof(out)); + memset(&in.payment_hash, 'A', sizeof(struct sha256)); + memset(&out.payment_hash, 'A', sizeof(struct sha256)); + memset(&payment_key, 'B', sizeof(payment_key)); + out.in = ∈ + out.key.id = 1337; + + /* Store the htlc_in */ + CHECK_MSG(wallet_htlc_save_in(w, &chan, &in), + tal_fmt(ctx, "Save htlc_in failed: %s", w->db->err)); + CHECK_MSG(in.dbid != 0, "HTLC DB ID was not set."); + /* Saving again should get us a collision */ + CHECK_MSG(!wallet_htlc_save_in(w, &chan, &in), + "Saving two HTLCs with the same data must not succeed."); + /* Update */ + CHECK_MSG(wallet_htlc_update(w, in.dbid, RCVD_ADD_HTLC, NULL), + "Update HTLC with null payment_key failed"); + CHECK_MSG( + wallet_htlc_update(w, in.dbid, SENT_REMOVE_HTLC, &payment_key), + "Update HTLC with payment_key failed"); + + CHECK_MSG(wallet_htlc_save_out(w, &chan, &out), + tal_fmt(ctx, "Save htlc_out failed: %s", w->db->err)); + CHECK_MSG(out.dbid != 0, "HTLC DB ID was not set."); + CHECK_MSG(!wallet_htlc_save_out(w, &chan, &out), + "Saving two HTLCs with the same data must not succeed."); + return true; +} + int main(void) { bool ok = true; @@ -303,6 +345,7 @@ int main(void) ok &= test_shachain_crud(); ok &= test_channel_crud(tmpctx); ok &= test_channel_config_crud(tmpctx); + ok &= test_htlc_crud(tmpctx); tal_free(tmpctx); return !ok;