cln-plugin: Add notification subscriptions and hooks to the plugins

For now hooks are treated identically to rpcmethods, with the
exception of not being returned in the `getmanifest` call. Later on we
can add typed handlers as well.
This commit is contained in:
Christian Decker
2022-02-23 19:00:25 +01:00
committed by Rusty Russell
parent 8c6af21169
commit 60e773239c
4 changed files with 141 additions and 75 deletions

View File

@@ -14,6 +14,8 @@ async fn main() -> Result<(), anyhow::Error> {
"a test-option with default 42", "a test-option with default 42",
)) ))
.rpcmethod("testmethod", "This is a test", Box::new(testmethod)) .rpcmethod("testmethod", "This is a test", Box::new(testmethod))
.subscribe("connect", Box::new(connect_handler))
.hook("peer_connected", Box::new(peer_connected_handler))
.start() .start()
.await?; .await?;
plugin.join().await plugin.join().await
@@ -22,3 +24,13 @@ async fn main() -> Result<(), anyhow::Error> {
fn testmethod(_p: Plugin<()>, _v: &serde_json::Value) -> Result<serde_json::Value, Error> { fn testmethod(_p: Plugin<()>, _v: &serde_json::Value) -> Result<serde_json::Value, Error> {
Ok(json!("Hello")) Ok(json!("Hello"))
} }
fn connect_handler(_p: Plugin<()>, v: &serde_json::Value) -> Result<(), Error> {
log::info!("Got a connect notification: {}", v);
Ok(())
}
fn peer_connected_handler(_p: Plugin<()>, v: &serde_json::Value) -> Result<serde_json::Value, Error> {
log::info!("Got a connect hook call: {}", v);
Ok(json!({"result": "continue"}))
}

View File

@@ -40,14 +40,10 @@ where
input: Option<I>, input: Option<I>,
output: Option<O>, output: Option<O>,
#[allow(dead_code)] hooks: HashMap<String, Hook<S>>,
hooks: Hooks,
#[allow(dead_code)]
subscriptions: Subscriptions,
options: Vec<ConfigOption>, options: Vec<ConfigOption>,
rpcmethods: HashMap<String, RpcMethod<S>>, rpcmethods: HashMap<String, RpcMethod<S>>,
subscriptions: HashMap<String, Subscription<S>>,
} }
impl<S, I, O> Builder<S, I, O> impl<S, I, O> Builder<S, I, O>
@@ -61,8 +57,8 @@ where
state, state,
input: Some(input), input: Some(input),
output: Some(output), output: Some(output),
hooks: Hooks::default(), hooks: HashMap::new(),
subscriptions: Subscriptions::default(), subscriptions: HashMap::new(),
options: vec![], options: vec![],
rpcmethods: HashMap::new(), rpcmethods: HashMap::new(),
} }
@@ -73,6 +69,21 @@ where
self self
} }
/// Subscribe to notifications for the given `topic`.
pub fn subscribe(mut self, topic: &str, callback: NotificationCallback<S>) -> Builder<S, I, O> {
self.subscriptions
.insert(topic.to_string(), Subscription { callback });
self
}
/// Add a subscription to a given `hookname`
pub fn hook(mut self, hookname: &str, callback: Callback<S>) -> Self {
self.hooks.insert(hookname.to_string(), Hook { callback });
self
}
/// Register a custom RPC method for the RPC passthrough from the
/// main daemon
pub fn rpcmethod( pub fn rpcmethod(
mut self, mut self,
name: &str, name: &str,
@@ -148,12 +159,6 @@ where
let (wait_handle, _) = tokio::sync::broadcast::channel(1); let (wait_handle, _) = tokio::sync::broadcast::channel(1);
// Collect the callbacks and create the hashmap for the dispatcher.
let mut rpcmethods = HashMap::new();
for (name, callback) in self.rpcmethods.drain().map(|(k, v)| (k, v.callback)) {
rpcmethods.insert(name, callback);
}
// An MPSC pair used by anything that needs to send messages // An MPSC pair used by anything that needs to send messages
// to the main daemon. // to the main daemon.
let (sender, receiver) = tokio::sync::mpsc::channel(4); let (sender, receiver) = tokio::sync::mpsc::channel(4);
@@ -164,11 +169,21 @@ where
sender, sender,
}; };
// TODO Split the two hashmaps once we fill in the hook
// payload structs in messages.rs
let mut rpcmethods: HashMap<String, Callback<S>> =
HashMap::from_iter(self.rpcmethods.drain().map(|(k, v)| (k, v.callback)));
rpcmethods.extend(self.hooks.clone().drain().map(|(k, v)| (k, v.callback)));
// Start the PluginDriver to handle plugin IO // Start the PluginDriver to handle plugin IO
tokio::spawn( tokio::spawn(
PluginDriver { PluginDriver {
plugin: plugin.clone(), plugin: plugin.clone(),
rpcmethods, rpcmethods,
hooks: HashMap::from_iter(self.hooks.drain().map(|(k, v)| (k, v.callback))),
subscriptions: HashMap::from_iter(
self.subscriptions.drain().map(|(k, v)| (k, v.callback)),
),
} }
.run(receiver, input, output), .run(receiver, input, output),
); );
@@ -192,6 +207,8 @@ where
messages::GetManifestResponse { messages::GetManifestResponse {
options: self.options.clone(), options: self.options.clone(),
subscriptions: self.subscriptions.keys().map(|s| s.clone()).collect(),
hooks: self.hooks.keys().map(|s| s.clone()).collect(),
rpcmethods, rpcmethods,
} }
} }
@@ -221,6 +238,7 @@ where
} }
type Callback<S> = Box<fn(Plugin<S>, &serde_json::Value) -> Result<serde_json::Value, Error>>; type Callback<S> = Box<fn(Plugin<S>, &serde_json::Value) -> Result<serde_json::Value, Error>>;
type NotificationCallback<S> = Box<fn(Plugin<S>, &serde_json::Value) -> Result<(), Error>>;
/// A struct collecting the metadata required to register a custom /// A struct collecting the metadata required to register a custom
/// rpcmethod with the main daemon upon init. It'll get deconstructed /// rpcmethod with the main daemon upon init. It'll get deconstructed
@@ -234,6 +252,21 @@ where
name: String, name: String,
} }
struct Subscription<S>
where
S: Clone + Send,
{
callback: NotificationCallback<S>,
}
#[derive(Clone)]
struct Hook<S>
where
S: Clone + Send,
{
callback: Callback<S>,
}
#[derive(Clone)] #[derive(Clone)]
pub struct Plugin<S> pub struct Plugin<S>
where where
@@ -258,9 +291,13 @@ struct PluginDriver<S>
where where
S: Send + Clone, S: Send + Clone,
{ {
#[allow(dead_code)]
plugin: Plugin<S>, plugin: Plugin<S>,
rpcmethods: HashMap<String, Callback<S>>, rpcmethods: HashMap<String, Callback<S>>,
#[allow(dead_code)] // Unused until we fill in the Hook structs.
hooks: HashMap<String, Callback<S>>,
subscriptions: HashMap<String, NotificationCallback<S>>,
} }
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -305,7 +342,7 @@ where
PluginDriver::<S>::dispatch_request(id, p, plugin).await PluginDriver::<S>::dispatch_request(id, p, plugin).await
} }
messages::JsonRpc::Notification(n) => { messages::JsonRpc::Notification(n) => {
PluginDriver::<S>::dispatch_notification(n, plugin).await self.dispatch_notification(n, plugin).await
} }
messages::JsonRpc::CustomRequest(id, p) => { messages::JsonRpc::CustomRequest(id, p) => {
match self.dispatch_custom_request(id, p, plugin).await { match self.dispatch_custom_request(id, p, plugin).await {
@@ -330,7 +367,7 @@ where
} }
} }
messages::JsonRpc::CustomNotification(n) => { messages::JsonRpc::CustomNotification(n) => {
PluginDriver::<S>::dispatch_custom_notification(n, plugin).await self.dispatch_custom_notification(n, plugin).await
} }
} }
} }
@@ -340,23 +377,24 @@ where
} }
async fn dispatch_request( async fn dispatch_request(
id: usize, _id: usize,
request: messages::Request, _request: messages::Request,
_plugin: &Plugin<S>, _plugin: &Plugin<S>,
) -> Result<(), Error> { ) -> Result<(), Error> {
panic!("Unexpected request {:?} with id {}", request, id); todo!("This is unreachable until we start filling in messages:Request. Until then the custom dispatcher below is used exclusively.")
} }
async fn dispatch_notification( async fn dispatch_notification(
notification: messages::Notification, &self,
_notification: messages::Notification,
_plugin: &Plugin<S>, _plugin: &Plugin<S>,
) -> Result<(), Error> ) -> Result<(), Error>
where where
S: Send + Clone, S: Send + Clone,
{ {
trace!("Dispatching notification {:?}", notification); todo!("As soon as we define the full structure of the messages::Notification we'll get here. Until then the custom dispatcher below is used.")
unimplemented!()
} }
async fn dispatch_custom_request( async fn dispatch_custom_request(
&self, &self,
_id: usize, _id: usize,
@@ -386,14 +424,35 @@ where
} }
async fn dispatch_custom_notification( async fn dispatch_custom_notification(
&self,
notification: serde_json::Value, notification: serde_json::Value,
_plugin: &Plugin<S>, plugin: &Plugin<S>,
) -> Result<(), Error> ) -> Result<(), Error>
where where
S: Send + Clone, S: Send + Clone,
{ {
trace!("Dispatching notification {:?}", notification); trace!("Dispatching custom notification {:?}", notification);
unimplemented!() let method = notification
.get("method")
.context("Missing 'method' in notification")?
.as_str()
.context("'method' is not a string")?;
let params = notification
.get("params")
.context("Missing 'params' field in notification")?;
let callback = self
.subscriptions
.get(method)
.with_context(|| anyhow!("No handler for method '{}' registered", method))?;
trace!(
"Dispatching custom request: method={}, params={}",
method,
params
);
if let Err(e) = callback(plugin.clone(), params) {
log::error!("Error in notification handler '{}': {}", method, e);
}
Ok(())
} }
} }
@@ -422,17 +481,6 @@ where
} }
} }
/// A container for all the configure hooks. It is just a collection
/// of callbacks that can be registered by the users of the
/// library. Based on this configuration we can then generate the
/// [`messages::GetManifestResponse`] from, populating our subscriptions
#[derive(Debug, Default)]
struct Hooks {}
/// A container for all the configured notifications.
#[derive(Debug, Default)]
struct Subscriptions {}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;

View File

@@ -12,50 +12,49 @@ pub(crate) enum Request {
// Builtin // Builtin
Getmanifest(GetManifestCall), Getmanifest(GetManifestCall),
Init(InitCall), Init(InitCall),
// Hooks // Hooks
PeerConnected, // PeerConnected,
CommitmentRevocation, // CommitmentRevocation,
DbWrite, // DbWrite,
InvoicePayment, // InvoicePayment,
Openchannel, // Openchannel,
Openchannel2, // Openchannel2,
Openchannel2Changed, // Openchannel2Changed,
Openchannel2Sign, // Openchannel2Sign,
RbfChannel, // RbfChannel,
HtlcAccepted, // HtlcAccepted,
RpcCommand, // RpcCommand,
Custommsg, // Custommsg,
OnionMessage, // OnionMessage,
OnionMessageBlinded, // OnionMessageBlinded,
OnionMessageOurpath, // OnionMessageOurpath,
// Bitcoin backend // Bitcoin backend
Getchaininfo, // Getchaininfo,
Estimatefees, // Estimatefees,
Getrawblockbyheight, // Getrawblockbyheight,
Getutxout, // Getutxout,
Sendrawtransaction, // Sendrawtransaction,
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
#[serde(tag = "method", content = "params")] #[serde(tag = "method", content = "params")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub(crate) enum Notification { pub(crate) enum Notification {
ChannelOpened, // ChannelOpened,
ChannelOpenFailed, // ChannelOpenFailed,
ChannelStateChanged, // ChannelStateChanged,
Connect, // Connect,
Disconnect, // Disconnect,
InvoicePayment, // InvoicePayment,
InvoiceCreation, // InvoiceCreation,
Warning, // Warning,
ForwardEvent, // ForwardEvent,
SendpaySuccess, // SendpaySuccess,
SendpayFailure, // SendpayFailure,
CoinMovement, // CoinMovement,
OpenchannelPeerSigs, // OpenchannelPeerSigs,
Shutdown, // Shutdown,
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@@ -128,6 +127,8 @@ pub(crate) struct RpcMethod {
pub(crate) struct GetManifestResponse { pub(crate) struct GetManifestResponse {
pub(crate) options: Vec<ConfigOption>, pub(crate) options: Vec<ConfigOption>,
pub(crate) rpcmethods: Vec<RpcMethod>, pub(crate) rpcmethods: Vec<RpcMethod>,
pub(crate) subscriptions: Vec<String>,
pub(crate) hooks: Vec<String>,
} }
#[derive(Serialize, Default, Debug)] #[derive(Serialize, Default, Debug)]

View File

@@ -25,6 +25,7 @@ def test_plugin_start(node_factory):
""" """
bin_path = Path.cwd() / "target" / "debug" / "examples" / "cln-plugin-startup" bin_path = Path.cwd() / "target" / "debug" / "examples" / "cln-plugin-startup"
l1 = node_factory.get_node(options={"plugin": str(bin_path), 'test-option': 31337}) l1 = node_factory.get_node(options={"plugin": str(bin_path), 'test-option': 31337})
l2 = node_factory.get_node()
cfg = l1.rpc.listconfigs() cfg = l1.rpc.listconfigs()
p = cfg['plugins'][0] p = cfg['plugins'][0]
@@ -52,3 +53,7 @@ def test_plugin_start(node_factory):
} }
assert l1.rpc.testmethod() == "Hello" assert l1.rpc.testmethod() == "Hello"
l1.connect(l2)
l1.daemon.wait_for_log(r'Got a connect hook call')
l1.daemon.wait_for_log(r'Got a connect notification')