common: use struct onionreply.

This makes it clear we're dealing with a message which is a wrapped error
reply (needing unwrap_onionreply), not an already-wrapped one.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2020-01-23 10:08:04 +10:30
parent aa6aad0131
commit 1099f6a5e1
27 changed files with 126 additions and 90 deletions

View File

@@ -6,6 +6,7 @@
#include <ccan/mem/mem.h>
#include <common/node_id.h>
#include <common/onion.h>
#include <common/onionreply.h>
#include <common/sphinx.h>
#include <common/utils.h>
@@ -529,12 +530,14 @@ struct route_step *process_onionpacket(
return step;
}
u8 *create_onionreply(const tal_t *ctx, const struct secret *shared_secret,
const u8 *failure_msg)
struct onionreply *create_onionreply(const tal_t *ctx,
const struct secret *shared_secret,
const u8 *failure_msg)
{
size_t msglen = tal_count(failure_msg);
size_t padlen = ONION_REPLY_SIZE - msglen;
u8 *reply = tal_arr(ctx, u8, 0), *payload = tal_arr(ctx, u8, 0);
struct onionreply *reply = tal(ctx, struct onionreply);
u8 *payload = tal_arr(ctx, u8, 0);
u8 key[KEY_LEN];
u8 hmac[HMAC_SIZE];
@@ -574,21 +577,23 @@ u8 *create_onionreply(const tal_t *ctx, const struct secret *shared_secret,
generate_key(key, "um", 2, shared_secret->data);
compute_hmac(hmac, payload, tal_count(payload), key, KEY_LEN);
towire(&reply, hmac, sizeof(hmac));
reply->contents = tal_arr(reply, u8, 0),
towire(&reply->contents, hmac, sizeof(hmac));
towire(&reply, payload, tal_count(payload));
towire(&reply->contents, payload, tal_count(payload));
tal_free(payload);
return reply;
}
u8 *wrap_onionreply(const tal_t *ctx,
const struct secret *shared_secret, const u8 *reply)
struct onionreply *wrap_onionreply(const tal_t *ctx,
const struct secret *shared_secret,
const struct onionreply *reply)
{
u8 key[KEY_LEN];
size_t streamlen = tal_count(reply);
size_t streamlen = tal_count(reply->contents);
u8 stream[streamlen];
u8 *result = tal_arr(ctx, u8, streamlen);
struct onionreply *result = tal(ctx, struct onionreply);
/* BOLT #4:
*
@@ -600,39 +605,43 @@ u8 *wrap_onionreply(const tal_t *ctx,
*/
generate_key(key, "ammag", 5, shared_secret->data);
generate_cipher_stream(stream, key, streamlen);
xorbytes(result, stream, reply, streamlen);
result->contents = tal_arr(result, u8, streamlen);
xorbytes(result->contents, stream, reply->contents, streamlen);
return result;
}
u8 *unwrap_onionreply(const tal_t *ctx,
const struct secret *shared_secrets,
const int numhops, const u8 *reply,
const int numhops,
const struct onionreply *reply,
int *origin_index)
{
u8 *msg = tal_arr(tmpctx, u8, tal_count(reply)), *final;
struct onionreply *r;
u8 key[KEY_LEN], hmac[HMAC_SIZE];
const u8 *cursor;
u8 *final;
size_t max;
u16 msglen;
if (tal_count(reply) != ONION_REPLY_SIZE + sizeof(hmac) + 4) {
if (tal_count(reply->contents) != ONION_REPLY_SIZE + sizeof(hmac) + 4) {
return NULL;
}
memcpy(msg, reply, tal_count(reply));
r = new_onionreply(tmpctx, reply->contents);
*origin_index = -1;
for (int i = 0; i < numhops; i++) {
/* Since the encryption is just XORing with the cipher
* stream encryption is identical to decryption */
msg = wrap_onionreply(tmpctx, &shared_secrets[i], msg);
r = wrap_onionreply(tmpctx, &shared_secrets[i], r);
/* Check if the HMAC matches, this means that this is
* the origin */
generate_key(key, "um", 2, shared_secrets[i].data);
compute_hmac(hmac, msg + sizeof(hmac),
tal_count(msg) - sizeof(hmac), key, KEY_LEN);
if (memcmp(hmac, msg, sizeof(hmac)) == 0) {
compute_hmac(hmac, r->contents + sizeof(hmac),
tal_count(r->contents) - sizeof(hmac),
key, KEY_LEN);
if (memcmp(hmac, r->contents, sizeof(hmac)) == 0) {
*origin_index = i;
break;
}
@@ -641,8 +650,8 @@ u8 *unwrap_onionreply(const tal_t *ctx,
return NULL;
}
cursor = msg + sizeof(hmac);
max = tal_count(msg) - sizeof(hmac);
cursor = r->contents + sizeof(hmac);
max = tal_count(r->contents) - sizeof(hmac);
msglen = fromwire_u16(&cursor, &max);
if (msglen > ONION_REPLY_SIZE) {
@@ -650,7 +659,7 @@ u8 *unwrap_onionreply(const tal_t *ctx,
}
final = tal_arr(ctx, u8, msglen);
fromwire(&cursor, &max, final, msglen);
if (!fromwire(&cursor, &max, final, msglen))
return tal_free(final);
return final;
}