Implement a typed version of call to avoid useless matching

This commit is contained in:
elsirion
2022-05-13 19:43:36 +00:00
committed by Christian Decker
parent 7046252f96
commit 10917743fe
2 changed files with 74 additions and 1 deletions

View File

@@ -24,6 +24,7 @@ pub use crate::{
notifications::Notification, notifications::Notification,
primitives::RpcError, primitives::RpcError,
}; };
use crate::model::IntoRequest;
/// ///
pub struct ClnRpc { pub struct ClnRpc {
@@ -105,6 +106,13 @@ impl ClnRpc {
}) })
} }
} }
pub async fn call_typed<R: IntoRequest>(&mut self, request: R) -> Result<R::Response, RpcError> {
Ok(self.call(request.into())
.await?
.try_into()
.expect("CLN will reply correctly"))
}
} }
/// Used to skip optional arrays when serializing requests. /// Used to skip optional arrays when serializing requests.
@@ -142,4 +150,23 @@ mod test {
read_req read_req
); );
} }
#[tokio::test]
async fn test_typed_call() {
let req = requests::GetinfoRequest {};
let (uds1, uds2) = UnixStream::pair().unwrap();
let mut cln = ClnRpc::from_stream(uds1).unwrap();
let mut read = FramedRead::new(uds2, JsonCodec::default());
tokio::task::spawn(async move {
let _: GetinfoResponse = cln.call_typed(req).await.unwrap();
});
let read_req = dbg!(read.next().await.unwrap().unwrap());
assert_eq!(
json!({"id": 1, "method": "getinfo", "params": {}, "jsonrpc": "2.0"}),
read_req
);
}
} }

View File

@@ -6,7 +6,7 @@ import sys
import re import re
from msggen.model import (ArrayField, CompositeField, EnumField, from msggen.model import (ArrayField, CompositeField, EnumField,
PrimitiveField, Service) PrimitiveField, Service, Method)
from msggen.gen.generator import IGenerator from msggen.gen.generator import IGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -214,6 +214,7 @@ class RustGenerator(IGenerator):
use crate::primitives::*; use crate::primitives::*;
#[allow(unused_imports)] #[allow(unused_imports)]
use serde::{{Deserialize, Serialize}}; use serde::{{Deserialize, Serialize}};
use super::{IntoRequest, Request};
""") """)
@@ -221,9 +222,24 @@ class RustGenerator(IGenerator):
req = meth.request req = meth.request
_, decl = gen_composite(req) _, decl = gen_composite(req)
self.write(decl, numindent=1) self.write(decl, numindent=1)
self.generate_request_trait_impl(meth)
self.write("}\n\n") self.write("}\n\n")
def generate_request_trait_impl(self, method: Method):
self.write(dedent(f"""\
impl From<{method.request.typename}> for Request {{
fn from(r: {method.request.typename}) -> Self {{
Request::{method.name}(r)
}}
}}
impl IntoRequest for {method.request.typename} {{
type Response = super::responses::{method.response.typename};
}}
"""), numindent=1)
def generate_responses(self, service: Service): def generate_responses(self, service: Service):
self.write(""" self.write("""
pub mod responses { pub mod responses {
@@ -231,6 +247,7 @@ class RustGenerator(IGenerator):
use crate::primitives::*; use crate::primitives::*;
#[allow(unused_imports)] #[allow(unused_imports)]
use serde::{{Deserialize, Serialize}}; use serde::{{Deserialize, Serialize}};
use super::{TryFromResponseError, Response};
""") """)
@@ -238,9 +255,25 @@ class RustGenerator(IGenerator):
res = meth.response res = meth.response
_, decl = gen_composite(res) _, decl = gen_composite(res)
self.write(decl, numindent=1) self.write(decl, numindent=1)
self.generate_response_trait_impl(meth)
self.write("}\n\n") self.write("}\n\n")
def generate_response_trait_impl(self, method: Method):
self.write(dedent(f"""\
impl TryFrom<Response> for {method.response.typename} {{
type Error = super::TryFromResponseError;
fn try_from(response: Response) -> Result<Self, Self::Error> {{
match response {{
Response::{method.name}(response) => Ok(response),
_ => Err(TryFromResponseError)
}}
}}
}}
"""), numindent=1)
def generate_enums(self, service: Service): def generate_enums(self, service: Service):
"""The Request and Response enums serve as parsing primitives. """The Request and Response enums serve as parsing primitives.
""" """
@@ -275,10 +308,23 @@ class RustGenerator(IGenerator):
""") """)
def generate_request_trait(self):
self.write("""
pub trait IntoRequest: Into<Request> {
type Response: TryFrom<Response, Error = TryFromResponseError>;
}
#[derive(Debug)]
pub struct TryFromResponseError;
""")
def generate(self, service: Service) -> None: def generate(self, service: Service) -> None:
self.write(header) self.write(header)
self.generate_enums(service) self.generate_enums(service)
self.generate_request_trait()
self.generate_requests(service) self.generate_requests(service)
self.generate_responses(service) self.generate_responses(service)