feat(homeserver): treat (empty) as None'

This commit is contained in:
nazeh
2024-09-28 08:35:43 +03:00
parent e031c7a9dd
commit 842b9f32c8
5 changed files with 66 additions and 23 deletions

View File

@@ -84,7 +84,7 @@ impl Session {
pub type Result<T> = core::result::Result<T, Error>; pub type Result<T> = core::result::Result<T, Error>;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug, PartialEq)]
pub enum Error { pub enum Error {
#[error("Empty payload")] #[error("Empty payload")]
EmptyPayload, EmptyPayload,
@@ -132,6 +132,8 @@ mod tests {
#[test] #[test]
fn deserialize() { fn deserialize() {
let deseiralized = Session::deserialize(&[]).unwrap(); let result = Session::deserialize(&[]);
assert_eq!(result, Err(Error::EmptyPayload));
} }
} }

View File

@@ -67,7 +67,7 @@ impl DB {
pub fn list_events( pub fn list_events(
&self, &self,
limit: Option<u16>, limit: Option<u16>,
cursor: Option<&str>, cursor: Option<String>,
) -> anyhow::Result<Vec<String>> { ) -> anyhow::Result<Vec<String>> {
let txn = self.env.read_txn()?; let txn = self.env.read_txn()?;
@@ -75,7 +75,7 @@ impl DB {
.unwrap_or(self.config.default_list_limit()) .unwrap_or(self.config.default_list_limit())
.min(self.config.max_list_limit()); .min(self.config.max_list_limit());
let cursor = cursor.unwrap_or("0000000000000"); let cursor = cursor.unwrap_or("0000000000000".to_string());
let mut result: Vec<String> = vec![]; let mut result: Vec<String> = vec![];
let mut next_cursor = cursor.to_string(); let mut next_cursor = cursor.to_string();

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequestParts, Path}, extract::{FromRequestParts, Path, Query},
http::{request::Parts, StatusCode}, http::{request::Parts, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
RequestPartsExt, RequestPartsExt,
@@ -74,3 +74,50 @@ where
Ok(EntryPath(path.to_string())) Ok(EntryPath(path.to_string()))
} }
} }
#[derive(Debug)]
pub struct ListQueryParams {
pub limit: Option<u16>,
pub cursor: Option<String>,
pub reverse: bool,
pub shallow: bool,
}
#[async_trait]
impl<S> FromRequestParts<S> for ListQueryParams
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params: Query<HashMap<String, String>> =
parts.extract().await.map_err(IntoResponse::into_response)?;
let reverse = params.contains_key("reverse");
let shallow = params.contains_key("shallow");
let limit = params
.get("limit")
// Treat `limit=` as None
.and_then(|l| if l.is_empty() { None } else { Some(l) })
.and_then(|l| l.parse::<u16>().ok());
let cursor = params
.get("cursor")
.map(|c| c.as_str())
// Treat `cursor=` as None
.and_then(|c| {
if c.is_empty() {
None
} else {
Some(c.to_string())
}
});
Ok(ListQueryParams {
reverse,
shallow,
limit,
cursor,
})
}
}

View File

@@ -1,8 +1,6 @@
use std::collections::HashMap;
use axum::{ use axum::{
body::Body, body::Body,
extract::{Query, State}, extract::State,
http::{header, Response, StatusCode}, http::{header, Response, StatusCode},
response::IntoResponse, response::IntoResponse,
}; };
@@ -10,17 +8,15 @@ use pubky_common::timestamp::{Timestamp, TimestampError};
use crate::{ use crate::{
error::{Error, Result}, error::{Error, Result},
extractors::ListQueryParams,
server::AppState, server::AppState,
}; };
pub async fn feed( pub async fn feed(
State(state): State<AppState>, State(state): State<AppState>,
Query(params): Query<HashMap<String, String>>, params: ListQueryParams,
) -> Result<impl IntoResponse> { ) -> Result<impl IntoResponse> {
let limit = params.get("limit").and_then(|l| l.parse::<u16>().ok()); if let Some(ref cursor) = params.cursor {
let cursor = params.get("cursor").map(|c| c.as_str());
if let Some(cursor) = cursor {
if let Err(timestmap_error) = Timestamp::try_from(cursor.to_string()) { if let Err(timestmap_error) = Timestamp::try_from(cursor.to_string()) {
let cause = match timestmap_error { let cause = match timestmap_error {
TimestampError::InvalidEncoding => { TimestampError::InvalidEncoding => {
@@ -35,7 +31,7 @@ pub async fn feed(
} }
} }
let result = state.db.list_events(limit, cursor)?; let result = state.db.list_events(params.limit, params.cursor)?;
Ok(Response::builder() Ok(Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)

View File

@@ -1,8 +1,6 @@
use std::collections::HashMap;
use axum::{ use axum::{
body::{Body, Bytes}, body::{Body, Bytes},
extract::{Query, State}, extract::State,
http::{header, Response, StatusCode}, http::{header, Response, StatusCode},
response::IntoResponse, response::IntoResponse,
}; };
@@ -12,7 +10,7 @@ use tower_cookies::Cookies;
use crate::{ use crate::{
error::{Error, Result}, error::{Error, Result},
extractors::{EntryPath, Pubky}, extractors::{EntryPath, ListQueryParams, Pubky},
server::AppState, server::AppState,
}; };
@@ -65,7 +63,7 @@ pub async fn get(
State(state): State<AppState>, State(state): State<AppState>,
pubky: Pubky, pubky: Pubky,
path: EntryPath, path: EntryPath,
Query(params): Query<HashMap<String, String>>, params: ListQueryParams,
) -> Result<impl IntoResponse> { ) -> Result<impl IntoResponse> {
verify(path.as_str())?; verify(path.as_str())?;
let public_key = pubky.public_key(); let public_key = pubky.public_key();
@@ -88,10 +86,10 @@ pub async fn get(
let vec = state.db.list( let vec = state.db.list(
&txn, &txn,
&path, &path,
params.contains_key("reverse"), params.reverse,
params.get("limit").and_then(|l| l.parse::<u16>().ok()), params.limit,
params.get("cursor").map(|cursor| cursor.into()), params.cursor,
params.contains_key("shallow"), params.shallow,
)?; )?;
return Ok(Response::builder() return Ok(Response::builder()