From cb3d0194cd3f9ba7335b34c0ac00f9dca72e9428 Mon Sep 17 00:00:00 2001 From: nazeh Date: Thu, 21 Dec 2023 19:27:16 +0300 Subject: [PATCH] fix: upsert exact key don't drop children --- mast/src/node.rs | 58 +++++++++++----- mast/src/operations/insert.rs | 10 ++- mast/src/test.rs | 124 ++++++++++++++++++++-------------- 3 files changed, 124 insertions(+), 68 deletions(-) diff --git a/mast/src/node.rs b/mast/src/node.rs index 55947e1..3b5716c 100644 --- a/mast/src/node.rs +++ b/mast/src/node.rs @@ -81,24 +81,6 @@ impl Node { hash } - /// Set the left child, save the updated node, and return the new hash. - pub(crate) fn set_left_child( - &mut self, - table: &mut Table<&[u8], (u64, &[u8])>, - child: Option, - ) -> Hash { - self.set_child(table, Branch::Left, child) - } - - /// Set the right child, save the updated node, and return the new hash. - pub(crate) fn set_right_child( - &mut self, - table: &mut Table<&[u8], (u64, &[u8])>, - child: Option, - ) -> Hash { - self.set_child(table, Branch::Right, child) - } - // === Getters === pub fn key(&self) -> &[u8] { @@ -128,6 +110,34 @@ impl Node { hash(&self.canonical_encode()) } + /// Set the value and save the updated node. + pub(crate) fn set_value( + &mut self, + table: &mut Table<&[u8], (u64, &[u8])>, + value: &[u8], + ) -> Hash { + self.value = value.into(); + self.save(table) + } + + /// Set the left child, save the updated node, and return the new hash. + pub(crate) fn set_left_child( + &mut self, + table: &mut Table<&[u8], (u64, &[u8])>, + child: Option, + ) -> Hash { + self.set_child(table, Branch::Left, child) + } + + /// Set the right child, save the updated node, and return the new hash. + pub(crate) fn set_right_child( + &mut self, + table: &mut Table<&[u8], (u64, &[u8])>, + child: Option, + ) -> Hash { + self.set_child(table, Branch::Right, child) + } + // === Private Methods === pub fn decrement_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) { @@ -156,6 +166,18 @@ impl Node { hash } + fn save(&mut self, table: &mut Table<&[u8], (u64, &[u8])>) -> Hash { + let encoded = self.canonical_encode(); + let hash = hash(&encoded); + + table.insert( + hash.as_bytes().as_slice(), + (self.ref_count, encoded.as_slice()), + ); + + hash + } + fn increment_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) { self.update_ref_count(table, RefCountDiff::Increment) } diff --git a/mast/src/operations/insert.rs b/mast/src/operations/insert.rs index d2f8e2a..ddfee3f 100644 --- a/mast/src/operations/insert.rs +++ b/mast/src/operations/insert.rs @@ -90,6 +90,7 @@ pub fn insert( value: &[u8], ) -> Hash { let mut path = binary_search_path(table, root, key); + dbg!(&path); let mut unzip_left_root: Option = None; let mut unzip_right_root: Option = None; @@ -101,7 +102,11 @@ pub fn insert( } } - let mut root = Node::insert(table, key, value, unzip_left_root, unzip_right_root); + let mut root = if let Some(mut existing) = path.existing { + existing.set_value(table, value) + } else { + Node::insert(table, key, value, unzip_left_root, unzip_right_root) + }; for (node, branch) in path.upper_path.iter_mut().rev() { match branch { @@ -114,6 +119,7 @@ pub fn insert( root } +#[derive(Debug)] struct BinarySearchPath { upper_path: Vec<(Node, Branch)>, existing: Option, @@ -160,6 +166,8 @@ fn binary_search_path( match key.cmp(current_node.key()) { Ordering::Equal => { // We found exact match. terminate the search. + + result.existing = Some(current_node); return result; } Ordering::Less => { diff --git a/mast/src/test.rs b/mast/src/test.rs index 8cce6e7..f841277 100644 --- a/mast/src/test.rs +++ b/mast/src/test.rs @@ -29,46 +29,55 @@ fn cases() { value: [b"v", key.as_bytes()].concat(), }); - let upsert_at_root = [ - ( - Entry { - key: b"X".to_vec(), - value: b"A".to_vec(), - }, - Operation::Insert, - ), - (( - Entry { - key: b"X".to_vec(), - value: b"B".to_vec(), - }, - Operation::Insert, - )), - ]; + let upsert_at_root = ["X", "X"] + .iter() + .enumerate() + .map(|(i, key)| { + ( + Entry { + key: b"X".to_vec(), + value: i.to_string().into(), + }, + Operation::Insert, + ) + }) + .collect::>(); - let upsert_deeper = [ - ( - Entry { - key: b"F".to_vec(), - value: b"A".to_vec(), - }, - Operation::Insert, - ), - ( - Entry { - key: b"X".to_vec(), - value: b"A".to_vec(), - }, - Operation::Insert, - ), - (( - Entry { - key: b"X".to_vec(), - value: b"B".to_vec(), - }, - Operation::Insert, - )), - ]; + // X has higher rank. + let upsert_deeper = ["X", "F", "F"] + .iter() + .enumerate() + .map(|(i, key)| { + ( + Entry { + key: key.as_bytes().to_vec(), + value: i.to_string().into(), + }, + Operation::Insert, + ) + }) + .collect::>(); + + let mut upsert_deeper_expected = upsert_deeper.clone(); + upsert_deeper_expected.remove(upsert_deeper.len() - 2); + + // X has higher rank. + let upsert_root_with_children = ["F", "X", "X"] + .iter() + .enumerate() + .map(|(i, key)| { + ( + Entry { + key: key.as_bytes().to_vec(), + value: i.to_string().into(), + }, + Operation::Insert, + ) + }) + .collect::>(); + + let mut upsert_root_with_children_expected = upsert_root_with_children.clone(); + upsert_root_with_children_expected.remove(upsert_root_with_children.len() - 2); let cases = [ ( @@ -105,23 +114,33 @@ fn cases() { Some("b3e862d316e6f5caca72c8f91b7a15015b4f7f8f970c2731433aad793f7fe3e6"), ), ( - "upsert at root", - upsert_at_root.to_vec(), + "upsert at root without children", + upsert_at_root.clone(), upsert_at_root[1..] .iter() .map(|(e, _)| e.clone()) .collect::>(), - Some("2947139081bbcc3816ebd73cb81ac0be5c564df55b88d6dbeb52c5254c1de887"), + Some("b1353174e730b9ff6850577357fd9ff608071bbab46ebe72c434133f5d4f0383"), ), ( "upsert deeper", upsert_deeper.to_vec(), - upsert_at_root[0..2] + upsert_deeper_expected + .to_vec() .iter() .map(|(e, _)| e.clone()) .collect::>(), - // Some("2947139081bbcc3816ebd73cb81ac0be5c564df55b88d6dbeb52c5254c1de887"), - None, + Some("58272c9e8c9e6b7266e4b60e45d55257b94e85561997f1706e0891ee542a8cd5"), + ), + ( + "upsert at root with children", + upsert_root_with_children.to_vec(), + upsert_root_with_children_expected + .to_vec() + .iter() + .map(|(e, _)| e.clone()) + .collect::>(), + Some("f46daf022dc852cd4e60a98a33de213f593e17bcd234d9abff7a178d8a5d0761"), ), ]; @@ -132,7 +151,7 @@ fn cases() { // === Helpers === -#[derive(Clone)] +#[derive(Clone, Debug)] enum Operation { Insert, Delete, @@ -150,7 +169,7 @@ impl std::fmt::Debug for Entry { } } -fn test(name: &str, input: &[(Entry, Operation)], output: &[Entry], root_hash: Option<&str>) { +fn test(name: &str, input: &[(Entry, Operation)], expected: &[Entry], root_hash: Option<&str>) { let inmemory = InMemoryBackend::new(); let db = Database::builder() .create_with_backend(inmemory) @@ -163,6 +182,12 @@ fn test(name: &str, input: &[(Entry, Operation)], output: &[Entry], root_hash: O Operation::Insert => treap.insert(&entry.key, &entry.value), Operation::Delete => todo!(), } + println!( + "{:?} {:?}\n{}", + &entry.key, + &entry.value, + into_mermaid_graph(&treap) + ); } let collected = treap @@ -173,10 +198,11 @@ fn test(name: &str, input: &[(Entry, Operation)], output: &[Entry], root_hash: O }) .collect::>(); - let mut sorted = output.to_vec(); + let mut sorted = expected.to_vec(); sorted.sort_by(|a, b| a.key.cmp(&b.key)); - // dbg!(&treap.root_hash()); + dbg!(&treap.root_hash()); + dbg!(&input, &expected); println!("{}", into_mermaid_graph(&treap)); if root_hash.is_some() {