From ebd4ef32d0a0ce8515e3b12b53ca310898c7bb17 Mon Sep 17 00:00:00 2001 From: nazeh Date: Wed, 20 Dec 2023 23:03:25 +0300 Subject: [PATCH] feat: insert passes all the eyeball tests --- mast/src/node.rs | 60 ++++++++++++++++--- mast/src/operations/insert.rs | 109 ++++++++++++++++++---------------- mast/src/treap.rs | 46 +++++++------- 3 files changed, 132 insertions(+), 83 deletions(-) diff --git a/mast/src/node.rs b/mast/src/node.rs index 68a313b..329be72 100644 --- a/mast/src/node.rs +++ b/mast/src/node.rs @@ -37,8 +37,8 @@ enum RefCountDiff { } impl Node { - pub(crate) fn open<'a>( - table: &'a impl ReadableTable<&'static [u8], (u64, &'static [u8])>, + pub(crate) fn open( + table: &'_ impl ReadableTable<&'static [u8], (u64, &'static [u8])>, hash: Hash, ) -> Option { let mut existing = table.get(hash.as_bytes().as_slice()).unwrap(); @@ -54,12 +54,18 @@ impl Node { }) } - pub(crate) fn insert(table: &mut Table<&[u8], (u64, &[u8])>, key: &[u8], value: &[u8]) -> Hash { - let node = Self { + pub(crate) fn insert( + table: &mut Table<&[u8], (u64, &[u8])>, + key: &[u8], + value: &[u8], + left: Option, + right: Option, + ) -> Hash { + let mut node = Self { key: key.into(), value: value.into(), - left: None, - right: None, + left, + right, ref_count: 1, }; @@ -75,6 +81,24 @@ impl Node { hash } + /// Set the left child, save the updated node, and return the new hash. + pub(crate) fn set_left_child( + &mut self, + table: &mut Table<&[u8], (u64, &[u8])>, + child: Option, + ) -> Hash { + self.set_child(table, Branch::Left, child) + } + + /// Set the right child, save the updated node, and return the new hash. + pub(crate) fn set_right_child( + &mut self, + table: &mut Table<&[u8], (u64, &[u8])>, + child: Option, + ) -> Hash { + self.set_child(table, Branch::Right, child) + } + // === Getters === pub(crate) fn key(&self) -> &[u8] { @@ -110,6 +134,28 @@ impl Node { // === Private Methods === + fn set_child( + &mut self, + table: &mut Table<&[u8], (u64, &[u8])>, + branch: Branch, + child: Option, + ) -> Hash { + match branch { + Branch::Left => self.left = child, + Branch::Right => self.right = child, + } + + let encoded = self.canonical_encode(); + let hash = hash(&encoded); + + table.insert( + hash.as_bytes().as_slice(), + (self.ref_count, encoded.as_slice()), + ); + + hash + } + fn increment_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) { self.update_ref_count(table, RefCountDiff::Increment) } @@ -172,7 +218,7 @@ fn encode(bytes: &[u8], out: &mut Vec) { fn decode(bytes: &[u8]) -> (&[u8], &[u8]) { let (len, remaining) = varu64::decode(bytes).unwrap(); let value = &remaining[..len as usize]; - let rest = &remaining[value.len() as usize..]; + let rest = &remaining[value.len()..]; (value, rest) } diff --git a/mast/src/operations/insert.rs b/mast/src/operations/insert.rs index 69369a7..abd98af 100644 --- a/mast/src/operations/insert.rs +++ b/mast/src/operations/insert.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use crate::node::{rank, Branch, Node}; use crate::treap::{HashTreap, NODES_TABLE, ROOTS_TABLE}; use crate::HASH_LEN; @@ -81,84 +83,91 @@ use redb::{Database, ReadTransaction, ReadableTable, Table, TableDefinition, Wri // The simplest way to do so, is to decrement all the nodes in the search path, and then increment // all then new nodes (in both the upper and lower paths) before comitting the write transaction. -pub fn insert<'a>( - table: &'a mut Table<&'static [u8], (u64, &'static [u8])>, +pub fn insert( + table: &'_ mut Table<&'static [u8], (u64, &'static [u8])>, root: Option, key: &[u8], value: &[u8], -) { +) -> Hash { let mut path = binary_search_path(table, root, key); - path.iter_mut() - .for_each(|(node, _)| node.decrement_ref_count(table)); + let mut unzip_left_root: Option = None; + let mut unzip_right_root: Option = None; - let mut unzipped_left: Option = None; - let mut unzipped_right: Option = None; - let mut upper_path: Option = None; - - let rank = rank(key); - - for (node, branch) in path.iter().rev() { - match node.rank().as_bytes().cmp(&rank.as_bytes()) { - std::cmp::Ordering::Equal => { - // We found an exact match, we update the value and proceed. - - upper_path = Some(Node::insert(table, key, value)) - } - std::cmp::Ordering::Less => { - // previous_hash = *current_node.left(); - // - // path.push((current_node, Branch::Left)); - } - std::cmp::Ordering::Greater => { - // previous_hash = *current_node.right(); - // - // path.push((current_node, Branch::Right)); - } + for (node, branch) in path.unzip_path.iter_mut().rev() { + match branch { + Branch::Right => unzip_left_root = Some(node.set_right_child(table, unzip_left_root)), + Branch::Left => unzip_right_root = Some(node.set_left_child(table, unzip_right_root)), } } - // if let Some((node, _)) = path.last_mut() { - // // If the last node is an exact match - // } else { - // // handle lower path - // } + let mut root = Node::insert(table, key, value, unzip_left_root, unzip_right_root); + + for (node, branch) in path.upper_path.iter_mut().rev() { + match branch { + Branch::Left => root = node.set_left_child(table, Some(root)), + Branch::Right => root = node.set_right_child(table, Some(root)), + } + } + + // Finally return the new root to be committed. + root } -/// Returns the current nodes from the root to the insertion point on the binary search path. -fn binary_search_path<'a>( - nodes_table: &'a impl ReadableTable<&'static [u8], (u64, &'static [u8])>, +struct BinarySearchPath { + upper_path: Vec<(Node, Branch)>, + exact_match: Option, + unzip_path: Vec<(Node, Branch)>, +} + +fn binary_search_path( + table: &'_ mut Table<&'static [u8], (u64, &'static [u8])>, root: Option, key: &[u8], -) -> Vec<(Node, Branch)> { - let mut path: Vec<(Node, Branch)> = Vec::new(); +) -> BinarySearchPath { + let rank = rank(key); + + let mut result = BinarySearchPath { + upper_path: Default::default(), + exact_match: None, + unzip_path: Default::default(), + }; let mut previous_hash = root; while let Some(current_hash) = previous_hash { - let current_node = Node::open(nodes_table, current_hash).expect("Node not found!"); + let current_node = Node::open(table, current_hash).expect("Node not found!"); - let current_key = current_node.key(); + // Decrement each node in the binary search path. + // if it doesn't change, we will increment it again later. + // + // It is important then to terminate the loop if we found an exact match, + // as lower nodes shouldn't change then. + current_node.decrement_ref_count(table); - match key.cmp(current_key) { - std::cmp::Ordering::Equal => { - // We found an exact match, we don't need to unzip the rest. - // Branch here doesn't matter - path.push((current_node, Branch::Left)); - break; + let mut path = if current_node.rank().as_bytes() > rank.as_bytes() { + &mut result.upper_path + } else { + &mut result.unzip_path + }; + + match key.cmp(current_node.key()) { + Ordering::Equal => { + // We found exact match. terminate the search. + return result; } - std::cmp::Ordering::Less => { + Ordering::Less => { previous_hash = *current_node.left(); path.push((current_node, Branch::Left)); } - std::cmp::Ordering::Greater => { + Ordering::Greater => { previous_hash = *current_node.right(); path.push((current_node, Branch::Right)); } - } + }; } - path + result } diff --git a/mast/src/treap.rs b/mast/src/treap.rs index 40a64c2..e2a2dfc 100644 --- a/mast/src/treap.rs +++ b/mast/src/treap.rs @@ -3,6 +3,9 @@ use redb::*; use crate::{node::Node, HASH_LEN}; +// TODO: test that order is correct +// TODO: test that there are no extr anodes. + // Table: Nodes v0 // stores all the hash treap nodes from all the treaps in the storage. // @@ -53,14 +56,13 @@ impl<'treap> HashTreap<'treap> { .and_then(|hash| Node::open(&nodes_table, hash)) } - fn root_hash<'a>( + fn root_hash( &self, - table: &'a impl ReadableTable<&'static [u8], &'static [u8]>, + table: &'_ impl ReadableTable<&'static [u8], &'static [u8]>, ) -> Option { let existing = table.get(self.name.as_bytes()).unwrap(); - if existing.is_none() { - return None; - } + existing.as_ref()?; + let hash = existing.unwrap(); let hash: [u8; HASH_LEN] = hash.value().try_into().expect("Invalid root hash"); @@ -75,13 +77,15 @@ impl<'treap> HashTreap<'treap> { let write_txn = self.db.begin_write().unwrap(); - 'transaction: { - let roots_table = write_txn.open_table(ROOTS_TABLE).unwrap(); + { + let mut roots_table = write_txn.open_table(ROOTS_TABLE).unwrap(); let mut nodes_table = write_txn.open_table(NODES_TABLE).unwrap(); let root = self.root_hash(&roots_table); - crate::operations::insert::insert(&mut nodes_table, root, key, value) + let new_root = crate::operations::insert::insert(&mut nodes_table, root, key, value); + + roots_table.insert(self.name.as_bytes(), new_root.as_bytes().as_slice()); }; // Finally commit the changes to the storage. @@ -90,7 +94,13 @@ impl<'treap> HashTreap<'treap> { // === Private Methods === + // === Test Methods === + + // TODO: move tests and test helper methods to separate module. + // Only keep the public methods here, and probably move it to lib.rs too. + /// Create a read transaction and get a node from the nodes table. + #[cfg(test)] pub(crate) fn get_node(&self, hash: &Option) -> Option { let read_txn = self.db.begin_read().unwrap(); let table = read_txn.open_table(NODES_TABLE).unwrap(); @@ -98,11 +108,6 @@ impl<'treap> HashTreap<'treap> { hash.and_then(|h| Node::open(&table, h)) } - // === Test Methods === - - // TODO: move tests and test helper methods to separate module. - // Only keep the public methods here, and probably move it to lib.rs too. - #[cfg(test)] fn verify_ranks(&self) -> bool { self.check_rank(self.root()) @@ -190,18 +195,7 @@ mod test { let mut treap = HashTreap::new(&db, "test"); - // TODO: fix this cases - let mut keys = [ - // "D", "N", "P", - "X", // "F", "Z", "Y", - "A", // - // "G", // - // "C", // - //"M", "H", "I", "J", - ]; - - // TODO: fix without sort. - // keys.sort(); + let mut keys = ["D", "N", "P", "X", "A", "G", "C", "M", "H", "I", "J"]; for key in keys.iter() { treap.insert(key.as_bytes(), b"0"); @@ -241,7 +235,7 @@ mod test { let mut treap = HashTreap::new(&db, "test"); - let mut keys = ["F", "X", "X"]; + let keys = ["F", "X", "X"]; for key in keys.iter() { treap.insert(key.as_bytes(), b"0");