rs: Run hooks, methods and notification handlers in tokio tasks

Changelog-Changed: cln-plugin: Hooks, notifications and RPC methods now run asynchronously allowing for re-entrant handlers
This commit is contained in:
Christian Decker
2023-04-11 06:57:45 +09:30
committed by ShahanaFarooqui
parent db3707f957
commit f69da84256
5 changed files with 97 additions and 144 deletions

View File

@@ -11,8 +11,8 @@ use std::str::FromStr;
use std::{io, str}; use std::{io, str};
use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{Decoder, Encoder};
use crate::messages::{Notification, Request};
use crate::messages::JsonRpc; use crate::messages::JsonRpc;
use crate::messages::{Notification, Request};
/// A simple codec that parses messages separated by two successive /// A simple codec that parses messages separated by two successive
/// `\n` newlines. /// `\n` newlines.

View File

@@ -322,7 +322,7 @@ where
hooks: self.hooks.keys().map(|s| s.clone()).collect(), hooks: self.hooks.keys().map(|s| s.clone()).collect(),
rpcmethods, rpcmethods,
dynamic: self.dynamic, dynamic: self.dynamic,
nonnumericids: true, nonnumericids: true,
} }
} }
@@ -458,8 +458,8 @@ where
// Start the PluginDriver to handle plugin IO // Start the PluginDriver to handle plugin IO
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = driver.run(receiver, input, output).await { if let Err(e) = driver.run(receiver, input, output).await {
log::warn!("Plugin loop returned error {:?}", e); log::warn!("Plugin loop returned error {:?}", e);
} }
// Now that we have left the reader loop its time to // Now that we have left the reader loop its time to
// notify any waiting tasks. This most likely will cause // notify any waiting tasks. This most likely will cause
@@ -507,7 +507,7 @@ where
impl<S> PluginDriver<S> impl<S> PluginDriver<S>
where where
S: Send + Clone, S: Send + Clone + Sync,
{ {
/// Run the plugin until we get a shutdown command. /// Run the plugin until we get a shutdown command.
async fn run<I, O>( async fn run<I, O>(
@@ -526,17 +526,17 @@ where
// the user-code, which may require some cleanups or // the user-code, which may require some cleanups or
// similar. // similar.
tokio::select! { tokio::select! {
e = self.dispatch_one(&mut input, &self.plugin) => { e = self.dispatch_one(&mut input, &self.plugin) => {
if let Err(e) = e { if let Err(e) = e {
return Err(e) return Err(e)
} }
}, },
v = receiver.recv() => { v = receiver.recv() => {
output.lock().await.send( output.lock().await.send(
v.context("internal communication error")? v.context("internal communication error")?
).await?; ).await?;
}, },
} }
} }
} }
@@ -554,36 +554,74 @@ where
Some(Ok(msg)) => { Some(Ok(msg)) => {
trace!("Received a message: {:?}", msg); trace!("Received a message: {:?}", msg);
match msg { match msg {
messages::JsonRpc::Request(id, p) => { messages::JsonRpc::Request(_id, _p) => {
PluginDriver::<S>::dispatch_request(id, p, plugin).await todo!("This is unreachable until we start filling in messages:Request. Until then the custom dispatcher below is used exclusively.");
} }
messages::JsonRpc::Notification(n) => { messages::JsonRpc::Notification(_n) => {
self.dispatch_notification(n, plugin).await todo!("As soon as we define the full structure of the messages::Notification we'll get here. Until then the custom dispatcher below is used.")
} }
messages::JsonRpc::CustomRequest(id, p) => { messages::JsonRpc::CustomRequest(id, request) => {
match self.dispatch_custom_request(id.clone(), p, plugin).await { trace!("Dispatching custom method {:?}", request);
Ok(v) => plugin let method = request
.sender .get("method")
.send(json!({ .context("Missing 'method' in request")?
"jsonrpc": "2.0", .as_str()
"id": id, .context("'method' is not a string")?;
"result": v let callback = self.rpcmethods.get(method).with_context(|| {
})) anyhow!("No handler for method '{}' registered", method)
.await })?;
.context("returning custom result"), let params = request
Err(e) => plugin .get("params")
.sender .context("Missing 'params' field in request")?
.send(json!({ .clone();
"jsonrpc": "2.0",
"id": id, let plugin = plugin.clone();
"error": e.to_string(), let call = callback(plugin.clone(), params);
}))
.await tokio::spawn(async move {
.context("returning custom error"), match call.await {
} Ok(v) => plugin
.sender
.send(json!({
"jsonrpc": "2.0",
"id": id,
"result": v
}))
.await
.context("returning custom response"),
Err(e) => plugin
.sender
.send(json!({
"jsonrpc": "2.0",
"id": id,
"error": e.to_string(),
}))
.await
.context("returning custom error"),
}
});
Ok(())
} }
messages::JsonRpc::CustomNotification(n) => { messages::JsonRpc::CustomNotification(request) => {
self.dispatch_custom_notification(n, plugin).await trace!("Dispatching custom notification {:?}", request);
let method = request
.get("method")
.context("Missing 'method' in request")?
.as_str()
.context("'method' is not a string")?;
let callback = self.subscriptions.get(method).with_context(|| {
anyhow!("No handler for notification '{}' registered", method)
})?;
let params = request
.get("params")
.context("Missing 'params' field in request")?
.clone();
let plugin = plugin.clone();
let call = callback(plugin.clone(), params);
tokio::spawn(async move { call.await.unwrap() });
Ok(())
} }
} }
} }
@@ -591,85 +629,6 @@ where
None => Err(anyhow!("Error reading from master")), None => Err(anyhow!("Error reading from master")),
} }
} }
async fn dispatch_request(
_id: serde_json::Value,
_request: messages::Request,
_plugin: &Plugin<S>,
) -> Result<(), Error> {
todo!("This is unreachable until we start filling in messages:Request. Until then the custom dispatcher below is used exclusively.")
}
async fn dispatch_notification(
&self,
_notification: messages::Notification,
_plugin: &Plugin<S>,
) -> Result<(), Error>
where
S: Send + Clone,
{
todo!("As soon as we define the full structure of the messages::Notification we'll get here. Until then the custom dispatcher below is used.")
}
async fn dispatch_custom_request(
&self,
_id: serde_json::Value,
request: serde_json::Value,
plugin: &Plugin<S>,
) -> Result<serde_json::Value, Error> {
let method = request
.get("method")
.context("Missing 'method' in request")?
.as_str()
.context("'method' is not a string")?;
let params = request
.get("params")
.context("Missing 'params' field in request")?;
let callback = self
.rpcmethods
.get(method)
.with_context(|| anyhow!("No handler for method '{}' registered", method))?;
trace!(
"Dispatching custom request: method={}, params={}",
method,
params
);
callback(plugin.clone(), params.clone()).await
}
async fn dispatch_custom_notification(
&self,
notification: serde_json::Value,
plugin: &Plugin<S>,
) -> Result<(), Error>
where
S: Send + Clone,
{
trace!("Dispatching custom notification {:?}", notification);
let method = notification
.get("method")
.context("Missing 'method' in notification")?
.as_str()
.context("'method' is not a string")?;
let params = notification
.get("params")
.context("Missing 'params' field in notification")?;
let callback = self
.subscriptions
.get(method)
.with_context(|| anyhow!("No handler for method '{}' registered", method))?;
trace!(
"Dispatching custom request: method={}, params={}",
method,
params
);
if let Err(e) = callback(plugin.clone(), params.clone()).await {
log::error!("Error in notification handler '{}': {}", method, e);
}
Ok(())
}
} }
impl<S> Plugin<S> impl<S> Plugin<S>
@@ -715,7 +674,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn init() { async fn init() {
let state = (); let state = ();
let builder = Builder::new(tokio::io::stdin(), tokio::io::stdout()); let builder = Builder::new(tokio::io::stdin(), tokio::io::stdout());
let _ = builder.start(state); let _ = builder.start(state);
} }

View File

@@ -41,20 +41,20 @@ pub(crate) enum Request {
#[serde(tag = "method", content = "params")] #[serde(tag = "method", content = "params")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub(crate) enum Notification { pub(crate) enum Notification {
// ChannelOpened, // ChannelOpened,
// ChannelOpenFailed, // ChannelOpenFailed,
// ChannelStateChanged, // ChannelStateChanged,
// Connect, // Connect,
// Disconnect, // Disconnect,
// InvoicePayment, // InvoicePayment,
// InvoiceCreation, // InvoiceCreation,
// Warning, // Warning,
// ForwardEvent, // ForwardEvent,
// SendpaySuccess, // SendpaySuccess,
// SendpayFailure, // SendpayFailure,
// CoinMovement, // CoinMovement,
// OpenchannelPeerSigs, // OpenchannelPeerSigs,
// Shutdown, // Shutdown,
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]

View File

@@ -28,7 +28,7 @@ impl Value {
_ => None, _ => None,
} }
} }
/// Returns true if the `Value` is an integer between `i64::MIN` and /// Returns true if the `Value` is an integer between `i64::MIN` and
/// `i64::MAX`. /// `i64::MAX`.
/// ///
@@ -36,8 +36,6 @@ impl Value {
/// return the integer value. /// return the integer value.
pub fn is_i64(&self) -> bool { pub fn is_i64(&self) -> bool {
self.as_i64().is_some() self.as_i64().is_some()
} }
/// If the `Value` is an integer, represent it as i64. Returns /// If the `Value` is an integer, represent it as i64. Returns

View File

@@ -249,10 +249,6 @@ def test_grpc_wrong_auth(node_factory):
stub.Getinfo(nodepb.GetinfoRequest()) stub.Getinfo(nodepb.GetinfoRequest())
@pytest.mark.xfail(
reason="Times out because we can't call the RPC method while currently holding on to HTLCs",
strict=True,
)
def test_cln_plugin_reentrant(node_factory, executor): def test_cln_plugin_reentrant(node_factory, executor):
"""Ensure that we continue processing events while already handling. """Ensure that we continue processing events while already handling.