mirror of
https://github.com/stakwork/sphinx-key.git
synced 2025-12-19 00:04:25 +01:00
utils for cleaner msg driver parsing and serialization
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use crate::mqtt::start_broker;
|
||||
use crate::ChannelRequest;
|
||||
use sphinx_key_parser::MsgDriver;
|
||||
use sphinx_key_parser as parser;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use vls_protocol::serde_bolt::WireString;
|
||||
use vls_protocol::{msgs, msgs::Message};
|
||||
@@ -31,24 +31,20 @@ pub async fn iteration(
|
||||
sequence: u16,
|
||||
tx: mpsc::Sender<ChannelRequest>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut md = MsgDriver::new_empty();
|
||||
msgs::write_serial_request_header(&mut md, sequence, 0)?;
|
||||
let ping = msgs::Ping {
|
||||
id,
|
||||
message: WireString("ping".as_bytes().to_vec()),
|
||||
};
|
||||
msgs::write(&mut md, ping)?;
|
||||
let ping_bytes = parser::request_from_msg(ping, sequence, 0)?;
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
// Send a request to the MQTT handler to send to signer
|
||||
let request = ChannelRequest {
|
||||
message: md.bytes(),
|
||||
message: ping_bytes,
|
||||
reply_tx,
|
||||
};
|
||||
let _ = tx.send(request).await;
|
||||
let res = reply_rx.await?;
|
||||
let mut ret = MsgDriver::new(res.reply);
|
||||
msgs::read_serial_response_header(&mut ret, sequence)?;
|
||||
let reply = msgs::read(&mut ret)?;
|
||||
let reply = parser::response_from_bytes(res.reply, sequence)?;
|
||||
match reply {
|
||||
Message::Pong(p) => {
|
||||
log::info!(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use log::*;
|
||||
use secp256k1::PublicKey;
|
||||
use sphinx_key_parser::MsgDriver;
|
||||
use sphinx_key_parser as parser;
|
||||
use std::thread;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
// use tokio::task::spawn_blocking;
|
||||
@@ -104,14 +104,11 @@ impl<C: 'static + Client> SignerLoop<C> {
|
||||
|
||||
fn handle_message(&mut self, message: Vec<u8>) -> Result<Vec<u8>> {
|
||||
let dbid = self.client_id.as_ref().map(|c| c.dbid).unwrap_or(0);
|
||||
let mut md = MsgDriver::new_empty();
|
||||
msgs::write_serial_request_header(&mut md, self.chan.sequence, dbid)?;
|
||||
msgs::write_vec(&mut md, message)?;
|
||||
let reply_rx = self.send_request(md.bytes())?;
|
||||
let mut res = self.get_reply(reply_rx)?;
|
||||
msgs::read_serial_response_header(&mut res, self.chan.sequence)?;
|
||||
let md = parser::raw_request_from_bytes(message, self.chan.sequence, dbid)?;
|
||||
let reply_rx = self.send_request(md)?;
|
||||
let res = self.get_reply(reply_rx)?;
|
||||
let reply = parser::raw_response_from_bytes(res, self.chan.sequence)?;
|
||||
self.chan.sequence = self.chan.sequence.wrapping_add(1);
|
||||
let reply = msgs::read_raw(&mut res)?;
|
||||
Ok(reply)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
vls-protocol = { path = "../../validating-lightning-signer/vls-protocol" }
|
||||
serde = { version = "1.0", default-features = false }
|
||||
serde_bolt = { version = "0.2", default-features = false }
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
|
||||
@@ -1,6 +1,62 @@
|
||||
use vls_protocol::serde_bolt::{self, Read, Write};
|
||||
use std::io;
|
||||
use serde::ser;
|
||||
use std::cmp::min;
|
||||
use std::io;
|
||||
use vls_protocol::msgs::{self, DeBolt, Message};
|
||||
use vls_protocol::serde_bolt::{Error, Read, Result, Write};
|
||||
|
||||
pub fn raw_request_from_bytes(
|
||||
message: Vec<u8>,
|
||||
sequence: u16,
|
||||
dbid: u64,
|
||||
) -> vls_protocol::Result<Vec<u8>> {
|
||||
let mut md = MsgDriver::new_empty();
|
||||
msgs::write_serial_request_header(&mut md, sequence, dbid)?;
|
||||
msgs::write_vec(&mut md, message)?;
|
||||
Ok(md.bytes())
|
||||
}
|
||||
|
||||
pub fn request_from_msg<T: ser::Serialize + DeBolt>(
|
||||
msg: T,
|
||||
sequence: u16,
|
||||
dbid: u64,
|
||||
) -> vls_protocol::Result<Vec<u8>> {
|
||||
let mut md = MsgDriver::new_empty();
|
||||
msgs::write_serial_request_header(&mut md, sequence, dbid)?;
|
||||
msgs::write(&mut md, msg).expect("failed to serial write");
|
||||
Ok(md.bytes())
|
||||
}
|
||||
|
||||
pub fn raw_response_from_msg<T: ser::Serialize + DeBolt>(
|
||||
msg: T,
|
||||
sequence: u16,
|
||||
) -> vls_protocol::Result<Vec<u8>> {
|
||||
let mut m = MsgDriver::new_empty();
|
||||
msgs::write_serial_response_header(&mut m, sequence)?;
|
||||
msgs::write(&mut m, msg)?;
|
||||
Ok(m.bytes())
|
||||
}
|
||||
|
||||
pub fn request_from_bytes<T: DeBolt>(msg: Vec<u8>) -> vls_protocol::Result<(T, u16, u64)> {
|
||||
let mut m = MsgDriver::new(msg);
|
||||
let (sequence, dbid) = msgs::read_serial_request_header(&mut m).expect("read ping header");
|
||||
let reply: T = msgs::read_message(&mut m).expect("failed to read ping message");
|
||||
Ok((reply, sequence, dbid))
|
||||
}
|
||||
|
||||
pub fn raw_response_from_bytes(
|
||||
res: Vec<u8>,
|
||||
expected_sequence: u16,
|
||||
) -> vls_protocol::Result<Vec<u8>> {
|
||||
let mut m = MsgDriver::new(res);
|
||||
msgs::read_serial_response_header(&mut m, expected_sequence)?;
|
||||
Ok(msgs::read_raw(&mut m)?)
|
||||
}
|
||||
|
||||
pub fn response_from_bytes(res: Vec<u8>, expected_sequence: u16) -> vls_protocol::Result<Message> {
|
||||
let mut m = MsgDriver::new(res);
|
||||
msgs::read_serial_response_header(&mut m, expected_sequence)?;
|
||||
Ok(msgs::read(&mut m)?)
|
||||
}
|
||||
|
||||
pub struct MsgDriver(Vec<u8>);
|
||||
|
||||
@@ -20,35 +76,37 @@ impl MsgDriver {
|
||||
}
|
||||
|
||||
impl Read for MsgDriver {
|
||||
type Error = serde_bolt::Error;
|
||||
type Error = Error;
|
||||
|
||||
// input: buf to be written. Should already be the right size
|
||||
fn read(&mut self, mut buf: &mut [u8]) -> serde_bolt::Result<usize> {
|
||||
fn read(&mut self, mut buf: &mut [u8]) -> Result<usize> {
|
||||
if buf.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
let (mut content, remaining) = self.0.split_at(
|
||||
min(buf.len(), self.0.len())
|
||||
);
|
||||
let (mut content, remaining) = self.0.split_at(min(buf.len(), self.0.len()));
|
||||
let bytes = &mut content;
|
||||
match io::copy(bytes, &mut buf) {
|
||||
Ok(len) => {
|
||||
self.0 = remaining.to_vec();
|
||||
Ok(len as usize)
|
||||
},
|
||||
Err(_) => Ok(0)
|
||||
}
|
||||
Err(_) => Ok(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn peek(&mut self) -> serde_bolt::Result<Option<u8>> {
|
||||
Ok(if let Some(u) = self.0.get(0) { Some(u.clone()) } else { None})
|
||||
fn peek(&mut self) -> Result<Option<u8>> {
|
||||
Ok(if let Some(u) = self.0.get(0) {
|
||||
Some(u.clone())
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for MsgDriver {
|
||||
type Error = serde_bolt::Error;
|
||||
type Error = Error;
|
||||
|
||||
fn write_all(&mut self, buf: &[u8]) -> serde_bolt::Result<()> {
|
||||
fn write_all(&mut self, buf: &[u8]) -> Result<()> {
|
||||
self.0.extend(buf.iter().cloned());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use sphinx_key_parser as parser;
|
||||
|
||||
use sphinx_key_parser::MsgDriver;
|
||||
|
||||
use rumqttc::{self, AsyncClient, MqttOptions, QoS, Event, Packet};
|
||||
use rumqttc::{self, AsyncClient, Event, MqttOptions, Packet, QoS};
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use vls_protocol::msgs;
|
||||
@@ -27,39 +26,42 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||
.await
|
||||
.expect("could not mqtt subscribe");
|
||||
|
||||
client.publish(PUB_TOPIC, QoS::AtMostOnce, false, "READY".as_bytes().to_vec()).await.expect("could not pub");
|
||||
client
|
||||
.publish(
|
||||
PUB_TOPIC,
|
||||
QoS::AtMostOnce,
|
||||
false,
|
||||
"READY".as_bytes().to_vec(),
|
||||
)
|
||||
.await
|
||||
.expect("could not pub");
|
||||
|
||||
loop {
|
||||
let event = eventloop.poll().await;
|
||||
// println!("{:?}", event.unwrap());
|
||||
if let Some(mut m) = incoming_msg(event.expect("failed to unwrap event")) {
|
||||
let (sequence, dbid) = msgs::read_serial_request_header(&mut m).expect("read ping header");
|
||||
if let Some(bs) = incoming_bytes(event.expect("failed to unwrap event")) {
|
||||
let (ping, sequence, dbid): (msgs::Ping, u16, u64) =
|
||||
parser::request_from_bytes(bs).expect("read ping header");
|
||||
println!("sequence {}", sequence);
|
||||
println!("dbid {}", dbid);
|
||||
let ping: msgs::Ping =
|
||||
msgs::read_message(&mut m).expect("failed to read ping message");
|
||||
println!("INCOMING: {:?}", ping);
|
||||
let mut md = MsgDriver::new_empty();
|
||||
msgs::write_serial_response_header(&mut md, sequence)
|
||||
.expect("failed to write_serial_request_header");
|
||||
let pong = msgs::Pong {
|
||||
id: ping.id,
|
||||
message: ping.message
|
||||
message: ping.message,
|
||||
};
|
||||
msgs::write(&mut md, pong).expect("failed to serial write");
|
||||
let bytes = parser::raw_response_from_msg(pong, sequence)?;
|
||||
client
|
||||
.publish(PUB_TOPIC, QoS::AtMostOnce, false, md.bytes())
|
||||
.publish(PUB_TOPIC, QoS::AtMostOnce, false, bytes)
|
||||
.await
|
||||
.expect("could not mqtt publish");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn incoming_msg(event: Event) -> Option<MsgDriver> {
|
||||
fn incoming_bytes(event: Event) -> Option<Vec<u8>> {
|
||||
if let Event::Incoming(packet) = event {
|
||||
if let Packet::Publish(p) = packet {
|
||||
let m = MsgDriver::new(p.payload.to_vec());
|
||||
return Some(m)
|
||||
return Some(p.payload.to_vec());
|
||||
}
|
||||
}
|
||||
None
|
||||
|
||||
Reference in New Issue
Block a user