From 9e8a483551d2394b46d96c832f7fd9ad31404e65 Mon Sep 17 00:00:00 2001 From: nazeh Date: Wed, 20 Dec 2023 21:05:55 +0300 Subject: [PATCH] wip: step closer to finishing insert --- mast/src/node.rs | 153 ++++++++++++++++------------------ mast/src/operations/gc.rs | 56 ------------- mast/src/operations/insert.rs | 96 +++++++++++++-------- mast/src/operations/mod.rs | 3 +- mast/src/treap.rs | 61 ++++++++++---- 5 files changed, 179 insertions(+), 190 deletions(-) delete mode 100644 mast/src/operations/gc.rs diff --git a/mast/src/node.rs b/mast/src/node.rs index a67798e..68a313b 100644 --- a/mast/src/node.rs +++ b/mast/src/node.rs @@ -37,55 +37,42 @@ enum RefCountDiff { } impl Node { - pub fn new(key: &[u8], value: &[u8]) -> Self { - Self { + pub(crate) fn open<'a>( + table: &'a impl ReadableTable<&'static [u8], (u64, &'static [u8])>, + hash: Hash, + ) -> Option { + let mut existing = table.get(hash.as_bytes().as_slice()).unwrap(); + + existing.map(|existing| { + let (ref_count, bytes) = { + let (r, v) = existing.value(); + (r, v.to_vec()) + }; + drop(existing); + + decode_node((ref_count, &bytes)) + }) + } + + pub(crate) fn insert(table: &mut Table<&[u8], (u64, &[u8])>, key: &[u8], value: &[u8]) -> Hash { + let node = Self { key: key.into(), value: value.into(), left: None, right: None, - ref_count: 0, - } - } - - pub fn decode(data: (u64, &[u8])) -> Node { - let (ref_count, encoded_node) = data; - - let (key, rest) = decode(encoded_node); - let (value, rest) = decode(rest); - - let (left, rest) = decode(rest); - let left = match left.len() { - 0 => None, - 32 => { - let bytes: [u8; HASH_LEN] = left.try_into().unwrap(); - Some(Hash::from_bytes(bytes)) - } - _ => { - panic!("invalid hash length!") - } + ref_count: 1, }; - let (right, _) = decode(rest); - let right = match right.len() { - 0 => None, - 32 => { - let bytes: [u8; HASH_LEN] = right.try_into().unwrap(); - Some(Hash::from_bytes(bytes)) - } - _ => { - panic!("invalid hash length!") - } - }; + let encoded = node.canonical_encode(); + let hash = hash(&encoded); - Node { - key: key.into(), - value: value.into(), - left, - right, + table.insert( + hash.as_bytes().as_slice(), + (node.ref_count, encoded.as_slice()), + ); - ref_count, - } + hash } // === Getters === @@ -117,20 +104,16 @@ impl Node { hash(&self.canonical_encode()) } - pub(crate) fn decrement_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) {} - - pub(crate) fn save(&self, table: &mut Table<&[u8], (u64, &[u8])>) { - let encoded = self.canonical_encode(); - let hash = hash(&encoded); - - table.insert( - hash.as_bytes().as_slice(), - (self.ref_count, encoded.as_slice()), - ); + pub(crate) fn decrement_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) { + self.update_ref_count(table, RefCountDiff::Decrement) } // === Private Methods === + fn increment_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) { + self.update_ref_count(table, RefCountDiff::Increment) + } + fn update_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>, diff: RefCountDiff) { let ref_count = match diff { RefCountDiff::Increment => self.ref_count + 1, @@ -175,36 +158,6 @@ pub(crate) fn rank(key: &[u8]) -> Hash { hash(key) } -/// Returns the node for a given hash. -pub(crate) fn get_node<'a>( - table: &'a impl ReadableTable<&'static [u8], (u64, &'static [u8])>, - hash: &[u8], -) -> Option { - let existing = table.get(hash).unwrap(); - - if existing.is_none() { - return None; - } - let data = existing.unwrap(); - - Some(Node::decode(data.value())) -} - -/// Returns the root hash for a given table. -pub(crate) fn get_root_hash<'a>( - table: &'a impl ReadableTable<&'static [u8], &'static [u8]>, - name: &str, -) -> Option { - let existing = table.get(name.as_bytes()).unwrap(); - if existing.is_none() { - return None; - } - let hash = existing.unwrap(); - - let hash: [u8; HASH_LEN] = hash.value().try_into().expect("Invalid root hash"); - Some(Hash::from_bytes(hash)) -} - fn encode(bytes: &[u8], out: &mut Vec) { // TODO: find a better way to reserve bytes. let current_len = out.len(); @@ -230,3 +183,43 @@ fn hash(bytes: &[u8]) -> Hash { hasher.finalize() } + +pub fn decode_node(data: (u64, &[u8])) -> Node { + let (ref_count, encoded_node) = data; + + let (key, rest) = decode(encoded_node); + let (value, rest) = decode(rest); + + let (left, rest) = decode(rest); + let left = match left.len() { + 0 => None, + 32 => { + let bytes: [u8; HASH_LEN] = left.try_into().unwrap(); + Some(Hash::from_bytes(bytes)) + } + _ => { + panic!("invalid hash length!") + } + }; + + let (right, _) = decode(rest); + let right = match right.len() { + 0 => None, + 32 => { + let bytes: [u8; HASH_LEN] = right.try_into().unwrap(); + Some(Hash::from_bytes(bytes)) + } + _ => { + panic!("invalid hash length!") + } + }; + + Node { + key: key.into(), + value: value.into(), + left, + right, + + ref_count, + } +} diff --git a/mast/src/operations/gc.rs b/mast/src/operations/gc.rs deleted file mode 100644 index 49f6fd9..0000000 --- a/mast/src/operations/gc.rs +++ /dev/null @@ -1,56 +0,0 @@ -use blake3::Hash; -use redb::{Database, ReadableTable, Table, TableDefinition, WriteTransaction}; - -#[derive(Debug)] -enum RefCountDiff { - Increment, - Decrement, -} - -pub(crate) fn increment_ref_count(node: Option, table: &mut Table<&[u8], (u64, &[u8])>) { - update_ref_count(node, RefCountDiff::Increment, table); -} - -pub(crate) fn decrement_ref_count(node: Option, table: &mut Table<&[u8], (u64, &[u8])>) { - update_ref_count(node, RefCountDiff::Decrement, table); -} - -fn update_ref_count( - node: Option, - ref_diff: RefCountDiff, - table: &mut Table<&[u8], (u64, &[u8])>, -) { - if let Some(hash) = node { - let mut existing = table - .get(hash.as_bytes().as_slice()) - .unwrap() - .expect("node shouldn't be messing!"); - - let (ref_count, bytes) = { - let (r, v) = existing.value(); - (r, v.to_vec()) - }; - drop(existing); - - let ref_count = match ref_diff { - RefCountDiff::Increment => ref_count + 1, - RefCountDiff::Decrement => { - if ref_count > 0 { - ref_count - 1 - } else { - ref_count - } - } - }; - - match ref_count { - 0 => { - // TODO: Confirm (read: test) this, because it is not easy to see in graphs. - table.remove(hash.as_bytes().as_slice()); - } - _ => { - table.insert(hash.as_bytes().as_slice(), (ref_count, bytes.as_slice())); - } - } - } -} diff --git a/mast/src/operations/insert.rs b/mast/src/operations/insert.rs index 0152450..69369a7 100644 --- a/mast/src/operations/insert.rs +++ b/mast/src/operations/insert.rs @@ -1,4 +1,4 @@ -use crate::node::{get_node, get_root_hash, rank, Branch, Node}; +use crate::node::{rank, Branch, Node}; use crate::treap::{HashTreap, NODES_TABLE, ROOTS_TABLE}; use crate::HASH_LEN; use blake3::Hash; @@ -81,61 +81,83 @@ use redb::{Database, ReadTransaction, ReadableTable, Table, TableDefinition, Wri // The simplest way to do so, is to decrement all the nodes in the search path, and then increment // all then new nodes (in both the upper and lower paths) before comitting the write transaction. -impl<'treap> HashTreap<'treap> { - pub fn insert(&mut self, key: &[u8], value: &[u8]) { - // TODO: validate key and value length. +pub fn insert<'a>( + table: &'a mut Table<&'static [u8], (u64, &'static [u8])>, + root: Option, + key: &[u8], + value: &[u8], +) { + let mut path = binary_search_path(table, root, key); - let write_txn = self.db.begin_write().unwrap(); + path.iter_mut() + .for_each(|(node, _)| node.decrement_ref_count(table)); - 'transaction: { - let roots_table = write_txn.open_table(ROOTS_TABLE).unwrap(); - let mut nodes_table = write_txn.open_table(NODES_TABLE).unwrap(); + let mut unzipped_left: Option = None; + let mut unzipped_right: Option = None; + let mut upper_path: Option = None; - let root = get_root_hash(&roots_table, &self.name); + let rank = rank(key); - let mut path = upper_path(key, root, &nodes_table); + for (node, branch) in path.iter().rev() { + match node.rank().as_bytes().cmp(&rank.as_bytes()) { + std::cmp::Ordering::Equal => { + // We found an exact match, we update the value and proceed. - path.iter_mut() - .for_each(|node| node.decrement_ref_count(&mut nodes_table)) - - // if path. - }; - - // Finally commit the changes to the storage. - write_txn.commit().unwrap(); + upper_path = Some(Node::insert(table, key, value)) + } + std::cmp::Ordering::Less => { + // previous_hash = *current_node.left(); + // + // path.push((current_node, Branch::Left)); + } + std::cmp::Ordering::Greater => { + // previous_hash = *current_node.right(); + // + // path.push((current_node, Branch::Right)); + } + } } + + // if let Some((node, _)) = path.last_mut() { + // // If the last node is an exact match + // } else { + // // handle lower path + // } } /// Returns the current nodes from the root to the insertion point on the binary search path. -fn upper_path<'a>( - key: &[u8], - root: Option, +fn binary_search_path<'a>( nodes_table: &'a impl ReadableTable<&'static [u8], (u64, &'static [u8])>, -) -> Vec { - let rank = rank(key); - - let mut path: Vec = Vec::new(); + root: Option, + key: &[u8], +) -> Vec<(Node, Branch)> { + let mut path: Vec<(Node, Branch)> = Vec::new(); let mut previous_hash = root; while let Some(current_hash) = previous_hash { - let current_node = get_node(nodes_table, current_hash.as_bytes()).expect("Node not found!"); + let current_node = Node::open(nodes_table, current_hash).expect("Node not found!"); let current_key = current_node.key(); - if key == current_key { - // We found an exact match, we don't need to unzip the rest. - path.push(current_node); - break; - } + match key.cmp(current_key) { + std::cmp::Ordering::Equal => { + // We found an exact match, we don't need to unzip the rest. + // Branch here doesn't matter + path.push((current_node, Branch::Left)); + break; + } + std::cmp::Ordering::Less => { + previous_hash = *current_node.left(); - if key < current_key { - previous_hash = *current_node.left(); - } else { - previous_hash = *current_node.right(); - } + path.push((current_node, Branch::Left)); + } + std::cmp::Ordering::Greater => { + previous_hash = *current_node.right(); - path.push(current_node); + path.push((current_node, Branch::Right)); + } + } } path diff --git a/mast/src/operations/mod.rs b/mast/src/operations/mod.rs index 20b6b94..0173d2e 100644 --- a/mast/src/operations/mod.rs +++ b/mast/src/operations/mod.rs @@ -1,2 +1 @@ -pub mod gc; -mod insert; +pub mod insert; diff --git a/mast/src/treap.rs b/mast/src/treap.rs index ac45833..40a64c2 100644 --- a/mast/src/treap.rs +++ b/mast/src/treap.rs @@ -1,7 +1,7 @@ use blake3::Hash; use redb::*; -use crate::node::{get_node, get_root_hash, Node}; +use crate::{node::Node, HASH_LEN}; // Table: Nodes v0 // stores all the hash treap nodes from all the treaps in the storage. @@ -49,8 +49,43 @@ impl<'treap> HashTreap<'treap> { let roots_table = read_txn.open_table(ROOTS_TABLE).unwrap(); let nodes_table = read_txn.open_table(NODES_TABLE).unwrap(); - let hash = get_root_hash(&roots_table, self.name); - hash.and_then(|hash| get_node(&nodes_table, hash.as_bytes())) + self.root_hash(&roots_table) + .and_then(|hash| Node::open(&nodes_table, hash)) + } + + fn root_hash<'a>( + &self, + table: &'a impl ReadableTable<&'static [u8], &'static [u8]>, + ) -> Option { + let existing = table.get(self.name.as_bytes()).unwrap(); + if existing.is_none() { + return None; + } + let hash = existing.unwrap(); + + let hash: [u8; HASH_LEN] = hash.value().try_into().expect("Invalid root hash"); + + Some(Hash::from_bytes(hash)) + } + + // === Public Methods === + + pub fn insert(&mut self, key: &[u8], value: &[u8]) { + // TODO: validate key and value length. + + let write_txn = self.db.begin_write().unwrap(); + + 'transaction: { + let roots_table = write_txn.open_table(ROOTS_TABLE).unwrap(); + let mut nodes_table = write_txn.open_table(NODES_TABLE).unwrap(); + + let root = self.root_hash(&roots_table); + + crate::operations::insert::insert(&mut nodes_table, root, key, value) + }; + + // Finally commit the changes to the storage. + write_txn.commit().unwrap(); } // === Private Methods === @@ -60,12 +95,7 @@ impl<'treap> HashTreap<'treap> { let read_txn = self.db.begin_read().unwrap(); let table = read_txn.open_table(NODES_TABLE).unwrap(); - hash.and_then(|h| { - table - .get(h.as_bytes().as_slice()) - .unwrap() - .map(|existing| Node::decode(existing.value())) - }) + hash.and_then(|h| Node::open(&table, h)) } // === Test Methods === @@ -112,12 +142,13 @@ impl<'treap> HashTreap<'treap> { data = existing.1.value(); } - println!( - "HEre is a node key:{:?} ref_count:{:?} node:{:?}", - Hash::from_bytes(key.try_into().unwrap()), - data.0, - Node::decode(data) - ); + // TODO: iterate over nodes + // println!( + // "HEre is a node key:{:?} ref_count:{:?} node:{:?}", + // Hash::from_bytes(key.try_into().unwrap()), + // data.0, + // Node::open(data) + // ); } } }