From 10917743fe3fedd9d00b69ef1fa42d47989955ae Mon Sep 17 00:00:00 2001 From: elsirion Date: Fri, 13 May 2022 19:43:36 +0000 Subject: [PATCH] Implement a typed version of `call` to avoid useless matching --- cln-rpc/src/lib.rs | 27 +++++++++++++++++ contrib/msggen/msggen/gen/rust.py | 48 ++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/cln-rpc/src/lib.rs b/cln-rpc/src/lib.rs index cd2758d0a..1e603d1e2 100644 --- a/cln-rpc/src/lib.rs +++ b/cln-rpc/src/lib.rs @@ -24,6 +24,7 @@ pub use crate::{ notifications::Notification, primitives::RpcError, }; +use crate::model::IntoRequest; /// pub struct ClnRpc { @@ -105,6 +106,13 @@ impl ClnRpc { }) } } + + pub async fn call_typed(&mut self, request: R) -> Result { + Ok(self.call(request.into()) + .await? + .try_into() + .expect("CLN will reply correctly")) + } } /// Used to skip optional arrays when serializing requests. @@ -142,4 +150,23 @@ mod test { read_req ); } + + #[tokio::test] + async fn test_typed_call() { + let req = requests::GetinfoRequest {}; + let (uds1, uds2) = UnixStream::pair().unwrap(); + let mut cln = ClnRpc::from_stream(uds1).unwrap(); + + let mut read = FramedRead::new(uds2, JsonCodec::default()); + tokio::task::spawn(async move { + let _: GetinfoResponse = cln.call_typed(req).await.unwrap(); + }); + + let read_req = dbg!(read.next().await.unwrap().unwrap()); + + assert_eq!( + json!({"id": 1, "method": "getinfo", "params": {}, "jsonrpc": "2.0"}), + read_req + ); + } } diff --git a/contrib/msggen/msggen/gen/rust.py b/contrib/msggen/msggen/gen/rust.py index 7eac8ecf6..88d0d6bd9 100644 --- a/contrib/msggen/msggen/gen/rust.py +++ b/contrib/msggen/msggen/gen/rust.py @@ -6,7 +6,7 @@ import sys import re from msggen.model import (ArrayField, CompositeField, EnumField, - PrimitiveField, Service) + PrimitiveField, Service, Method) from msggen.gen.generator import IGenerator logger = logging.getLogger(__name__) @@ -214,6 +214,7 @@ class RustGenerator(IGenerator): use crate::primitives::*; #[allow(unused_imports)] use serde::{{Deserialize, Serialize}}; + use super::{IntoRequest, Request}; """) @@ -221,9 +222,24 @@ class RustGenerator(IGenerator): req = meth.request _, decl = gen_composite(req) self.write(decl, numindent=1) + self.generate_request_trait_impl(meth) self.write("}\n\n") + def generate_request_trait_impl(self, method: Method): + self.write(dedent(f"""\ + impl From<{method.request.typename}> for Request {{ + fn from(r: {method.request.typename}) -> Self {{ + Request::{method.name}(r) + }} + }} + + impl IntoRequest for {method.request.typename} {{ + type Response = super::responses::{method.response.typename}; + }} + + """), numindent=1) + def generate_responses(self, service: Service): self.write(""" pub mod responses { @@ -231,6 +247,7 @@ class RustGenerator(IGenerator): use crate::primitives::*; #[allow(unused_imports)] use serde::{{Deserialize, Serialize}}; + use super::{TryFromResponseError, Response}; """) @@ -238,9 +255,25 @@ class RustGenerator(IGenerator): res = meth.response _, decl = gen_composite(res) self.write(decl, numindent=1) + self.generate_response_trait_impl(meth) self.write("}\n\n") + def generate_response_trait_impl(self, method: Method): + self.write(dedent(f"""\ + impl TryFrom for {method.response.typename} {{ + type Error = super::TryFromResponseError; + + fn try_from(response: Response) -> Result {{ + match response {{ + Response::{method.name}(response) => Ok(response), + _ => Err(TryFromResponseError) + }} + }} + }} + + """), numindent=1) + def generate_enums(self, service: Service): """The Request and Response enums serve as parsing primitives. """ @@ -275,10 +308,23 @@ class RustGenerator(IGenerator): """) + def generate_request_trait(self): + self.write(""" + pub trait IntoRequest: Into { + type Response: TryFrom; + } + + #[derive(Debug)] + pub struct TryFromResponseError; + + """) + def generate(self, service: Service) -> None: self.write(header) self.generate_enums(service) + self.generate_request_trait() + self.generate_requests(service) self.generate_responses(service)