diff --git a/plugins/examples/cln-plugin-startup.rs b/plugins/examples/cln-plugin-startup.rs index 506f8fbef..a377d81d6 100644 --- a/plugins/examples/cln-plugin-startup.rs +++ b/plugins/examples/cln-plugin-startup.rs @@ -14,6 +14,8 @@ async fn main() -> Result<(), anyhow::Error> { "a test-option with default 42", )) .rpcmethod("testmethod", "This is a test", Box::new(testmethod)) + .subscribe("connect", Box::new(connect_handler)) + .hook("peer_connected", Box::new(peer_connected_handler)) .start() .await?; plugin.join().await @@ -22,3 +24,13 @@ async fn main() -> Result<(), anyhow::Error> { fn testmethod(_p: Plugin<()>, _v: &serde_json::Value) -> Result { Ok(json!("Hello")) } + +fn connect_handler(_p: Plugin<()>, v: &serde_json::Value) -> Result<(), Error> { + log::info!("Got a connect notification: {}", v); + Ok(()) +} + +fn peer_connected_handler(_p: Plugin<()>, v: &serde_json::Value) -> Result { + log::info!("Got a connect hook call: {}", v); + Ok(json!({"result": "continue"})) +} diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index faf929b25..b77064995 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -40,14 +40,10 @@ where input: Option, output: Option, - #[allow(dead_code)] - hooks: Hooks, - - #[allow(dead_code)] - subscriptions: Subscriptions, - + hooks: HashMap>, options: Vec, rpcmethods: HashMap>, + subscriptions: HashMap>, } impl Builder @@ -61,8 +57,8 @@ where state, input: Some(input), output: Some(output), - hooks: Hooks::default(), - subscriptions: Subscriptions::default(), + hooks: HashMap::new(), + subscriptions: HashMap::new(), options: vec![], rpcmethods: HashMap::new(), } @@ -73,6 +69,21 @@ where self } + /// Subscribe to notifications for the given `topic`. + pub fn subscribe(mut self, topic: &str, callback: NotificationCallback) -> Builder { + self.subscriptions + .insert(topic.to_string(), Subscription { callback }); + self + } + + /// Add a subscription to a given `hookname` + pub fn hook(mut self, hookname: &str, callback: Callback) -> Self { + self.hooks.insert(hookname.to_string(), Hook { callback }); + self + } + + /// Register a custom RPC method for the RPC passthrough from the + /// main daemon pub fn rpcmethod( mut self, name: &str, @@ -148,12 +159,6 @@ where let (wait_handle, _) = tokio::sync::broadcast::channel(1); - // Collect the callbacks and create the hashmap for the dispatcher. - let mut rpcmethods = HashMap::new(); - for (name, callback) in self.rpcmethods.drain().map(|(k, v)| (k, v.callback)) { - rpcmethods.insert(name, callback); - } - // An MPSC pair used by anything that needs to send messages // to the main daemon. let (sender, receiver) = tokio::sync::mpsc::channel(4); @@ -164,11 +169,21 @@ where sender, }; + // TODO Split the two hashmaps once we fill in the hook + // payload structs in messages.rs + let mut rpcmethods: HashMap> = + HashMap::from_iter(self.rpcmethods.drain().map(|(k, v)| (k, v.callback))); + rpcmethods.extend(self.hooks.clone().drain().map(|(k, v)| (k, v.callback))); + // Start the PluginDriver to handle plugin IO tokio::spawn( PluginDriver { plugin: plugin.clone(), rpcmethods, + hooks: HashMap::from_iter(self.hooks.drain().map(|(k, v)| (k, v.callback))), + subscriptions: HashMap::from_iter( + self.subscriptions.drain().map(|(k, v)| (k, v.callback)), + ), } .run(receiver, input, output), ); @@ -192,6 +207,8 @@ where messages::GetManifestResponse { options: self.options.clone(), + subscriptions: self.subscriptions.keys().map(|s| s.clone()).collect(), + hooks: self.hooks.keys().map(|s| s.clone()).collect(), rpcmethods, } } @@ -221,6 +238,7 @@ where } type Callback = Box, &serde_json::Value) -> Result>; +type NotificationCallback = Box, &serde_json::Value) -> Result<(), Error>>; /// A struct collecting the metadata required to register a custom /// rpcmethod with the main daemon upon init. It'll get deconstructed @@ -234,6 +252,21 @@ where name: String, } +struct Subscription +where + S: Clone + Send, +{ + callback: NotificationCallback, +} + +#[derive(Clone)] +struct Hook +where + S: Clone + Send, +{ + callback: Callback, +} + #[derive(Clone)] pub struct Plugin where @@ -258,9 +291,13 @@ struct PluginDriver where S: Send + Clone, { - #[allow(dead_code)] + plugin: Plugin, rpcmethods: HashMap>, + + #[allow(dead_code)] // Unused until we fill in the Hook structs. + hooks: HashMap>, + subscriptions: HashMap>, } use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -281,9 +318,9 @@ where { loop { tokio::select! { - _ = self.dispatch_one(&mut input, &self.plugin) => {}, - v = receiver.recv() => {output.lock().await.send(v.unwrap()).await?}, - } + _ = self.dispatch_one(&mut input, &self.plugin) => {}, + v = receiver.recv() => {output.lock().await.send(v.unwrap()).await?}, + } } } @@ -305,7 +342,7 @@ where PluginDriver::::dispatch_request(id, p, plugin).await } messages::JsonRpc::Notification(n) => { - PluginDriver::::dispatch_notification(n, plugin).await + self.dispatch_notification(n, plugin).await } messages::JsonRpc::CustomRequest(id, p) => { match self.dispatch_custom_request(id, p, plugin).await { @@ -330,7 +367,7 @@ where } } messages::JsonRpc::CustomNotification(n) => { - PluginDriver::::dispatch_custom_notification(n, plugin).await + self.dispatch_custom_notification(n, plugin).await } } } @@ -340,23 +377,24 @@ where } async fn dispatch_request( - id: usize, - request: messages::Request, + _id: usize, + _request: messages::Request, _plugin: &Plugin, ) -> Result<(), Error> { - panic!("Unexpected request {:?} with id {}", request, id); + todo!("This is unreachable until we start filling in messages:Request. Until then the custom dispatcher below is used exclusively.") } async fn dispatch_notification( - notification: messages::Notification, + &self, + _notification: messages::Notification, _plugin: &Plugin, ) -> Result<(), Error> where S: Send + Clone, { - trace!("Dispatching notification {:?}", notification); - unimplemented!() + todo!("As soon as we define the full structure of the messages::Notification we'll get here. Until then the custom dispatcher below is used.") } + async fn dispatch_custom_request( &self, _id: usize, @@ -386,14 +424,35 @@ where } async fn dispatch_custom_notification( + &self, notification: serde_json::Value, - _plugin: &Plugin, + plugin: &Plugin, ) -> Result<(), Error> where S: Send + Clone, { - trace!("Dispatching notification {:?}", notification); - unimplemented!() + trace!("Dispatching custom notification {:?}", notification); + let method = notification + .get("method") + .context("Missing 'method' in notification")? + .as_str() + .context("'method' is not a string")?; + let params = notification + .get("params") + .context("Missing 'params' field in notification")?; + let callback = self + .subscriptions + .get(method) + .with_context(|| anyhow!("No handler for method '{}' registered", method))?; + trace!( + "Dispatching custom request: method={}, params={}", + method, + params + ); + if let Err(e) = callback(plugin.clone(), params) { + log::error!("Error in notification handler '{}': {}", method, e); + } + Ok(()) } } @@ -422,17 +481,6 @@ where } } -/// A container for all the configure hooks. It is just a collection -/// of callbacks that can be registered by the users of the -/// library. Based on this configuration we can then generate the -/// [`messages::GetManifestResponse`] from, populating our subscriptions -#[derive(Debug, Default)] -struct Hooks {} - -/// A container for all the configured notifications. -#[derive(Debug, Default)] -struct Subscriptions {} - #[cfg(test)] mod test { use super::*; diff --git a/plugins/src/messages.rs b/plugins/src/messages.rs index f32cb2086..ec11cb2eb 100644 --- a/plugins/src/messages.rs +++ b/plugins/src/messages.rs @@ -12,50 +12,49 @@ pub(crate) enum Request { // Builtin Getmanifest(GetManifestCall), Init(InitCall), - // Hooks - PeerConnected, - CommitmentRevocation, - DbWrite, - InvoicePayment, - Openchannel, - Openchannel2, - Openchannel2Changed, - Openchannel2Sign, - RbfChannel, - HtlcAccepted, - RpcCommand, - Custommsg, - OnionMessage, - OnionMessageBlinded, - OnionMessageOurpath, + // PeerConnected, + // CommitmentRevocation, + // DbWrite, + // InvoicePayment, + // Openchannel, + // Openchannel2, + // Openchannel2Changed, + // Openchannel2Sign, + // RbfChannel, + // HtlcAccepted, + // RpcCommand, + // Custommsg, + // OnionMessage, + // OnionMessageBlinded, + // OnionMessageOurpath, // Bitcoin backend - Getchaininfo, - Estimatefees, - Getrawblockbyheight, - Getutxout, - Sendrawtransaction, + // Getchaininfo, + // Estimatefees, + // Getrawblockbyheight, + // Getutxout, + // Sendrawtransaction, } #[derive(Deserialize, Debug)] #[serde(tag = "method", content = "params")] #[serde(rename_all = "snake_case")] pub(crate) enum Notification { - ChannelOpened, - ChannelOpenFailed, - ChannelStateChanged, - Connect, - Disconnect, - InvoicePayment, - InvoiceCreation, - Warning, - ForwardEvent, - SendpaySuccess, - SendpayFailure, - CoinMovement, - OpenchannelPeerSigs, - Shutdown, +// ChannelOpened, +// ChannelOpenFailed, +// ChannelStateChanged, +// Connect, +// Disconnect, +// InvoicePayment, +// InvoiceCreation, +// Warning, +// ForwardEvent, +// SendpaySuccess, +// SendpayFailure, +// CoinMovement, +// OpenchannelPeerSigs, +// Shutdown, } #[derive(Deserialize, Debug)] @@ -128,6 +127,8 @@ pub(crate) struct RpcMethod { pub(crate) struct GetManifestResponse { pub(crate) options: Vec, pub(crate) rpcmethods: Vec, + pub(crate) subscriptions: Vec, + pub(crate) hooks: Vec, } #[derive(Serialize, Default, Debug)] diff --git a/tests/test_cln_rs.py b/tests/test_cln_rs.py index c336c294d..5c2d89108 100644 --- a/tests/test_cln_rs.py +++ b/tests/test_cln_rs.py @@ -25,6 +25,7 @@ def test_plugin_start(node_factory): """ bin_path = Path.cwd() / "target" / "debug" / "examples" / "cln-plugin-startup" l1 = node_factory.get_node(options={"plugin": str(bin_path), 'test-option': 31337}) + l2 = node_factory.get_node() cfg = l1.rpc.listconfigs() p = cfg['plugins'][0] @@ -52,3 +53,7 @@ def test_plugin_start(node_factory): } assert l1.rpc.testmethod() == "Hello" + + l1.connect(l2) + l1.daemon.wait_for_log(r'Got a connect hook call') + l1.daemon.wait_for_log(r'Got a connect notification')