feat(homeserver): treat (empty) as None'

This commit is contained in:
nazeh
2024-09-28 08:35:43 +03:00
parent e031c7a9dd
commit 842b9f32c8
5 changed files with 66 additions and 23 deletions

View File

@@ -84,7 +84,7 @@ impl Session {
pub type Result<T> = core::result::Result<T, Error>;
#[derive(thiserror::Error, Debug)]
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum Error {
#[error("Empty payload")]
EmptyPayload,
@@ -132,6 +132,8 @@ mod tests {
#[test]
fn deserialize() {
let deseiralized = Session::deserialize(&[]).unwrap();
let result = Session::deserialize(&[]);
assert_eq!(result, Err(Error::EmptyPayload));
}
}

View File

@@ -67,7 +67,7 @@ impl DB {
pub fn list_events(
&self,
limit: Option<u16>,
cursor: Option<&str>,
cursor: Option<String>,
) -> anyhow::Result<Vec<String>> {
let txn = self.env.read_txn()?;
@@ -75,7 +75,7 @@ impl DB {
.unwrap_or(self.config.default_list_limit())
.min(self.config.max_list_limit());
let cursor = cursor.unwrap_or("0000000000000");
let cursor = cursor.unwrap_or("0000000000000".to_string());
let mut result: Vec<String> = vec![];
let mut next_cursor = cursor.to_string();

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use axum::{
async_trait,
extract::{FromRequestParts, Path},
extract::{FromRequestParts, Path, Query},
http::{request::Parts, StatusCode},
response::{IntoResponse, Response},
RequestPartsExt,
@@ -74,3 +74,50 @@ where
Ok(EntryPath(path.to_string()))
}
}
#[derive(Debug)]
pub struct ListQueryParams {
pub limit: Option<u16>,
pub cursor: Option<String>,
pub reverse: bool,
pub shallow: bool,
}
#[async_trait]
impl<S> FromRequestParts<S> for ListQueryParams
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params: Query<HashMap<String, String>> =
parts.extract().await.map_err(IntoResponse::into_response)?;
let reverse = params.contains_key("reverse");
let shallow = params.contains_key("shallow");
let limit = params
.get("limit")
// Treat `limit=` as None
.and_then(|l| if l.is_empty() { None } else { Some(l) })
.and_then(|l| l.parse::<u16>().ok());
let cursor = params
.get("cursor")
.map(|c| c.as_str())
// Treat `cursor=` as None
.and_then(|c| {
if c.is_empty() {
None
} else {
Some(c.to_string())
}
});
Ok(ListQueryParams {
reverse,
shallow,
limit,
cursor,
})
}
}

View File

@@ -1,8 +1,6 @@
use std::collections::HashMap;
use axum::{
body::Body,
extract::{Query, State},
extract::State,
http::{header, Response, StatusCode},
response::IntoResponse,
};
@@ -10,17 +8,15 @@ use pubky_common::timestamp::{Timestamp, TimestampError};
use crate::{
error::{Error, Result},
extractors::ListQueryParams,
server::AppState,
};
pub async fn feed(
State(state): State<AppState>,
Query(params): Query<HashMap<String, String>>,
params: ListQueryParams,
) -> Result<impl IntoResponse> {
let limit = params.get("limit").and_then(|l| l.parse::<u16>().ok());
let cursor = params.get("cursor").map(|c| c.as_str());
if let Some(cursor) = cursor {
if let Some(ref cursor) = params.cursor {
if let Err(timestmap_error) = Timestamp::try_from(cursor.to_string()) {
let cause = match timestmap_error {
TimestampError::InvalidEncoding => {
@@ -35,7 +31,7 @@ pub async fn feed(
}
}
let result = state.db.list_events(limit, cursor)?;
let result = state.db.list_events(params.limit, params.cursor)?;
Ok(Response::builder()
.status(StatusCode::OK)

View File

@@ -1,8 +1,6 @@
use std::collections::HashMap;
use axum::{
body::{Body, Bytes},
extract::{Query, State},
extract::State,
http::{header, Response, StatusCode},
response::IntoResponse,
};
@@ -12,7 +10,7 @@ use tower_cookies::Cookies;
use crate::{
error::{Error, Result},
extractors::{EntryPath, Pubky},
extractors::{EntryPath, ListQueryParams, Pubky},
server::AppState,
};
@@ -65,7 +63,7 @@ pub async fn get(
State(state): State<AppState>,
pubky: Pubky,
path: EntryPath,
Query(params): Query<HashMap<String, String>>,
params: ListQueryParams,
) -> Result<impl IntoResponse> {
verify(path.as_str())?;
let public_key = pubky.public_key();
@@ -88,10 +86,10 @@ pub async fn get(
let vec = state.db.list(
&txn,
&path,
params.contains_key("reverse"),
params.get("limit").and_then(|l| l.parse::<u16>().ok()),
params.get("cursor").map(|cursor| cursor.into()),
params.contains_key("shallow"),
params.reverse,
params.limit,
params.cursor,
params.shallow,
)?;
return Ok(Response::builder()