diff --git a/common/derive_basepoints.c b/common/derive_basepoints.c index 1f03236a4..176963da6 100644 --- a/common/derive_basepoints.c +++ b/common/derive_basepoints.c @@ -51,26 +51,30 @@ bool derive_basepoints(const struct secret *seed, return true; } -void per_commit_secret(const struct sha256 *shaseed, +bool per_commit_secret(const struct sha256 *shaseed, struct secret *commit_secret, u64 per_commit_index) { struct sha256 s; + + if (per_commit_index >= (1ULL << SHACHAIN_BITS)) + return false; + shachain_from_seed(shaseed, shachain_index(per_commit_index), &s); BUILD_ASSERT(sizeof(s) == sizeof(*commit_secret)); memcpy(commit_secret, &s, sizeof(s)); + return true; } bool per_commit_point(const struct sha256 *shaseed, struct pubkey *commit_point, u64 per_commit_index) { - struct sha256 per_commit_secret; + struct secret secret; - /* Derive new per-commitment-point. */ - shachain_from_seed(shaseed, shachain_index(per_commit_index), - &per_commit_secret); + if (!per_commit_secret(shaseed, &secret, per_commit_index)) + return false; /* BOLT #3: * @@ -81,7 +85,7 @@ bool per_commit_point(const struct sha256 *shaseed, */ if (secp256k1_ec_pubkey_create(secp256k1_ctx, &commit_point->pubkey, - per_commit_secret.u.u8) != 1) + secret.data) != 1) return false; return true; @@ -235,6 +239,9 @@ bool shachain_get_secret(const struct shachain *shachain, { struct sha256 sha; + if (commit_num >= (1ULL << SHACHAIN_BITS)) + return false; + if (!shachain_get_hash(shachain, shachain_index(commit_num), &sha)) return false; BUILD_ASSERT(sizeof(*preimage) == sizeof(sha)); diff --git a/common/derive_basepoints.h b/common/derive_basepoints.h index 6dc24a9b7..84b6c25c5 100644 --- a/common/derive_basepoints.h +++ b/common/derive_basepoints.h @@ -112,8 +112,10 @@ bool derive_htlc_basepoint(const struct secret *seed, * @shaseed: the sha256 seed * @commit_secret: the returned per-commit secret. * @per_commit_index: (in) which @commit_secret to return. + * + * Returns false if per_commit_index is invalid, or can't derive. */ -void per_commit_secret(const struct sha256 *shaseed, +bool per_commit_secret(const struct sha256 *shaseed, struct secret *commit_secret, u64 per_commit_index); diff --git a/hsmd/hsm.c b/hsmd/hsm.c index 6c62a3f51..deb4699d9 100644 --- a/hsmd/hsm.c +++ b/hsmd/hsm.c @@ -628,7 +628,14 @@ handle_get_per_commitment_point(struct io_conn *conn, struct client *c) if (n >= 2) { old_secret = tal(tmpctx, struct secret); - per_commit_secret(&shaseed, old_secret, n - 2); + if (!per_commit_secret(&shaseed, old_secret, n - 2)) { + status_broken("Cannot derive secret %"PRIu64 + " for client %s", + n - 1, + type_to_string(tmpctx, + struct pubkey, &c->id)); + goto fail; + } } else old_secret = NULL;