diff --git a/pubky-common/src/session.rs b/pubky-common/src/session.rs index 494265f..972652c 100644 --- a/pubky-common/src/session.rs +++ b/pubky-common/src/session.rs @@ -84,7 +84,7 @@ impl Session { pub type Result = core::result::Result; -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Debug, PartialEq)] pub enum Error { #[error("Empty payload")] EmptyPayload, @@ -132,6 +132,8 @@ mod tests { #[test] fn deserialize() { - let deseiralized = Session::deserialize(&[]).unwrap(); + let result = Session::deserialize(&[]); + + assert_eq!(result, Err(Error::EmptyPayload)); } } diff --git a/pubky-homeserver/src/database/tables/events.rs b/pubky-homeserver/src/database/tables/events.rs index 1829765..76a4d46 100644 --- a/pubky-homeserver/src/database/tables/events.rs +++ b/pubky-homeserver/src/database/tables/events.rs @@ -67,7 +67,7 @@ impl DB { pub fn list_events( &self, limit: Option, - cursor: Option<&str>, + cursor: Option, ) -> anyhow::Result> { let txn = self.env.read_txn()?; @@ -75,7 +75,7 @@ impl DB { .unwrap_or(self.config.default_list_limit()) .min(self.config.max_list_limit()); - let cursor = cursor.unwrap_or("0000000000000"); + let cursor = cursor.unwrap_or("0000000000000".to_string()); let mut result: Vec = vec![]; let mut next_cursor = cursor.to_string(); diff --git a/pubky-homeserver/src/extractors.rs b/pubky-homeserver/src/extractors.rs index 567ca6b..779ce65 100644 --- a/pubky-homeserver/src/extractors.rs +++ b/pubky-homeserver/src/extractors.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use axum::{ async_trait, - extract::{FromRequestParts, Path}, + extract::{FromRequestParts, Path, Query}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, RequestPartsExt, @@ -74,3 +74,50 @@ where Ok(EntryPath(path.to_string())) } } + +#[derive(Debug)] +pub struct ListQueryParams { + pub limit: Option, + pub cursor: Option, + pub reverse: bool, + pub shallow: bool, +} + +#[async_trait] +impl FromRequestParts for ListQueryParams +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let params: Query> = + parts.extract().await.map_err(IntoResponse::into_response)?; + + let reverse = params.contains_key("reverse"); + let shallow = params.contains_key("shallow"); + let limit = params + .get("limit") + // Treat `limit=` as None + .and_then(|l| if l.is_empty() { None } else { Some(l) }) + .and_then(|l| l.parse::().ok()); + let cursor = params + .get("cursor") + .map(|c| c.as_str()) + // Treat `cursor=` as None + .and_then(|c| { + if c.is_empty() { + None + } else { + Some(c.to_string()) + } + }); + + Ok(ListQueryParams { + reverse, + shallow, + limit, + cursor, + }) + } +} diff --git a/pubky-homeserver/src/routes/feed.rs b/pubky-homeserver/src/routes/feed.rs index 627eb3e..6271aeb 100644 --- a/pubky-homeserver/src/routes/feed.rs +++ b/pubky-homeserver/src/routes/feed.rs @@ -1,8 +1,6 @@ -use std::collections::HashMap; - use axum::{ body::Body, - extract::{Query, State}, + extract::State, http::{header, Response, StatusCode}, response::IntoResponse, }; @@ -10,17 +8,15 @@ use pubky_common::timestamp::{Timestamp, TimestampError}; use crate::{ error::{Error, Result}, + extractors::ListQueryParams, server::AppState, }; pub async fn feed( State(state): State, - Query(params): Query>, + params: ListQueryParams, ) -> Result { - let limit = params.get("limit").and_then(|l| l.parse::().ok()); - let cursor = params.get("cursor").map(|c| c.as_str()); - - if let Some(cursor) = cursor { + if let Some(ref cursor) = params.cursor { if let Err(timestmap_error) = Timestamp::try_from(cursor.to_string()) { let cause = match timestmap_error { TimestampError::InvalidEncoding => { @@ -35,7 +31,7 @@ pub async fn feed( } } - let result = state.db.list_events(limit, cursor)?; + let result = state.db.list_events(params.limit, params.cursor)?; Ok(Response::builder() .status(StatusCode::OK) diff --git a/pubky-homeserver/src/routes/public.rs b/pubky-homeserver/src/routes/public.rs index 4cf2eed..8c6b2b9 100644 --- a/pubky-homeserver/src/routes/public.rs +++ b/pubky-homeserver/src/routes/public.rs @@ -1,8 +1,6 @@ -use std::collections::HashMap; - use axum::{ body::{Body, Bytes}, - extract::{Query, State}, + extract::State, http::{header, Response, StatusCode}, response::IntoResponse, }; @@ -12,7 +10,7 @@ use tower_cookies::Cookies; use crate::{ error::{Error, Result}, - extractors::{EntryPath, Pubky}, + extractors::{EntryPath, ListQueryParams, Pubky}, server::AppState, }; @@ -65,7 +63,7 @@ pub async fn get( State(state): State, pubky: Pubky, path: EntryPath, - Query(params): Query>, + params: ListQueryParams, ) -> Result { verify(path.as_str())?; let public_key = pubky.public_key(); @@ -88,10 +86,10 @@ pub async fn get( let vec = state.db.list( &txn, &path, - params.contains_key("reverse"), - params.get("limit").and_then(|l| l.parse::().ok()), - params.get("cursor").map(|cursor| cursor.into()), - params.contains_key("shallow"), + params.reverse, + params.limit, + params.cursor, + params.shallow, )?; return Ok(Response::builder()