diff --git a/mast/src/lib.rs b/mast/src/lib.rs index 2a0d595..dbfb5f1 100644 --- a/mast/src/lib.rs +++ b/mast/src/lib.rs @@ -2,6 +2,9 @@ mod mermaid; mod node; +mod operations; pub mod treap; pub(crate) use blake3::{Hash, Hasher}; + +pub const HASH_LEN: usize = 32; diff --git a/mast/src/mermaid.rs b/mast/src/mermaid.rs index 036274b..54877df 100644 --- a/mast/src/mermaid.rs +++ b/mast/src/mermaid.rs @@ -3,13 +3,13 @@ mod test { use crate::node::Node; use crate::treap::HashTreap; - impl<'a> HashTreap<'a> { + impl<'treap> HashTreap<'treap> { pub fn as_mermaid_graph(&self) -> String { let mut graph = String::new(); graph.push_str("graph TD;\n"); - if let Some(root) = self.root.clone() { + if let Some(root) = self.root() { self.build_graph_string(&root, &mut graph); } diff --git a/mast/src/node.rs b/mast/src/node.rs index ef7b72d..a67798e 100644 --- a/mast/src/node.rs +++ b/mast/src/node.rs @@ -1,11 +1,13 @@ -use redb::{Database, ReadableTable, Table, TableDefinition, WriteTransaction}; +//! In memory representation of a treap node. -use crate::{Hash, Hasher}; +use redb::{ReadableTable, Table}; + +use crate::{Hash, Hasher, HASH_LEN}; // TODO: Are we creating too many hashers? // TODO: are we calculating the rank and hash too often? - -const HASH_LEN: usize = 32; +// TODO: remove unused +// TODO: remove unwrap #[derive(Debug, Clone)] /// In memory reprsentation of treap node. @@ -28,6 +30,12 @@ pub(crate) enum Branch { Right, } +#[derive(Debug)] +enum RefCountDiff { + Increment, + Decrement, +} + impl Node { pub fn new(key: &[u8], value: &[u8]) -> Self { Self { @@ -58,7 +66,7 @@ impl Node { } }; - let (right, rest) = decode(rest); + let (right, _) = decode(rest); let right = match right.len() { 0 => None, 32 => { @@ -109,31 +117,7 @@ impl Node { hash(&self.canonical_encode()) } - pub(crate) fn set_child( - &mut self, - branch: &Branch, - new_child: Option, - table: &mut Table<&[u8], (u64, &[u8])>, - ) { - let old_child = match branch { - Branch::Left => self.left, - Branch::Right => self.right, - }; - - // increment old child's ref count. - decrement_ref_count(old_child, table); - - // increment new child's ref count. - increment_ref_count(new_child, table); - - // set new child - match branch { - Branch::Left => self.left = new_child, - Branch::Right => self.right = new_child, - } - - self.save(table); - } + pub(crate) fn decrement_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) {} pub(crate) fn save(&self, table: &mut Table<&[u8], (u64, &[u8])>) { let encoded = self.canonical_encode(); @@ -147,6 +131,27 @@ impl Node { // === Private Methods === + fn update_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>, diff: RefCountDiff) { + let ref_count = match diff { + RefCountDiff::Increment => self.ref_count + 1, + RefCountDiff::Decrement => { + if self.ref_count > 0 { + self.ref_count - 1 + } else { + self.ref_count + } + } + }; + + let bytes = self.canonical_encode(); + let hash = hash(&bytes); + + match ref_count { + 0 => table.remove(hash.as_bytes().as_slice()), + _ => table.insert(hash.as_bytes().as_slice(), (ref_count, bytes.as_slice())), + }; + } + fn canonical_encode(&self) -> Vec { let mut bytes = vec![]; @@ -166,6 +171,40 @@ impl Node { } } +pub(crate) fn rank(key: &[u8]) -> Hash { + hash(key) +} + +/// Returns the node for a given hash. +pub(crate) fn get_node<'a>( + table: &'a impl ReadableTable<&'static [u8], (u64, &'static [u8])>, + hash: &[u8], +) -> Option { + let existing = table.get(hash).unwrap(); + + if existing.is_none() { + return None; + } + let data = existing.unwrap(); + + Some(Node::decode(data.value())) +} + +/// Returns the root hash for a given table. +pub(crate) fn get_root_hash<'a>( + table: &'a impl ReadableTable<&'static [u8], &'static [u8]>, + name: &str, +) -> Option { + let existing = table.get(name.as_bytes()).unwrap(); + if existing.is_none() { + return None; + } + let hash = existing.unwrap(); + + let hash: [u8; HASH_LEN] = hash.value().try_into().expect("Invalid root hash"); + Some(Hash::from_bytes(hash)) +} + fn encode(bytes: &[u8], out: &mut Vec) { // TODO: find a better way to reserve bytes. let current_len = out.len(); @@ -191,64 +230,3 @@ fn hash(bytes: &[u8]) -> Hash { hasher.finalize() } - -#[derive(Debug)] -enum RefCountDiff { - Increment, - Decrement, -} - -pub(crate) fn increment_ref_count(child: Option, table: &mut Table<&[u8], (u64, &[u8])>) { - update_ref_count(child, RefCountDiff::Increment, table); -} - -pub(crate) fn decrement_ref_count(child: Option, table: &mut Table<&[u8], (u64, &[u8])>) { - update_ref_count(child, RefCountDiff::Decrement, table); -} - -fn update_ref_count( - child: Option, - ref_diff: RefCountDiff, - table: &mut Table<&[u8], (u64, &[u8])>, -) { - if let Some(hash) = child { - dbg!("should update child ref", &child); - let mut existing = table - .get(hash.as_bytes().as_slice()) - .unwrap() - .expect("Child shouldn't be messing!"); - - let (ref_count, bytes) = { - let (r, v) = existing.value(); - (r, v.to_vec()) - }; - drop(existing); - dbg!(( - "\n\n decrmenting blah blah blah child", - &child, - &ref_count, - &ref_diff - )); - - let ref_count = match ref_diff { - RefCountDiff::Increment => ref_count + 1, - RefCountDiff::Decrement => { - if ref_count > 0 { - ref_count - 1 - } else { - ref_count - } - } - }; - - match ref_count { - 0 => { - // TODO: Confirm (read: test) this, because it is not easy to see in graphs. - table.remove(hash.as_bytes().as_slice()); - } - _ => { - table.insert(hash.as_bytes().as_slice(), (ref_count, bytes.as_slice())); - } - } - } -} diff --git a/mast/src/operations/gc.rs b/mast/src/operations/gc.rs new file mode 100644 index 0000000..49f6fd9 --- /dev/null +++ b/mast/src/operations/gc.rs @@ -0,0 +1,56 @@ +use blake3::Hash; +use redb::{Database, ReadableTable, Table, TableDefinition, WriteTransaction}; + +#[derive(Debug)] +enum RefCountDiff { + Increment, + Decrement, +} + +pub(crate) fn increment_ref_count(node: Option, table: &mut Table<&[u8], (u64, &[u8])>) { + update_ref_count(node, RefCountDiff::Increment, table); +} + +pub(crate) fn decrement_ref_count(node: Option, table: &mut Table<&[u8], (u64, &[u8])>) { + update_ref_count(node, RefCountDiff::Decrement, table); +} + +fn update_ref_count( + node: Option, + ref_diff: RefCountDiff, + table: &mut Table<&[u8], (u64, &[u8])>, +) { + if let Some(hash) = node { + let mut existing = table + .get(hash.as_bytes().as_slice()) + .unwrap() + .expect("node shouldn't be messing!"); + + let (ref_count, bytes) = { + let (r, v) = existing.value(); + (r, v.to_vec()) + }; + drop(existing); + + let ref_count = match ref_diff { + RefCountDiff::Increment => ref_count + 1, + RefCountDiff::Decrement => { + if ref_count > 0 { + ref_count - 1 + } else { + ref_count + } + } + }; + + match ref_count { + 0 => { + // TODO: Confirm (read: test) this, because it is not easy to see in graphs. + table.remove(hash.as_bytes().as_slice()); + } + _ => { + table.insert(hash.as_bytes().as_slice(), (ref_count, bytes.as_slice())); + } + } + } +} diff --git a/mast/src/operations/insert.rs b/mast/src/operations/insert.rs new file mode 100644 index 0000000..0152450 --- /dev/null +++ b/mast/src/operations/insert.rs @@ -0,0 +1,142 @@ +use crate::node::{get_node, get_root_hash, rank, Branch, Node}; +use crate::treap::{HashTreap, NODES_TABLE, ROOTS_TABLE}; +use crate::HASH_LEN; +use blake3::Hash; +use redb::{Database, ReadTransaction, ReadableTable, Table, TableDefinition, WriteTransaction}; + +// Watch this [video](https://youtu.be/NxRXhBur6Xs?si=GNwaUOfuGwr_tBKI&t=1763) for a good explanation of the unzipping algorithm. +// Also see the Iterative insertion algorithm in the page 12 of the [original paper](https://arxiv.org/pdf/1806.06726.pdf). +// The difference here is that in a Hash Treap, we need to update nodes bottom up. + +// Let's say we have the following tree: +// +// F +// / \ +// D P +// / / \ +// C H X +// / / \ \ +// A G M Y +// / +// I +// +// The binary search path for inserting `J` then is: +// +// F +// \ +// P +// / +// H +// \ +// M +// / +// I +// +// Then we define `upper_path` as the path from the root to the insertion point +// marked by the first node with a `rank` that is either: +// +// - less than the `rank` of the inserted key: +// +// F +// \ +// P +// ∧-- / --∧ upper path if rank(J) > rank(H) +// ∨-- H --∨ unzip path +// \ +// M Note that this is an arbitrary example, +// / do not expect the actual ranks of these keys to be the same in implmentation. +// I +// +// Upper path doesn't change much beyond updating the hash of their child in the branch featured in +// this binary search path. +// +// We call the rest of the path `unzipping path` or `split path` and this is where we create two +// new paths (left and right), each contain the nodes with keys smaller than or larger than the +// inserted key respectively. +// +// We update these unzipped paths from the _bottom up_ to generate the new hashes for their +// parents. +// Once we have the two paths, we use their tips as the new children of the newly inserted node. +// Finally we update the hashes upwards until we reach the new root of the tree. +// +// - equal to the `rank` of the inserted key: +// +// F +// \ +// P +// / +// H --^ upper path if +// rank(H) = rank(H) +// +// This (exact key match) is the only way for the rank to match +// for secure hashes like blake3. +// +// This is a different case since we don't really need to split (unzip) the lower path, we just +// need to update the hash of the node (according to the new value) and update the hash of its +// parents until we reach the root. +// +// Also note that we need to update the `ref_count` of all the nodes, and delete the nodes with +// `ref_count` of zero. +// +// The simplest way to do so, is to decrement all the nodes in the search path, and then increment +// all then new nodes (in both the upper and lower paths) before comitting the write transaction. + +impl<'treap> HashTreap<'treap> { + pub fn insert(&mut self, key: &[u8], value: &[u8]) { + // TODO: validate key and value length. + + let write_txn = self.db.begin_write().unwrap(); + + 'transaction: { + let roots_table = write_txn.open_table(ROOTS_TABLE).unwrap(); + let mut nodes_table = write_txn.open_table(NODES_TABLE).unwrap(); + + let root = get_root_hash(&roots_table, &self.name); + + let mut path = upper_path(key, root, &nodes_table); + + path.iter_mut() + .for_each(|node| node.decrement_ref_count(&mut nodes_table)) + + // if path. + }; + + // Finally commit the changes to the storage. + write_txn.commit().unwrap(); + } +} + +/// Returns the current nodes from the root to the insertion point on the binary search path. +fn upper_path<'a>( + key: &[u8], + root: Option, + nodes_table: &'a impl ReadableTable<&'static [u8], (u64, &'static [u8])>, +) -> Vec { + let rank = rank(key); + + let mut path: Vec = Vec::new(); + + let mut previous_hash = root; + + while let Some(current_hash) = previous_hash { + let current_node = get_node(nodes_table, current_hash.as_bytes()).expect("Node not found!"); + + let current_key = current_node.key(); + + if key == current_key { + // We found an exact match, we don't need to unzip the rest. + path.push(current_node); + break; + } + + if key < current_key { + previous_hash = *current_node.left(); + } else { + previous_hash = *current_node.right(); + } + + path.push(current_node); + } + + path +} diff --git a/mast/src/operations/mod.rs b/mast/src/operations/mod.rs new file mode 100644 index 0000000..20b6b94 --- /dev/null +++ b/mast/src/operations/mod.rs @@ -0,0 +1,2 @@ +pub mod gc; +mod insert; diff --git a/mast/src/treap.rs b/mast/src/treap.rs index f551e13..ac45833 100644 --- a/mast/src/treap.rs +++ b/mast/src/treap.rs @@ -1,272 +1,61 @@ -use blake3::{Hash, Hasher}; -// use redb::{Database, ite, ReadableTable, Table, TableDefinition}; +use blake3::Hash; use redb::*; -use crate::node::{decrement_ref_count, increment_ref_count, Branch, Node}; - -// TODO: remove unused -// TODO: remove unwrap - -#[derive(Debug)] -pub struct HashTreap<'a> { - /// Redb database to store the nodes. - pub(crate) db: &'a Database, - pub(crate) root: Option, -} +use crate::node::{get_node, get_root_hash, Node}; // Table: Nodes v0 +// stores all the hash treap nodes from all the treaps in the storage. +// // Key: `[u8; 32]` # Node hash // Value: `(u64, [u8])` # (RefCount, EncodedNode) -const NODES_TABLE: TableDefinition<&[u8], (u64, &[u8])> = +pub const NODES_TABLE: TableDefinition<&[u8], (u64, &[u8])> = TableDefinition::new("kytz:hash_treap:nodes:v0"); -impl<'a> HashTreap<'a> { +// Table: Roots v0 +// stores all the current roots for all treaps in the storage. +// +// Key: `[u8; 32]` # Treap name +// Value: `[u8; 32]` # Hash +pub const ROOTS_TABLE: TableDefinition<&[u8], &[u8]> = + TableDefinition::new("kytz:hash_treap:roots:v0"); + +#[derive(Debug)] +pub struct HashTreap<'treap> { + /// Redb database to store the nodes. + pub(crate) db: &'treap Database, + pub(crate) name: &'treap str, +} + +impl<'treap> HashTreap<'treap> { // TODO: add name to open from storage with. - pub fn new(db: &'a Database) -> Self { + pub fn new(db: &'treap Database, name: &'treap str) -> Self { // Setup tables let write_tx = db.begin_write().unwrap(); { let _table = write_tx.open_table(NODES_TABLE).unwrap(); + let _table = write_tx.open_table(ROOTS_TABLE).unwrap(); } write_tx.commit().unwrap(); - // TODO: Try to open root (using this treaps or tags table). - // TODO: sould be checking for root on the fly probably! - - Self { root: None, db } + Self { name, db } } - pub fn insert(&mut self, key: &[u8], value: &[u8]) { - // TODO: validate key and value length. + // === Getters === - let mut node = Node::new(key, value); + pub(crate) fn root(&self) -> Option { + let read_txn = self.db.begin_read().unwrap(); - let write_txn = self.db.begin_write().unwrap(); + let roots_table = read_txn.open_table(ROOTS_TABLE).unwrap(); + let nodes_table = read_txn.open_table(NODES_TABLE).unwrap(); - let _ = 'transaction: { - let mut nodes_table = write_txn.open_table(NODES_TABLE).unwrap(); - - if self.root.is_none() { - // We are done. - self.update_root(&node, &mut nodes_table); - break 'transaction; - } - - // Watch this [video](https://youtu.be/NxRXhBur6Xs?si=GNwaUOfuGwr_tBKI&t=1763) for a good explanation of the unzipping algorithm. - // Also see the Iterative insertion algorithm in the page 12 of the [original paper](https://arxiv.org/pdf/1806.06726.pdf). - // The difference here is that in a Hash Treap, we need to update nodes bottom up. - - // Let's say we have the following tree: - // - // F - // / \ - // D P - // / / \ - // C H X - // / / \ \ - // A G M Y - // / - // I - // - // First we mark the binary search path to the leaf, going right if the key is greater than - // the current node's key and vice versa. - // - // F - // \ - // P - // / - // H - // \ - // M - // / - // I - // - - // Path before insertion point. (Node, Branch to update) - let mut top_path: Vec<(Node, Branch)> = Vec::new(); - // Subtree of nodes on the path smaller than the inserted key. - let mut left_unzip_path: Vec = Vec::new(); - // Subtree of nodes on the path larger than the inserted key. - let mut right_unzip_path: Vec = Vec::new(); - - let mut next = self.root.clone().map(|n| n.hash()); - - // Top down traversal of the binary search path. - while let Some(current) = self.get_node(&next) { - let node_rank = node.rank(); - let curr_rank = current.rank(); - - if node_rank.as_bytes() == curr_rank.as_bytes() { - // Same key, we should update the value and return. - self.update_root(&node, &mut nodes_table); - break 'transaction; - } - - let should_zip = node_rank.as_bytes() > curr_rank.as_bytes(); - - // Traverse left or right. - if key < current.key() { - next = *current.left(); - - if should_zip { - right_unzip_path.push(current) - } else { - top_path.push((current, Branch::Left)); - } - } else { - next = *current.right(); - - if should_zip { - left_unzip_path.push(current) - } else { - top_path.push((current, Branch::Right)); - } - }; - } - - // === Updating hashes bottom up === - - // We are at the unzipping part of the path. - // - // First do the unzipping bottom up. - // - // H - // \ - // M < current_right - // / - // I < current_left - // - // Into (hopefully you can see the "unzipping"): - // - // left right - // subtree subtree - // - // H | - // \ | - // I | M - - // dbg!(( - // "unzipping left", - // String::from_utf8(node.key().to_vec()).unwrap(), - // &left_unzip_path - // .iter() - // .map(|n| String::from_utf8(n.key().to_vec()).unwrap()) - // .collect::>(), - // &right_unzip_path - // .iter() - // .map(|n| String::from_utf8(n.key().to_vec()).unwrap()) - // .collect::>(), - // )); - - let left_unzip_path_len = left_unzip_path.len(); - for i in 0..left_unzip_path_len { - if i == left_unzip_path_len - 1 { - // The last node in the path is special, since we need to clear its right - // child from older versions. - let child = left_unzip_path.get_mut(i).unwrap(); - child.set_child(&Branch::Right, None, &mut nodes_table); - - // Skip the last element for the first iterator - break; - } - - let (first, second) = left_unzip_path.split_at_mut(i + 1); - let child = &first[i]; - let parent = &mut second[0]; - - parent.set_child(&Branch::Right, Some(child.hash()), &mut nodes_table); - } - - let right_unzip_path_len = right_unzip_path.len(); - for i in 0..right_unzip_path_len { - if i == right_unzip_path_len - 1 { - // The last node in the path is special, since we need to clear its right - // child from older versions. - let child = right_unzip_path.get_mut(i).unwrap(); - dbg!(("clearing the left child fuckin please", &child)); - child.set_child(&Branch::Left, None, &mut nodes_table); - dbg!(("clearing the left child fuckin please", &child)); - - // Skip the last element for the first iterator - break; - } - - let (first, second) = right_unzip_path.split_at_mut(i + 1); - let child = &first[i]; - let parent = &mut second[0]; - - parent.set_child(&Branch::Left, Some(child.hash()), &mut nodes_table); - } - - // Done unzipping, join the current_left and current_right to J and update hashes upwards. - // - // J < Insertion point. - // / \ - // H M - // \ - // I - - node.set_child( - &Branch::Left, - left_unzip_path.first().map(|n| n.hash()), - &mut nodes_table, - ); - node.set_child( - &Branch::Right, - right_unzip_path.first().map(|n| n.hash()), - &mut nodes_table, - ); - - // Update the rest of the path upwards with the new hashes. - // So the final tree should look like: - // - // F - // / \ - // D P - // / / \ - // C J X - // / / \ \ - // A H M Y - // / \ - // G I - - if top_path.is_empty() { - // The insertion point is at the root and we are done. - self.update_root(&node, &mut nodes_table) - } - - let mut previous = node; - - while let Some((mut parent, branch)) = top_path.pop() { - parent.set_child(&branch, Some(previous.hash()), &mut nodes_table); - - previous = parent; - } - - // Update the root pointer. - self.update_root(&previous, &mut nodes_table) - }; - - // Finally we should commit the changes to the storage. - write_txn.commit().unwrap(); + let hash = get_root_hash(&roots_table, self.name); + hash.and_then(|hash| get_node(&nodes_table, hash.as_bytes())) } // === Private Methods === - fn update_root(&mut self, node: &Node, table: &mut Table<&[u8], (u64, &[u8])>) { - // decrement_ref_count(self.root.clone().map(|n| n.hash()), table); - - node.save(table); - - // The tree is empty, the incoming node has to be the root, and we are done. - self.root = Some(node.clone()); - - // TODO: we need to persist the root change too to the storage. - // TODO: add a tag to persist snapshots. - increment_ref_count(self.root.clone().map(|n| n.hash()), table); - } - + /// Create a read transaction and get a node from the nodes table. pub(crate) fn get_node(&self, hash: &Option) -> Option { let read_txn = self.db.begin_read().unwrap(); let table = read_txn.open_table(NODES_TABLE).unwrap(); @@ -281,10 +70,12 @@ impl<'a> HashTreap<'a> { // === Test Methods === + // TODO: move tests and test helper methods to separate module. + // Only keep the public methods here, and probably move it to lib.rs too. + #[cfg(test)] fn verify_ranks(&self) -> bool { - let node = self.get_node(&self.root.clone().map(|n| n.hash())); - self.check_rank(node) + self.check_rank(self.root()) } #[cfg(test)] @@ -292,11 +83,9 @@ impl<'a> HashTreap<'a> { match node { Some(n) => { let left_check = self.get_node(n.left()).map_or(true, |left| { - dbg!(("left", &left)); n.rank().as_bytes() > left.rank().as_bytes() && self.check_rank(Some(left)) }); let right_check = self.get_node(n.right()).map_or(true, |right| { - dbg!(("right", &right)); n.rank().as_bytes() > right.rank().as_bytes() && self.check_rank(Some(right)) }); @@ -308,6 +97,7 @@ impl<'a> HashTreap<'a> { #[cfg(test)] fn list_all_nodes(&self) { + // TODO: return all the nodes to verify GC in the test, or verify it here. let read_txn = self.db.begin_read().unwrap(); let nodes_table = read_txn.open_table(NODES_TABLE).unwrap(); @@ -337,17 +127,17 @@ mod test { use super::HashTreap; use super::Node; + use redb::backends::InMemoryBackend; use redb::{Database, Error, ReadableTable, TableDefinition}; // TODO: write a good test for GC. #[test] fn sorted_insert() { - // Create an in-memory database let file = tempfile::NamedTempFile::new().unwrap(); let db = Database::create(file.path()).unwrap(); - let mut treap = HashTreap::new(&db); + let mut treap = HashTreap::new(&db, "test"); let mut keys = [ "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", @@ -364,11 +154,10 @@ mod test { #[test] fn unsorted_insert() { - // Create an in-memory database let file = tempfile::NamedTempFile::new().unwrap(); let db = Database::create(file.path()).unwrap(); - let mut treap = HashTreap::new(&db); + let mut treap = HashTreap::new(&db, "test"); // TODO: fix this cases let mut keys = [ @@ -399,7 +188,7 @@ mod test { let file = tempfile::NamedTempFile::new().unwrap(); let db = Database::create(file.path()).unwrap(); - let mut treap = HashTreap::new(&db); + let mut treap = HashTreap::new(&db, "test"); let mut keys = ["X", "X"]; @@ -413,4 +202,24 @@ mod test { println!("{}", treap.as_mermaid_graph()) } + + #[test] + fn upsert_deeper_than_root() { + let file = tempfile::NamedTempFile::new().unwrap(); + let db = Database::create(file.path()).unwrap(); + + let mut treap = HashTreap::new(&db, "test"); + + let mut keys = ["F", "X", "X"]; + + for key in keys.iter() { + treap.insert(key.as_bytes(), b"0"); + } + + assert!(treap.verify_ranks(), "Ranks are not correct"); + + // TODO: check the value. + + println!("{}", treap.as_mermaid_graph()) + } }