Merge pull request #51 from penberg/moreskipmap

database: migrate RefCell<HashMap> and RefCell<BTreeMap> to SkipMap
This commit is contained in:
Piotr Sarna
2023-06-06 09:58:37 +02:00
committed by GitHub
2 changed files with 133 additions and 79 deletions

View File

@@ -1,11 +1,11 @@
use crate::clock::LogicalClock;
use crate::errors::DatabaseError;
use crate::persistent_storage::Storage;
use crossbeam_skiplist::SkipMap;
use crossbeam_skiplist::{SkipMap, SkipSet};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
@@ -66,7 +66,7 @@ enum TxTimestampOrID {
}
/// Transaction
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Transaction {
/// The state of the transaction.
state: TransactionState,
@@ -75,9 +75,55 @@ pub struct Transaction {
/// The transaction begin timestamp.
begin_ts: u64,
/// The transaction write set.
write_set: HashSet<RowID>,
#[serde(with = "skipset_rowid")]
write_set: SkipSet<RowID>,
/// The transaction read set.
read_set: RefCell<HashSet<RowID>>,
#[serde(with = "skipset_rowid")]
read_set: SkipSet<RowID>,
}
mod skipset_rowid {
use super::*;
use serde::{de, ser, ser::SerializeSeq};
struct SkipSetDeserializer;
impl<'de> serde::de::Visitor<'de> for SkipSetDeserializer {
type Value = SkipSet<RowID>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("SkipSet<RowID> key value sequence.")
}
fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let new_skipset = SkipSet::new();
while let Some(elem) = seq.next_element()? {
new_skipset.insert(elem);
}
Ok(new_skipset)
}
}
pub fn serialize<S: ser::Serializer>(
value: &SkipSet<RowID>,
ser: S,
) -> std::result::Result<S::Ok, S::Error> {
let mut set = ser.serialize_seq(Some(value.len()))?;
for v in value {
set.serialize_element(v.value())?;
}
set.end()
}
pub fn deserialize<'de, D: de::Deserializer<'de>>(
de: D,
) -> std::result::Result<SkipSet<RowID>, D::Error> {
de.deserialize_seq(SkipSetDeserializer)
}
}
impl Transaction {
@@ -86,14 +132,13 @@ impl Transaction {
state: TransactionState::Active,
tx_id,
begin_ts,
write_set: HashSet::new(),
read_set: RefCell::new(HashSet::new()),
write_set: SkipSet::new(),
read_set: SkipSet::new(),
}
}
fn insert_to_read_set(&self, id: RowID) {
let mut read_set = self.read_set.borrow_mut();
read_set.insert(id);
self.read_set.insert(id);
}
fn insert_to_write_set(&mut self, id: RowID) {
@@ -103,18 +148,21 @@ impl Transaction {
impl std::fmt::Display for Transaction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
match self.read_set.try_borrow() {
Ok(read_set) => write!(
f,
"{{ id: {}, begin_ts: {}, write_set: {:?}, read_set: {:?} }}",
self.tx_id, self.begin_ts, self.write_set, read_set
),
Err(_) => write!(
f,
"{{ id: {}, begin_ts: {}, write_set: {:?}, read_set: <borrowed> }}",
self.tx_id, self.begin_ts, self.write_set
),
}
write!(
f,
"{{ id: {}, begin_ts: {}, write_set: {:?}, read_set: {:?}",
self.tx_id,
self.begin_ts,
// FIXME: I'm sorry, we obviously shouldn't be cloning here.
self.write_set
.iter()
.map(|v| *v.value())
.collect::<Vec<RowID>>(),
self.read_set
.iter()
.map(|v| *v.value())
.collect::<Vec<RowID>>()
)
}
}
@@ -138,8 +186,8 @@ impl<Clock: LogicalClock> Database<Clock> {
/// Creates a new database.
pub fn new(clock: Clock, storage: Storage) -> Self {
let inner = DatabaseInner {
rows: RefCell::new(SkipMap::new()),
txs: RefCell::new(HashMap::new()),
rows: SkipMap::new(),
txs: SkipMap::new(),
tx_timestamps: RefCell::new(BTreeMap::new()),
tx_ids: AtomicU64::new(1), // let's reserve transaction 0 for special purposes
clock,
@@ -293,8 +341,8 @@ impl<Clock: LogicalClock> Database<Clock> {
#[derive(Debug)]
pub struct DatabaseInner<Clock: LogicalClock> {
rows: RefCell<SkipMap<RowID, RwLock<Vec<RowVersion>>>>,
txs: RefCell<HashMap<TxID, Transaction>>,
rows: SkipMap<RowID, RwLock<Vec<RowVersion>>>,
txs: SkipMap<TxID, RwLock<Transaction>>,
tx_timestamps: RefCell<BTreeMap<u64, usize>>,
tx_ids: AtomicU64,
clock: Clock,
@@ -303,10 +351,11 @@ pub struct DatabaseInner<Clock: LogicalClock> {
impl<Clock: LogicalClock> DatabaseInner<Clock> {
fn insert(&self, tx_id: TxID, row: Row) -> Result<()> {
let mut txs = self.txs.borrow_mut();
let tx = txs
.get_mut(&tx_id)
let tx = self
.txs
.get(&tx_id)
.ok_or(DatabaseError::NoSuchTransactionID(tx_id))?;
let mut tx = tx.value().write().unwrap();
assert!(tx.state == TransactionState::Active);
let id = row.id;
let row_version = RowVersion {
@@ -314,41 +363,39 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
end: None,
row,
};
let rows = self.rows.borrow_mut();
let versions = rows.get_or_insert_with(id, || RwLock::new(Vec::new()));
let versions = self.rows.get_or_insert_with(id, || RwLock::new(Vec::new()));
let mut versions = versions.value().write().unwrap();
versions.push(row_version);
tx.insert_to_write_set(id);
Ok(())
}
#[allow(clippy::await_holding_refcell_ref)]
fn delete(&self, tx_id: TxID, id: RowID) -> Result<bool> {
// NOTICE: They *are* dropped before an await point!!! But the await is conditional,
// so I think clippy is just confused.
let mut txs = self.txs.borrow_mut();
let rows = self.rows.borrow_mut();
let row_versions_opt = rows.get(&id);
let row_versions_opt = self.rows.get(&id);
if let Some(ref row_versions) = row_versions_opt {
let mut row_versions = row_versions.value().write().unwrap();
for rv in row_versions.iter_mut().rev() {
let tx = txs
let tx = self
.txs
.get(&tx_id)
.ok_or(DatabaseError::NoSuchTransactionID(tx_id))?;
let tx = tx.value().read().unwrap();
assert!(tx.state == TransactionState::Active);
if is_write_write_conflict(&txs, tx, rv) {
drop(txs);
if is_write_write_conflict(&self.txs, &tx, rv) {
drop(row_versions);
drop(row_versions_opt);
drop(rows);
drop(tx);
self.rollback_tx(tx_id);
return Err(DatabaseError::WriteWriteConflict);
}
if is_version_visible(&txs, tx, rv) {
if is_version_visible(&self.txs, &tx, rv) {
rv.end = Some(TxTimestampOrID::TxID(tx.tx_id));
let tx = txs
.get_mut(&tx_id)
drop(tx); // FIXME: maybe just grab the write lock above? Do we ever expect conflicts?
let tx = self
.txs
.get(&tx_id)
.ok_or(DatabaseError::NoSuchTransactionID(tx_id))?;
let mut tx = tx.value().write().unwrap();
tx.insert_to_write_set(id);
return Ok(true);
}
@@ -358,14 +405,13 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
}
fn read(&self, tx_id: TxID, id: RowID) -> Result<Option<Row>> {
let txs = self.txs.borrow_mut();
let tx = txs.get(&tx_id).unwrap();
let tx = self.txs.get(&tx_id).unwrap();
let tx = tx.value().read().unwrap();
assert!(tx.state == TransactionState::Active);
let rows = self.rows.borrow();
if let Some(row_versions) = rows.get(&id) {
if let Some(row_versions) = self.rows.get(&id) {
let row_versions = row_versions.value().read().unwrap();
for rv in row_versions.iter().rev() {
if is_version_visible(&txs, tx, rv) {
if is_version_visible(&self.txs, &tx, rv) {
tx.insert_to_read_set(id);
return Ok(Some(rv.row.clone()));
}
@@ -375,14 +421,13 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
}
fn scan_row_ids(&self) -> Result<Vec<RowID>> {
let rows = self.rows.borrow();
let keys = rows.iter().map(|entry| *entry.key());
let keys = self.rows.iter().map(|entry| *entry.key());
Ok(keys.collect())
}
fn scan_row_ids_for_table(&self, table_id: u64) -> Result<Vec<RowID>> {
let rows = &self.rows.borrow();
Ok(rows
Ok(self
.rows
.range(
RowID {
table_id,
@@ -401,29 +446,28 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
let begin_ts = self.get_timestamp();
let tx = Transaction::new(tx_id, begin_ts);
tracing::trace!("BEGIN {tx}");
let mut txs = self.txs.borrow_mut();
let mut tx_timestamps = self.tx_timestamps.borrow_mut();
txs.insert(tx_id, tx);
self.txs.insert(tx_id, RwLock::new(tx));
*tx_timestamps.entry(begin_ts).or_insert(0) += 1;
tx_id
}
fn commit_tx(&mut self, tx_id: TxID) -> Result<()> {
let end_ts = self.get_timestamp();
let mut txs = self.txs.borrow_mut();
let tx = txs.get_mut(&tx_id).unwrap();
let tx = self.txs.get(&tx_id).unwrap();
let mut tx = tx.value().write().unwrap();
match tx.state {
TransactionState::Terminated => return Err(DatabaseError::TxTerminated),
_ => {
assert!(tx.state == TransactionState::Active);
}
}
let rows = self.rows.borrow_mut();
tx.state = TransactionState::Preparing;
tracing::trace!("PREPARE {tx}");
let mut log_record: LogRecord = LogRecord::new(end_ts);
for id in &tx.write_set {
if let Some(row_versions) = rows.get(id) {
let id = id.value();
if let Some(row_versions) = self.rows.get(id) {
let mut row_versions = row_versions.value().write().unwrap();
for row_version in row_versions.iter_mut() {
if let TxTimestampOrID::TxID(id) = row_version.begin {
@@ -456,9 +500,7 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
tx_timestamps.remove(&tx.begin_ts);
}
}
txs.remove(&tx_id);
drop(rows);
drop(txs);
self.txs.remove(&tx_id);
if !log_record.row_versions.is_empty() {
self.storage.log_tx(log_record)?;
}
@@ -466,18 +508,18 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
}
fn rollback_tx(&self, tx_id: TxID) {
let mut txs = self.txs.borrow_mut();
let tx = txs.get_mut(&tx_id).unwrap();
let tx = self.txs.get(&tx_id).unwrap();
let mut tx = tx.value().write().unwrap();
assert!(tx.state == TransactionState::Active);
tx.state = TransactionState::Aborted;
tracing::trace!("ABORT {tx}");
let rows = self.rows.borrow_mut();
for id in &tx.write_set {
if let Some(row_versions) = rows.get(id) {
let id = id.value();
if let Some(row_versions) = self.rows.get(id) {
let mut row_versions = row_versions.value().write().unwrap();
row_versions.retain(|rv| rv.begin != TxTimestampOrID::TxID(tx_id));
if row_versions.is_empty() {
rows.remove(id);
self.rows.remove(id);
}
}
}
@@ -504,11 +546,9 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
/// We can do better by keeping an index of row versions ordered
/// by their end timestamps.
fn drop_unused_row_versions(&self) {
let txs = self.txs.borrow();
let tx_timestamps = self.tx_timestamps.borrow();
let rows = self.rows.borrow_mut();
let mut to_remove = Vec::new();
for entry in rows.iter() {
for entry in self.rows.iter() {
let mut row_versions = entry.value().write().unwrap();
row_versions.retain(|rv| {
let should_stay = match rv.end {
@@ -524,7 +564,7 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
// Let's skip potentially complex logic if the transaction is still
// active/tracked. We will drop the row version when the transaction
// gets garbage-collected itself, it will always happen eventually.
Some(TxTimestampOrID::TxID(tx_id)) => !txs.contains_key(&tx_id),
Some(TxTimestampOrID::TxID(tx_id)) => !self.txs.contains_key(&tx_id),
// this row version is current, ergo visible
None => true,
};
@@ -543,7 +583,7 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
}
}
for id in to_remove {
rows.remove(&id);
self.rows.remove(&id);
}
}
@@ -552,9 +592,9 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
for record in tx_log {
tracing::debug!("RECOVERING {:?}", record);
for version in record.row_versions {
let rows = self.rows.borrow_mut();
let row_versions =
rows.get_or_insert_with(version.row.id, || RwLock::new(Vec::new()));
let row_versions = self
.rows
.get_or_insert_with(version.row.id, || RwLock::new(Vec::new()));
let mut row_versions = row_versions.value().write().unwrap();
row_versions.push(version);
}
@@ -567,13 +607,14 @@ impl<Clock: LogicalClock> DatabaseInner<Clock> {
/// A write-write conflict happens when transaction T_m attempts to update a
/// row version that is currently being updated by an active transaction T_n.
fn is_write_write_conflict(
txs: &HashMap<TxID, Transaction>,
txs: &SkipMap<TxID, RwLock<Transaction>>,
tx: &Transaction,
rv: &RowVersion,
) -> bool {
match rv.end {
Some(TxTimestampOrID::TxID(rv_end)) => {
let te = txs.get(&rv_end).unwrap();
let te = te.value().read().unwrap();
match te.state {
TransactionState::Active => tx.tx_id != te.tx_id,
TransactionState::Preparing => todo!(),
@@ -587,15 +628,24 @@ fn is_write_write_conflict(
}
}
fn is_version_visible(txs: &HashMap<TxID, Transaction>, tx: &Transaction, rv: &RowVersion) -> bool {
fn is_version_visible(
txs: &SkipMap<TxID, RwLock<Transaction>>,
tx: &Transaction,
rv: &RowVersion,
) -> bool {
is_begin_visible(txs, tx, rv) && is_end_visible(txs, tx, rv)
}
fn is_begin_visible(txs: &HashMap<TxID, Transaction>, tx: &Transaction, rv: &RowVersion) -> bool {
fn is_begin_visible(
txs: &SkipMap<TxID, RwLock<Transaction>>,
tx: &Transaction,
rv: &RowVersion,
) -> bool {
match rv.begin {
TxTimestampOrID::Timestamp(rv_begin_ts) => tx.begin_ts >= rv_begin_ts,
TxTimestampOrID::TxID(rv_begin) => {
let tb = txs.get(&rv_begin).unwrap();
let tb = tb.value().read().unwrap();
match tb.state {
TransactionState::Active => tx.tx_id == tb.tx_id && rv.end.is_none(),
TransactionState::Preparing => todo!(),
@@ -607,11 +657,16 @@ fn is_begin_visible(txs: &HashMap<TxID, Transaction>, tx: &Transaction, rv: &Row
}
}
fn is_end_visible(txs: &HashMap<TxID, Transaction>, tx: &Transaction, rv: &RowVersion) -> bool {
fn is_end_visible(
txs: &SkipMap<TxID, RwLock<Transaction>>,
tx: &Transaction,
rv: &RowVersion,
) -> bool {
match rv.end {
Some(TxTimestampOrID::Timestamp(rv_end_ts)) => tx.begin_ts < rv_end_ts,
Some(TxTimestampOrID::TxID(rv_end)) => {
let te = txs.get(&rv_end).unwrap();
let te = te.value().read().unwrap();
match te.state {
TransactionState::Active => tx.tx_id != te.tx_id,
TransactionState::Preparing => todo!(),

View File

@@ -1,4 +1,3 @@
use super::*;
use crate::clock::LocalClock;
use tracing_test::traced_test;