feat: Only block disabled users from writing files, not reading #130

This commit is contained in:
Severin Alexander Bühler
2025-05-16 14:05:38 +03:00
committed by GitHub
parent af206c7fc9
commit 93ae488196
6 changed files with 18 additions and 19 deletions

View File

@@ -94,11 +94,11 @@ async fn disabled_user() {
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
// Make sure the user cannot read their own file
// Make sure the user can still read their own file
let response = client.get(file_url.clone()).send().await.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
assert_eq!(response.status(), StatusCode::OK);
// Make sure the user cannot write to their own file
// Make sure the user cannot write a new file
let response = client
.put(file_url.clone())
.body(vec![])
@@ -107,9 +107,11 @@ async fn disabled_user() {
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
// Make sure the user cannot sign in
let session = client.signin(&keypair).await;
assert!(session.is_err());
// Make sure the user can still sign in
client
.signin(&keypair)
.await
.expect("Signin should succeed");
}
#[tokio::test]

View File

@@ -6,10 +6,14 @@ use crate::persistence::lmdb::{tables::users::UserQueryError, LmDB};
use super::Error;
/// Returns an error if the user doesn't exist or is disabled.
pub fn err_if_user_is_invalid(pubkey: &PublicKey, db: &LmDB) -> super::error::Result<()> {
pub fn err_if_user_is_invalid(
pubkey: &PublicKey,
db: &LmDB,
err_if_disabled: bool,
) -> super::error::Result<()> {
match db.get_user(pubkey, &mut db.env.read_txn()?) {
Ok(user) => {
if user.disabled {
if err_if_disabled && user.disabled {
return Err(Error::with_status(StatusCode::FORBIDDEN));
}
}

View File

@@ -123,7 +123,7 @@ fn create_session_and_cookie(
capabilities: &[Capability],
user_agent: Option<TypedHeader<UserAgent>>,
) -> Result<impl IntoResponse> {
err_if_user_is_invalid(public_key, &state.db)?;
err_if_user_is_invalid(public_key, &state.db, false)?;
// 1) Create session
let session_secret = encode(Alphabet::Crockford, &random_bytes::<16>());

View File

@@ -9,7 +9,6 @@ use pkarr::PublicKey;
use std::str::FromStr;
use crate::core::{
err_if_user_is_invalid::err_if_user_is_invalid,
error::{Error, Result},
extractors::{ListQueryParams, PubkyHost},
AppState,
@@ -22,8 +21,6 @@ pub async fn head(
headers: HeaderMap,
path: OriginalUri,
) -> Result<impl IntoResponse> {
err_if_user_is_invalid(pubky.public_key(), &state.db)?;
let rtxn = state.db.env.read_txn()?;
get_entry(
headers,
@@ -41,8 +38,6 @@ pub async fn get(
path: OriginalUri,
params: ListQueryParams,
) -> Result<impl IntoResponse> {
err_if_user_is_invalid(pubky.public_key(), &state.db)?;
let public_key = pubky.public_key().clone();
let path = path.0.path().to_string();
@@ -86,8 +81,6 @@ pub fn list(
path: &str,
params: ListQueryParams,
) -> Result<Response<Body>> {
err_if_user_is_invalid(public_key, &state.db)?;
let txn = state.db.env.read_txn()?;
let path = format!("{public_key}{path}");

View File

@@ -14,7 +14,7 @@ pub async fn session(
cookies: Cookies,
pubky: PubkyHost,
) -> Result<impl IntoResponse> {
err_if_user_is_invalid(pubky.public_key(), &state.db)?;
err_if_user_is_invalid(pubky.public_key(), &state.db, false)?;
if let Some(secret) = session_secret_from_cookies(&cookies, pubky.public_key()) {
if let Some(session) = state.db.get_session(&secret)? {
// TODO: add content-type

View File

@@ -45,7 +45,7 @@ pub async fn delete(
pubky: PubkyHost,
path: OriginalUri,
) -> Result<impl IntoResponse> {
err_if_user_is_invalid(pubky.public_key(), &state.db)?;
err_if_user_is_invalid(pubky.public_key(), &state.db, false)?;
let public_key = pubky.public_key();
let full_path = path.0.path();
let existing_bytes = state.db.get_entry_content_length(public_key, full_path)?;
@@ -69,7 +69,7 @@ pub async fn put(
path: OriginalUri,
body: Body,
) -> Result<impl IntoResponse> {
err_if_user_is_invalid(pubky.public_key(), &state.db)?;
err_if_user_is_invalid(pubky.public_key(), &state.db, true)?;
let public_key = pubky.public_key();
let full_path = path.0.path();
let existing_entry_bytes = state.db.get_entry_content_length(public_key, full_path)?;