psbt: pull out changeset logic into common, update API

Greatly simplify the changeset API. Instead of 'diff' we simply generate
the changes.

Also pulls up the 'next message' method, as at some point the
interactive tx protocol will be used for other things as well
(splices/closes etc)

Suggested-By: @rustyrussell
This commit is contained in:
niftynei
2020-09-09 19:40:29 +09:30
committed by Rusty Russell
parent 5cd06227d7
commit c50f377a85
12 changed files with 190 additions and 204 deletions

View File

@@ -6,7 +6,9 @@
#include <ccan/asort/asort.h>
#include <ccan/ccan/endian/endian.h>
#include <ccan/ccan/mem/mem.h>
#include <common/channel_id.h>
#include <common/utils.h>
#include <wire/peer_wire.h>
bool psbt_get_serial_id(const struct wally_map *map, u16 *serial_id)
{
@@ -213,39 +215,45 @@ void psbt_sort_by_serial_id(struct wally_psbt *psbt)
struct type##_set a; \
a.type = from->type##s[index]; \
a.tx_##type = from->tx->type##s[index]; \
tal_arr_expand(add_to, a); \
tal_arr_expand(&add_to, a); \
} while (0)
static struct psbt_changeset *new_changeset(const tal_t *ctx)
{
struct psbt_changeset *set = tal(ctx, struct psbt_changeset);
set->added_ins = tal_arr(set, struct input_set, 0);
set->rm_ins = tal_arr(set, struct input_set, 0);
set->added_outs = tal_arr(set, struct output_set, 0);
set->rm_outs = tal_arr(set, struct output_set, 0);
return set;
}
/* this requires having a serial_id entry on everything */
/* YOU MUST KEEP orig + new AROUND TO USE THE RESULTING SETS */
bool psbt_has_diff(const tal_t *ctx,
struct wally_psbt *orig,
struct wally_psbt *new,
struct input_set **added_ins,
struct input_set **rm_ins,
struct output_set **added_outs,
struct output_set **rm_outs)
struct psbt_changeset *psbt_get_changeset(const tal_t *ctx,
struct wally_psbt *orig,
struct wally_psbt *new)
{
int result;
size_t i = 0, j = 0;
struct psbt_changeset *set;
psbt_sort_by_serial_id(orig);
psbt_sort_by_serial_id(new);
*added_ins = tal_arr(ctx, struct input_set, 0);
*rm_ins = tal_arr(ctx, struct input_set, 0);
*added_outs = tal_arr(ctx, struct output_set, 0);
*rm_outs = tal_arr(ctx, struct output_set, 0);
set = new_changeset(ctx);
/* Find the input diff */
while (i < orig->num_inputs || j < new->num_inputs) {
if (i >= orig->num_inputs) {
ADD(input, added_ins, new, j);
ADD(input, set->added_ins, new, j);
j++;
continue;
}
if (j >= new->num_inputs) {
ADD(input, rm_ins, orig, i);
ADD(input, set->rm_ins, orig, i);
i++;
continue;
}
@@ -253,19 +261,19 @@ bool psbt_has_diff(const tal_t *ctx,
result = compare_serials(&orig->inputs[i].unknowns,
&new->inputs[j].unknowns);
if (result == -1) {
ADD(input, rm_ins, orig, i);
ADD(input, set->rm_ins, orig, i);
i++;
continue;
}
if (result == 1) {
ADD(input, added_ins, new, j);
ADD(input, set->added_ins, new, j);
j++;
continue;
}
if (!input_identical(orig, i, new, j)) {
ADD(input, rm_ins, orig, i);
ADD(input, added_ins, new, j);
ADD(input, set->rm_ins, orig, i);
ADD(input, set->added_ins, new, j);
}
i++;
j++;
@@ -275,12 +283,12 @@ bool psbt_has_diff(const tal_t *ctx,
j = 0;
while (i < orig->num_outputs || j < new->num_outputs) {
if (i >= orig->num_outputs) {
ADD(output, added_outs, new, j);
ADD(output, set->added_outs, new, j);
j++;
continue;
}
if (j >= new->num_outputs) {
ADD(output, rm_outs, orig, i);
ADD(output, set->rm_outs, orig, i);
i++;
continue;
}
@@ -288,27 +296,106 @@ bool psbt_has_diff(const tal_t *ctx,
result = compare_serials(&orig->outputs[i].unknowns,
&new->outputs[j].unknowns);
if (result == -1) {
ADD(output, rm_outs, orig, i);
ADD(output, set->rm_outs, orig, i);
i++;
continue;
}
if (result == 1) {
ADD(output, added_outs, new, j);
ADD(output, set->added_outs, new, j);
j++;
continue;
}
if (!output_identical(orig, i, new, j)) {
ADD(output, rm_outs, orig, i);
ADD(output, added_outs, new, j);
ADD(output, set->rm_outs, orig, i);
ADD(output, set->added_outs, new, j);
}
i++;
j++;
}
return tal_count(*added_ins) != 0 ||
tal_count(*rm_ins) != 0 ||
tal_count(*added_outs) != 0 ||
tal_count(*rm_outs) != 0;
return set;
}
u8 *psbt_changeset_get_next(const tal_t *ctx, struct channel_id *cid,
struct psbt_changeset *set)
{
u16 serial_id;
u8 *msg;
if (tal_count(set->added_ins) != 0) {
const struct input_set *in = &set->added_ins[0];
u16 max_witness_len;
u8 *script;
if (!psbt_get_serial_id(&in->input.unknowns, &serial_id))
abort();
const u8 *prevtx = linearize_wtx(ctx,
in->input.utxo);
if (!psbt_input_get_max_witness_len(&in->input,
&max_witness_len))
abort();
if (in->input.redeem_script_len)
script = tal_dup_arr(ctx, u8,
in->input.redeem_script,
in->input.redeem_script_len, 0);
else
script = NULL;
msg = towire_tx_add_input(ctx, cid, serial_id,
prevtx, in->tx_input.index,
in->tx_input.sequence,
max_witness_len,
script,
NULL);
tal_arr_remove(&set->added_ins, 0);
return msg;
}
if (tal_count(set->rm_ins) != 0) {
if (!psbt_get_serial_id(&set->rm_ins[0].input.unknowns,
&serial_id))
abort();
msg = towire_tx_remove_input(ctx, cid, serial_id);
tal_arr_remove(&set->rm_ins, 0);
return msg;
}
if (tal_count(set->added_outs) != 0) {
struct amount_sat sats;
struct amount_asset asset_amt;
const struct output_set *out = &set->added_outs[0];
if (!psbt_get_serial_id(&out->output.unknowns, &serial_id))
abort();
asset_amt = wally_tx_output_get_amount(&out->tx_output);
sats = amount_asset_to_sat(&asset_amt);
const u8 *script = wally_tx_output_get_script(ctx,
&out->tx_output);
msg = towire_tx_add_output(ctx, cid, serial_id,
sats.satoshis, /* Raw: wire interface */
script);
tal_arr_remove(&set->added_outs, 0);
return msg;
}
if (tal_count(set->rm_outs) != 0) {
if (!psbt_get_serial_id(&set->rm_outs[0].output.unknowns,
&serial_id))
abort();
msg = towire_tx_remove_output(ctx, cid, serial_id);
/* Is this a kosher way to move the list forward? */
tal_arr_remove(&set->rm_outs, 0);
return msg;
}
return NULL;
}
void psbt_input_add_serial_id(struct wally_psbt_input *input,