fix: passing all test cases

This commit is contained in:
nazeh
2023-12-23 08:40:57 +03:00
parent e9d939f017
commit 05e8bb1720
4 changed files with 314 additions and 329 deletions

View File

@@ -36,10 +36,22 @@ enum RefCountDiff {
}
impl Node {
pub(crate) fn new(key: &[u8], value: &[u8]) -> Self {
Self {
key: key.into(),
value: value.into(),
left: None,
right: None,
ref_count: 0,
}
}
pub(crate) fn open(
table: &'_ impl ReadableTable<&'static [u8], (u64, &'static [u8])>,
hash: Hash,
) -> Option<Node> {
// TODO: make it Result instead!
let existing = table.get(hash.as_bytes().as_slice()).unwrap();
existing.map(|existing| {
@@ -53,35 +65,6 @@ impl Node {
})
}
pub(crate) fn insert(
table: &mut Table<&[u8], (u64, &[u8])>,
key: &[u8],
value: &[u8],
left: Option<Hash>,
right: Option<Hash>,
) -> Hash {
let node = Self {
key: key.into(),
value: value.into(),
left,
right,
ref_count: 1,
};
let encoded = node.canonical_encode();
let hash = hash(&encoded);
table
.insert(
hash.as_bytes().as_slice(),
(node.ref_count, encoded.as_slice()),
)
.unwrap();
hash
}
// === Getters ===
pub fn key(&self) -> &[u8] {
@@ -100,98 +83,57 @@ impl Node {
&self.right
}
pub fn rank(&self) -> Hash {
hash(self.key())
}
pub(crate) fn ref_count(&self) -> &u64 {
&self.ref_count
}
// === Public Methods ===
pub fn rank(&self) -> Hash {
hash(&self.key)
}
/// Returns the hash of the node.
pub fn hash(&self) -> Hash {
hash(&self.canonical_encode())
}
/// Set the value and save the updated node.
pub(crate) fn set_value(
&mut self,
table: &mut Table<&[u8], (u64, &[u8])>,
value: &[u8],
) -> Hash {
self.value = value.into();
self.save(table)
}
/// Set the left child, save the updated node, and return the new hash.
pub(crate) fn set_left_child(
&mut self,
table: &mut Table<&[u8], (u64, &[u8])>,
child: Option<Hash>,
) -> Hash {
self.set_child(table, Branch::Left, child)
}
/// Set the right child, save the updated node, and return the new hash.
pub(crate) fn set_right_child(
&mut self,
table: &mut Table<&[u8], (u64, &[u8])>,
child: Option<Hash>,
) -> Hash {
self.set_child(table, Branch::Right, child)
let encoded = self.canonical_encode();
hash(&encoded)
}
// === Private Methods ===
pub fn decrement_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) {
self.update_ref_count(table, RefCountDiff::Decrement)
/// Set the value.
pub(crate) fn set_value(&mut self, value: &[u8]) -> &mut Self {
self.value = value.into();
self
}
fn set_child(
&mut self,
table: &mut Table<&[u8], (u64, &[u8])>,
branch: Branch,
child: Option<Hash>,
) -> Hash {
/// Set the left child, save the updated node, and return the new hash.
pub(crate) fn set_left_child(&mut self, child: Option<&mut Node>) -> &mut Self {
self.set_child(Branch::Left, child)
}
/// Set the right child, save the updated node, and return the new hash.
pub(crate) fn set_right_child(&mut self, child: Option<&mut Node>) -> &mut Self {
self.set_child(Branch::Right, child)
}
/// Set the child, update its ref_count, save the updated node and return it.
fn set_child(&mut self, branch: Branch, new_child: Option<&mut Node>) -> &mut Self {
match branch {
Branch::Left => self.left = child,
Branch::Right => self.right = child,
}
Branch::Left => self.left = new_child.as_ref().map(|n| n.hash()),
Branch::Right => self.right = new_child.as_ref().map(|n| n.hash()),
};
let encoded = self.canonical_encode();
let hash = hash(&encoded);
table
.insert(
hash.as_bytes().as_slice(),
(self.ref_count, encoded.as_slice()),
)
.unwrap();
hash
self
}
fn save(&mut self, table: &mut Table<&[u8], (u64, &[u8])>) -> Hash {
let encoded = self.canonical_encode();
let hash = hash(&encoded);
table
.insert(
hash.as_bytes().as_slice(),
(self.ref_count, encoded.as_slice()),
)
.unwrap();
hash
pub(crate) fn increment_ref_count(&mut self) -> &mut Self {
self.update_ref_count(RefCountDiff::Increment)
}
fn increment_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) {
self.update_ref_count(table, RefCountDiff::Increment)
pub(crate) fn decrement_ref_count(&mut self) -> &mut Self {
self.update_ref_count(RefCountDiff::Decrement)
}
fn update_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>, diff: RefCountDiff) {
fn update_ref_count(&mut self, diff: RefCountDiff) -> &mut Self {
let ref_count = match diff {
RefCountDiff::Increment => self.ref_count + 1,
RefCountDiff::Decrement => {
@@ -203,14 +145,23 @@ impl Node {
}
};
let bytes = self.canonical_encode();
let hash = hash(&bytes);
// We only updaet the ref count, and handle deletion elsewhere.
self.ref_count = ref_count;
self
}
match ref_count {
0 => table.remove(hash.as_bytes().as_slice()),
_ => table.insert(hash.as_bytes().as_slice(), (ref_count, bytes.as_slice())),
}
.unwrap();
pub(crate) fn save(&mut self, table: &mut Table<&[u8], (u64, &[u8])>) -> &mut Self {
// TODO: keep data in encoded in a bytes field.
let encoded = self.canonical_encode();
table
.insert(
hash(&encoded).as_bytes().as_slice(),
(self.ref_count, encoded.as_slice()),
)
.unwrap();
self
}
fn canonical_encode(&self) -> Vec<u8> {
@@ -232,8 +183,11 @@ impl Node {
}
}
pub(crate) fn rank(key: &[u8]) -> Hash {
hash(key)
pub(crate) fn hash(bytes: &[u8]) -> Hash {
let mut hasher = Hasher::new();
hasher.update(bytes);
hasher.finalize()
}
fn encode(bytes: &[u8], out: &mut Vec<u8>) {
@@ -255,14 +209,7 @@ fn decode(bytes: &[u8]) -> (&[u8], &[u8]) {
(value, rest)
}
fn hash(bytes: &[u8]) -> Hash {
let mut hasher = Hasher::new();
hasher.update(bytes);
hasher.finalize()
}
pub fn decode_node(data: (u64, &[u8])) -> Node {
fn decode_node(data: (u64, &[u8])) -> Node {
let (ref_count, encoded_node) = data;
let (key, rest) = decode(encoded_node);

View File

@@ -1,7 +1,6 @@
use std::cmp::Ordering;
use crate::node::{rank, Branch, Node};
use blake3::Hash;
use crate::node::{hash, Branch, Node};
use redb::Table;
// Watch this [video](https://youtu.be/NxRXhBur6Xs?si=GNwaUOfuGwr_tBKI&t=1763) for a good explanation of the unzipping algorithm.
@@ -82,38 +81,86 @@ use redb::Table;
// all then new nodes (in both the upper and lower paths) before comitting the write transaction.
pub fn insert(
table: &'_ mut Table<&'static [u8], (u64, &'static [u8])>,
root: Option<Hash>,
nodes_table: &'_ mut Table<&'static [u8], (u64, &'static [u8])>,
root: Option<Node>,
key: &[u8],
value: &[u8],
) -> Hash {
let mut path = binary_search_path(table, root, key);
) -> Node {
let mut path = binary_search_path(nodes_table, root, key);
let mut unzip_left_root: Option<Hash> = None;
let mut unzip_right_root: Option<Hash> = None;
let mut unzip_left_root: Option<&mut Node> = None;
let mut unzip_right_root: Option<&mut Node> = None;
// Unzip the lower path to get left and right children of the inserted node.
for (node, branch) in path.unzip_path.iter_mut().rev() {
node.decrement_ref_count().save(nodes_table);
match branch {
Branch::Right => unzip_left_root = Some(node.set_right_child(table, unzip_left_root)),
Branch::Left => unzip_right_root = Some(node.set_left_child(table, unzip_right_root)),
Branch::Right => {
node.set_right_child(unzip_left_root)
.increment_ref_count()
.save(nodes_table);
unzip_left_root = Some(node);
}
Branch::Left => {
node.set_left_child(unzip_right_root)
.increment_ref_count()
.save(nodes_table);
unzip_right_root = Some(node);
}
}
}
let mut root = if let Some(mut existing) = path.existing {
existing.set_value(table, value)
let mut root = path.existing;
if let Some(mut existing) = root {
if existing.value() == value {
// There is really nothing to update. Skip traversing upwards.
return path.upper_path.pop().map(|(n, _)| n).unwrap_or(existing);
}
existing.decrement_ref_count().save(nodes_table);
// Else, update the value and rehashe the node so that we can update the hashes upwards.
existing
.set_value(value)
.increment_ref_count()
.save(nodes_table);
root = Some(existing)
} else {
Node::insert(table, key, value, unzip_left_root, unzip_right_root)
// Insert the new node.
let mut node = Node::new(key, value);
// TODO: we do hash the node twice here, can we do better?
node.set_left_child(unzip_left_root)
.set_right_child(unzip_right_root)
.increment_ref_count()
.save(nodes_table);
root = Some(node);
};
for (node, branch) in path.upper_path.iter_mut().rev() {
let mut upper_path = path.upper_path;
// Propagate the new hashes upwards if there are any nodes in the upper_path.
while let Some((mut node, branch)) = upper_path.pop() {
node.decrement_ref_count().save(nodes_table);
match branch {
Branch::Left => root = node.set_left_child(table, Some(root)),
Branch::Right => root = node.set_right_child(table, Some(root)),
}
Branch::Left => node.set_left_child(root.as_mut()),
Branch::Right => node.set_right_child(root.as_mut()),
};
node.increment_ref_count().save(nodes_table);
root = Some(node);
}
// Finally return the new root to be committed.
root
// Finally return the new root to be set to the root.
root.expect("Root should be set by now")
}
#[derive(Debug)]
@@ -130,11 +177,11 @@ struct BinarySearchPath {
///
/// If a match was found, the `lower_path` will be empty.
fn binary_search_path(
table: &'_ mut Table<&'static [u8], (u64, &'static [u8])>,
root: Option<Hash>,
table: &Table<&'static [u8], (u64, &'static [u8])>,
root: Option<Node>,
key: &[u8],
) -> BinarySearchPath {
let rank = rank(key);
let rank = hash(key);
let mut result = BinarySearchPath {
upper_path: Default::default(),
@@ -142,43 +189,182 @@ fn binary_search_path(
unzip_path: Default::default(),
};
let mut previous_hash = root;
let mut next = root;
while let Some(current_hash) = previous_hash {
let current_node = Node::open(table, current_hash).expect("Node not found!");
// Decrement each node in the binary search path.
// if it doesn't change, we will increment it again later.
//
// It is important then to terminate the loop if we found an exact match,
// as lower nodes shouldn't change then.
current_node.decrement_ref_count(table);
let path = if current_node.rank().as_bytes() > rank.as_bytes() {
while let Some(current) = next {
let path = if current.rank().as_bytes() > rank.as_bytes() {
&mut result.upper_path
} else {
&mut result.unzip_path
};
match key.cmp(current_node.key()) {
match key.cmp(current.key()) {
Ordering::Equal => {
// We found exact match. terminate the search.
result.existing = Some(current_node);
result.existing = Some(current);
return result;
}
Ordering::Less => {
previous_hash = *current_node.left();
next = current.left().and_then(|n| Node::open(table, n));
path.push((current_node, Branch::Left));
path.push((current, Branch::Left));
}
Ordering::Greater => {
previous_hash = *current_node.right();
next = current.right().and_then(|n| Node::open(table, n));
path.push((current_node, Branch::Right));
path.push((current, Branch::Right));
}
};
}
result
}
#[cfg(test)]
mod test {
use crate::test::{test_operations, Entry, Operation};
#[test]
fn insert_single_entry() {
let case = ["A"];
let expected = case.map(|key| Entry {
key: key.as_bytes().to_vec(),
value: [b"v", key.as_bytes()].concat(),
});
test_operations(
&expected.clone().map(|e| (e, Operation::Insert)),
&expected,
Some("78fd7507ef338f1a5816ffd702394999680a9694a85f4b8af77795d9fdd5854d"),
)
}
#[test]
fn sorted_alphabets() {
let case = [
"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q",
"R", "S", "T", "U", "V", "W", "X", "Y", "Z",
];
let expected = case.map(|key| Entry {
key: key.as_bytes().to_vec(),
value: [b"v", key.as_bytes()].concat(),
});
test_operations(
&expected.clone().map(|e| (e, Operation::Insert)),
&expected,
Some("02af3de6ed6368c5abc16f231a17d1140e7bfec483c8d0aa63af4ef744d29bc3"),
);
}
#[test]
fn reverse_alphabets() {
let mut case = [
"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q",
"R", "S", "T", "U", "V", "W", "X", "Y", "Z",
];
case.reverse();
let expected = case.map(|key| Entry {
key: key.as_bytes().to_vec(),
value: [b"v", key.as_bytes()].concat(),
});
test_operations(
&expected.clone().map(|e| (e, Operation::Insert)),
&expected,
Some("02af3de6ed6368c5abc16f231a17d1140e7bfec483c8d0aa63af4ef744d29bc3"),
)
}
#[test]
fn unsorted() {
let case = ["D", "N", "P", "X", "A", "G", "C", "M", "H", "I", "J"];
let expected = case.map(|key| Entry {
key: key.as_bytes().to_vec(),
value: [b"v", key.as_bytes()].concat(),
});
test_operations(
&expected.clone().map(|e| (e, Operation::Insert)),
&expected,
Some("0957cc9b87c11cef6d88a95328cfd9043a3d6a99e9ba35ee5c9c47e53fb6d42b"),
)
}
#[test]
fn upsert_at_root() {
let case = ["X", "X"];
let mut i = 0;
let entries = case.map(|key| {
i += 1;
Entry {
key: key.as_bytes().to_vec(),
value: i.to_string().into(),
}
});
test_operations(
&entries.clone().map(|e| (e, Operation::Insert)),
&entries[1..],
Some("4538b4de5e58f9be9d54541e69fab8c94c31553a1dec579227ef9b572d1c1dff"),
)
}
#[test]
fn upsert_deeper() {
// X has higher rank.
let case = ["X", "F", "F"];
let mut i = 0;
let entries = case.map(|key| {
i += 1;
Entry {
key: key.as_bytes().to_vec(),
value: i.to_string().into(),
}
});
let mut expected = entries.to_vec();
expected.sort_by(|a, b| a.key.cmp(&b.key));
test_operations(
&entries.clone().map(|e| (e, Operation::Insert)),
&expected[1..],
Some("c9f7aaefb18ec8569322b9621fc64f430a7389a790e0bf69ec0ad02879d6ce54"),
)
}
#[test]
fn upsert_root_with_children() {
// X has higher rank.
let case = ["F", "X", "X"];
let mut i = 0;
let entries = case.map(|key| {
i += 1;
Entry {
key: key.as_bytes().to_vec(),
value: i.to_string().into(),
}
});
let mut expected = entries.to_vec();
expected.remove(1);
test_operations(
&entries.clone().map(|e| (e, Operation::Insert)),
&expected,
Some("02e26311f2b55bf6d4a7163399f99e17c975891a05af2f1e09bc969f8bf0f95d"),
)
}
}

View File

@@ -7,162 +7,18 @@ use crate::Hash;
use redb::backends::InMemoryBackend;
use redb::Database;
#[test]
fn cases() {
let sorted_alphabets = [
"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R",
"S", "T", "U", "V", "W", "X", "Y", "Z",
]
.map(|key| Entry {
key: key.as_bytes().to_vec(),
value: [b"v", key.as_bytes()].concat(),
});
let mut reverse_alphabets = sorted_alphabets.clone();
reverse_alphabets.reverse();
let unsorted = ["D", "N", "P", "X", "A", "G", "C", "M", "H", "I", "J"].map(|key| Entry {
key: key.as_bytes().to_vec(),
value: [b"v", key.as_bytes()].concat(),
});
let single_entry = ["X"].map(|key| Entry {
key: key.as_bytes().to_vec(),
value: [b"v", key.as_bytes()].concat(),
});
let upsert_at_root = ["X", "X"]
.iter()
.enumerate()
.map(|(i, _)| {
(
Entry {
key: b"X".to_vec(),
value: i.to_string().into(),
},
Operation::Insert,
)
})
.collect::<Vec<_>>();
// X has higher rank.
let upsert_deeper = ["X", "F", "F"]
.iter()
.enumerate()
.map(|(i, key)| {
(
Entry {
key: key.as_bytes().to_vec(),
value: i.to_string().into(),
},
Operation::Insert,
)
})
.collect::<Vec<_>>();
let mut upsert_deeper_expected = upsert_deeper.clone();
upsert_deeper_expected.remove(upsert_deeper.len() - 2);
// X has higher rank.
let upsert_root_with_children = ["F", "X", "X"]
.iter()
.enumerate()
.map(|(i, key)| {
(
Entry {
key: key.as_bytes().to_vec(),
value: i.to_string().into(),
},
Operation::Insert,
)
})
.collect::<Vec<_>>();
let mut upsert_root_with_children_expected = upsert_root_with_children.clone();
upsert_root_with_children_expected.remove(upsert_root_with_children.len() - 2);
let cases = [
(
"sorted alphabets",
sorted_alphabets
.clone()
.map(|e| (e, Operation::Insert))
.to_vec(),
sorted_alphabets.to_vec(),
Some("02af3de6ed6368c5abc16f231a17d1140e7bfec483c8d0aa63af4ef744d29bc3"),
),
(
"reversed alphabets",
sorted_alphabets
.clone()
.map(|e| (e, Operation::Insert))
.to_vec(),
sorted_alphabets.to_vec(),
Some("02af3de6ed6368c5abc16f231a17d1140e7bfec483c8d0aa63af4ef744d29bc3"),
),
(
"unsorted alphabets",
unsorted.clone().map(|e| (e, Operation::Insert)).to_vec(),
unsorted.to_vec(),
Some("0957cc9b87c11cef6d88a95328cfd9043a3d6a99e9ba35ee5c9c47e53fb6d42b"),
),
(
"Single insert",
single_entry
.clone()
.map(|e| (e, Operation::Insert))
.to_vec(),
single_entry.to_vec(),
Some("b3e862d316e6f5caca72c8f91b7a15015b4f7f8f970c2731433aad793f7fe3e6"),
),
(
"upsert at root without children",
upsert_at_root.clone(),
upsert_at_root[1..]
.iter()
.map(|(e, _)| e.clone())
.collect::<Vec<_>>(),
Some("b1353174e730b9ff6850577357fd9ff608071bbab46ebe72c434133f5d4f0383"),
),
(
"upsert deeper",
upsert_deeper.to_vec(),
upsert_deeper_expected
.to_vec()
.iter()
.map(|(e, _)| e.clone())
.collect::<Vec<_>>(),
Some("58272c9e8c9e6b7266e4b60e45d55257b94e85561997f1706e0891ee542a8cd5"),
),
(
"upsert at root with children",
upsert_root_with_children.to_vec(),
upsert_root_with_children_expected
.to_vec()
.iter()
.map(|(e, _)| e.clone())
.collect::<Vec<_>>(),
Some("f46daf022dc852cd4e60a98a33de213f593e17bcd234d9abff7a178d8a5d0761"),
),
];
for case in cases {
test(case.0, &case.1, &case.2, case.3);
}
}
// === Helpers ===
#[derive(Clone, Debug)]
enum Operation {
pub enum Operation {
Insert,
Delete,
}
#[derive(Clone, PartialEq)]
struct Entry {
key: Vec<u8>,
value: Vec<u8>,
pub struct Entry {
pub(crate) key: Vec<u8>,
pub(crate) value: Vec<u8>,
}
impl std::fmt::Debug for Entry {
@@ -171,7 +27,7 @@ impl std::fmt::Debug for Entry {
}
}
fn test(name: &str, input: &[(Entry, Operation)], expected: &[Entry], root_hash: Option<&str>) {
pub fn test_operations(input: &[(Entry, Operation)], expected: &[Entry], root_hash: Option<&str>) {
let inmemory = InMemoryBackend::new();
let db = Database::builder()
.create_with_backend(inmemory)
@@ -184,18 +40,20 @@ fn test(name: &str, input: &[(Entry, Operation)], expected: &[Entry], root_hash:
Operation::Insert => treap.insert(&entry.key, &entry.value),
Operation::Delete => todo!(),
}
println!(
"{:?} {:?}\n{}",
&entry.key,
&entry.value,
into_mermaid_graph(&treap)
);
}
// Uncomment to see the graph (only if values are utf8)
// println!("{}", into_mermaid_graph(&treap));
let collected = treap
.iter()
.map(|n| {
assert_eq!(*n.ref_count(), 1_u64, "Node has wrong ref count");
assert_eq!(
*n.ref_count(),
1_u64,
"{}",
format!("Node has wrong ref count {:?}", n)
);
Entry {
key: n.key().to_vec(),
@@ -207,22 +65,13 @@ fn test(name: &str, input: &[(Entry, Operation)], expected: &[Entry], root_hash:
let mut sorted = expected.to_vec();
sorted.sort_by(|a, b| a.key.cmp(&b.key));
// println!("{}", into_mermaid_graph(&treap));
verify_ranks(&treap);
assert_eq!(collected, sorted, "{}", format!("Entries do not match"));
if root_hash.is_some() {
assert_root(&treap, root_hash.unwrap());
} else {
dbg!(&treap.root_hash());
verify_ranks(&treap);
}
assert_eq!(
collected,
sorted,
"{}",
format!("Entries do not match at: \"{}\"", name)
);
}
/// Verify that every node has higher rank than its children.

View File

@@ -65,12 +65,15 @@ impl<'treap> HashTreap<'treap> {
let mut roots_table = write_txn.open_table(ROOTS_TABLE).unwrap();
let mut nodes_table = write_txn.open_table(NODES_TABLE).unwrap();
let root = self.root_hash_inner(&roots_table);
let old_root = self
.root_hash_inner(&roots_table)
.and_then(|hash| Node::open(&nodes_table, hash));
let new_root = crate::operations::insert::insert(&mut nodes_table, root, key, value);
let new_root =
crate::operations::insert::insert(&mut nodes_table, old_root, key, value);
roots_table
.insert(self.name.as_bytes(), new_root.as_bytes().as_slice())
.insert(self.name.as_bytes(), new_root.hash().as_bytes().as_slice())
.unwrap();
};