utils for cleaner msg driver parsing and serialization

This commit is contained in:
Evan Feenstra
2022-06-08 11:15:56 -07:00
parent a58a09f099
commit 919b6e571e
5 changed files with 102 additions and 47 deletions

View File

@@ -1,6 +1,6 @@
use crate::mqtt::start_broker; use crate::mqtt::start_broker;
use crate::ChannelRequest; use crate::ChannelRequest;
use sphinx_key_parser::MsgDriver; use sphinx_key_parser as parser;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use vls_protocol::serde_bolt::WireString; use vls_protocol::serde_bolt::WireString;
use vls_protocol::{msgs, msgs::Message}; use vls_protocol::{msgs, msgs::Message};
@@ -31,24 +31,20 @@ pub async fn iteration(
sequence: u16, sequence: u16,
tx: mpsc::Sender<ChannelRequest>, tx: mpsc::Sender<ChannelRequest>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut md = MsgDriver::new_empty();
msgs::write_serial_request_header(&mut md, sequence, 0)?;
let ping = msgs::Ping { let ping = msgs::Ping {
id, id,
message: WireString("ping".as_bytes().to_vec()), message: WireString("ping".as_bytes().to_vec()),
}; };
msgs::write(&mut md, ping)?; let ping_bytes = parser::request_from_msg(ping, sequence, 0)?;
let (reply_tx, reply_rx) = oneshot::channel(); let (reply_tx, reply_rx) = oneshot::channel();
// Send a request to the MQTT handler to send to signer // Send a request to the MQTT handler to send to signer
let request = ChannelRequest { let request = ChannelRequest {
message: md.bytes(), message: ping_bytes,
reply_tx, reply_tx,
}; };
let _ = tx.send(request).await; let _ = tx.send(request).await;
let res = reply_rx.await?; let res = reply_rx.await?;
let mut ret = MsgDriver::new(res.reply); let reply = parser::response_from_bytes(res.reply, sequence)?;
msgs::read_serial_response_header(&mut ret, sequence)?;
let reply = msgs::read(&mut ret)?;
match reply { match reply {
Message::Pong(p) => { Message::Pong(p) => {
log::info!( log::info!(

View File

@@ -1,6 +1,6 @@
use log::*; use log::*;
use secp256k1::PublicKey; use secp256k1::PublicKey;
use sphinx_key_parser::MsgDriver; use sphinx_key_parser as parser;
use std::thread; use std::thread;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
// use tokio::task::spawn_blocking; // use tokio::task::spawn_blocking;
@@ -104,14 +104,11 @@ impl<C: 'static + Client> SignerLoop<C> {
fn handle_message(&mut self, message: Vec<u8>) -> Result<Vec<u8>> { fn handle_message(&mut self, message: Vec<u8>) -> Result<Vec<u8>> {
let dbid = self.client_id.as_ref().map(|c| c.dbid).unwrap_or(0); let dbid = self.client_id.as_ref().map(|c| c.dbid).unwrap_or(0);
let mut md = MsgDriver::new_empty(); let md = parser::raw_request_from_bytes(message, self.chan.sequence, dbid)?;
msgs::write_serial_request_header(&mut md, self.chan.sequence, dbid)?; let reply_rx = self.send_request(md)?;
msgs::write_vec(&mut md, message)?; let res = self.get_reply(reply_rx)?;
let reply_rx = self.send_request(md.bytes())?; let reply = parser::raw_response_from_bytes(res, self.chan.sequence)?;
let mut res = self.get_reply(reply_rx)?;
msgs::read_serial_response_header(&mut res, self.chan.sequence)?;
self.chan.sequence = self.chan.sequence.wrapping_add(1); self.chan.sequence = self.chan.sequence.wrapping_add(1);
let reply = msgs::read_raw(&mut res)?;
Ok(reply) Ok(reply)
} }

View File

@@ -5,6 +5,8 @@ edition = "2021"
[dependencies] [dependencies]
vls-protocol = { path = "../../validating-lightning-signer/vls-protocol" } vls-protocol = { path = "../../validating-lightning-signer/vls-protocol" }
serde = { version = "1.0", default-features = false }
serde_bolt = { version = "0.2", default-features = false }
[features] [features]
default = ["std"] default = ["std"]

View File

@@ -1,6 +1,62 @@
use vls_protocol::serde_bolt::{self, Read, Write}; use serde::ser;
use std::io;
use std::cmp::min; use std::cmp::min;
use std::io;
use vls_protocol::msgs::{self, DeBolt, Message};
use vls_protocol::serde_bolt::{Error, Read, Result, Write};
pub fn raw_request_from_bytes(
message: Vec<u8>,
sequence: u16,
dbid: u64,
) -> vls_protocol::Result<Vec<u8>> {
let mut md = MsgDriver::new_empty();
msgs::write_serial_request_header(&mut md, sequence, dbid)?;
msgs::write_vec(&mut md, message)?;
Ok(md.bytes())
}
pub fn request_from_msg<T: ser::Serialize + DeBolt>(
msg: T,
sequence: u16,
dbid: u64,
) -> vls_protocol::Result<Vec<u8>> {
let mut md = MsgDriver::new_empty();
msgs::write_serial_request_header(&mut md, sequence, dbid)?;
msgs::write(&mut md, msg).expect("failed to serial write");
Ok(md.bytes())
}
pub fn raw_response_from_msg<T: ser::Serialize + DeBolt>(
msg: T,
sequence: u16,
) -> vls_protocol::Result<Vec<u8>> {
let mut m = MsgDriver::new_empty();
msgs::write_serial_response_header(&mut m, sequence)?;
msgs::write(&mut m, msg)?;
Ok(m.bytes())
}
pub fn request_from_bytes<T: DeBolt>(msg: Vec<u8>) -> vls_protocol::Result<(T, u16, u64)> {
let mut m = MsgDriver::new(msg);
let (sequence, dbid) = msgs::read_serial_request_header(&mut m).expect("read ping header");
let reply: T = msgs::read_message(&mut m).expect("failed to read ping message");
Ok((reply, sequence, dbid))
}
pub fn raw_response_from_bytes(
res: Vec<u8>,
expected_sequence: u16,
) -> vls_protocol::Result<Vec<u8>> {
let mut m = MsgDriver::new(res);
msgs::read_serial_response_header(&mut m, expected_sequence)?;
Ok(msgs::read_raw(&mut m)?)
}
pub fn response_from_bytes(res: Vec<u8>, expected_sequence: u16) -> vls_protocol::Result<Message> {
let mut m = MsgDriver::new(res);
msgs::read_serial_response_header(&mut m, expected_sequence)?;
Ok(msgs::read(&mut m)?)
}
pub struct MsgDriver(Vec<u8>); pub struct MsgDriver(Vec<u8>);
@@ -20,35 +76,37 @@ impl MsgDriver {
} }
impl Read for MsgDriver { impl Read for MsgDriver {
type Error = serde_bolt::Error; type Error = Error;
// input: buf to be written. Should already be the right size // input: buf to be written. Should already be the right size
fn read(&mut self, mut buf: &mut [u8]) -> serde_bolt::Result<usize> { fn read(&mut self, mut buf: &mut [u8]) -> Result<usize> {
if buf.is_empty() { if buf.is_empty() {
return Ok(0); return Ok(0);
} }
let (mut content, remaining) = self.0.split_at( let (mut content, remaining) = self.0.split_at(min(buf.len(), self.0.len()));
min(buf.len(), self.0.len())
);
let bytes = &mut content; let bytes = &mut content;
match io::copy(bytes, &mut buf) { match io::copy(bytes, &mut buf) {
Ok(len) => { Ok(len) => {
self.0 = remaining.to_vec(); self.0 = remaining.to_vec();
Ok(len as usize) Ok(len as usize)
}, }
Err(_) => Ok(0) Err(_) => Ok(0),
} }
} }
fn peek(&mut self) -> serde_bolt::Result<Option<u8>> { fn peek(&mut self) -> Result<Option<u8>> {
Ok(if let Some(u) = self.0.get(0) { Some(u.clone()) } else { None}) Ok(if let Some(u) = self.0.get(0) {
Some(u.clone())
} else {
None
})
} }
} }
impl Write for MsgDriver { impl Write for MsgDriver {
type Error = serde_bolt::Error; type Error = Error;
fn write_all(&mut self, buf: &[u8]) -> serde_bolt::Result<()> { fn write_all(&mut self, buf: &[u8]) -> Result<()> {
self.0.extend(buf.iter().cloned()); self.0.extend(buf.iter().cloned());
Ok(()) Ok(())
} }

View File

@@ -1,7 +1,6 @@
use sphinx_key_parser as parser;
use sphinx_key_parser::MsgDriver; use rumqttc::{self, AsyncClient, Event, MqttOptions, Packet, QoS};
use rumqttc::{self, AsyncClient, MqttOptions, QoS, Event, Packet};
use std::error::Error; use std::error::Error;
use std::time::Duration; use std::time::Duration;
use vls_protocol::msgs; use vls_protocol::msgs;
@@ -27,39 +26,42 @@ async fn main() -> Result<(), Box<dyn Error>> {
.await .await
.expect("could not mqtt subscribe"); .expect("could not mqtt subscribe");
client.publish(PUB_TOPIC, QoS::AtMostOnce, false, "READY".as_bytes().to_vec()).await.expect("could not pub"); client
.publish(
PUB_TOPIC,
QoS::AtMostOnce,
false,
"READY".as_bytes().to_vec(),
)
.await
.expect("could not pub");
loop { loop {
let event = eventloop.poll().await; let event = eventloop.poll().await;
// println!("{:?}", event.unwrap()); // println!("{:?}", event.unwrap());
if let Some(mut m) = incoming_msg(event.expect("failed to unwrap event")) { if let Some(bs) = incoming_bytes(event.expect("failed to unwrap event")) {
let (sequence, dbid) = msgs::read_serial_request_header(&mut m).expect("read ping header"); let (ping, sequence, dbid): (msgs::Ping, u16, u64) =
parser::request_from_bytes(bs).expect("read ping header");
println!("sequence {}", sequence); println!("sequence {}", sequence);
println!("dbid {}", dbid); println!("dbid {}", dbid);
let ping: msgs::Ping =
msgs::read_message(&mut m).expect("failed to read ping message");
println!("INCOMING: {:?}", ping); println!("INCOMING: {:?}", ping);
let mut md = MsgDriver::new_empty();
msgs::write_serial_response_header(&mut md, sequence)
.expect("failed to write_serial_request_header");
let pong = msgs::Pong { let pong = msgs::Pong {
id: ping.id, id: ping.id,
message: ping.message message: ping.message,
}; };
msgs::write(&mut md, pong).expect("failed to serial write"); let bytes = parser::raw_response_from_msg(pong, sequence)?;
client client
.publish(PUB_TOPIC, QoS::AtMostOnce, false, md.bytes()) .publish(PUB_TOPIC, QoS::AtMostOnce, false, bytes)
.await .await
.expect("could not mqtt publish"); .expect("could not mqtt publish");
} }
} }
} }
fn incoming_msg(event: Event) -> Option<MsgDriver> { fn incoming_bytes(event: Event) -> Option<Vec<u8>> {
if let Event::Incoming(packet) = event { if let Event::Incoming(packet) = event {
if let Packet::Publish(p) = packet { if let Packet::Publish(p) = packet {
let m = MsgDriver::new(p.payload.to_vec()); return Some(p.payload.to_vec());
return Some(m)
} }
} }
None None