Files
turso/core/incremental/merge_operator.rs
Glauber Costa b419db489a Implement the DBSP merge operator
The Merge operator is a stateless operator that merges two deltas.
There are two modes: Distinct, where we merge together values that
are the same, and All, where we preserve all values. We use the rowid of
the hashable row to guarantee that: In Distinct mode, the rowid is set
to 0 in both sides. If they values are the same, they will hash to the
same thing. For All, the rowids are different.

The merge operator is used for the UNION statement, which is a
cornerstone of Recursive CTEs.
2025-09-21 21:00:27 -03:00

188 lines
7.1 KiB
Rust

// Merge operator for DBSP - combines two delta streams
// Used in recursive CTEs and UNION operations
use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow};
use crate::incremental::operator::{
ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator,
};
use crate::types::IOResult;
use crate::Result;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::fmt::{self, Display};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
/// How the merge operator should handle rowids when combining deltas
#[derive(Debug, Clone)]
pub enum UnionMode {
/// For UNION (distinct) - hash values only to merge duplicates
Distinct,
/// For UNION ALL - include source table name in hash to keep duplicates separate
All {
left_table: String,
right_table: String,
},
}
/// Merge operator that combines two input deltas into one output delta
/// Handles both recursive CTEs and UNION/UNION ALL operations
#[derive(Debug)]
pub struct MergeOperator {
operator_id: usize,
union_mode: UnionMode,
/// For UNION: tracks seen value hashes with their assigned rowids
/// For UNION ALL: tracks (source_id, original_rowid) -> assigned_rowid mappings
seen_rows: HashMap<u64, i64>, // hash -> assigned_rowid
/// Next rowid to assign for new rows
next_rowid: i64,
}
impl MergeOperator {
/// Create a new merge operator with specified union mode
pub fn new(operator_id: usize, mode: UnionMode) -> Self {
Self {
operator_id,
union_mode: mode,
seen_rows: HashMap::new(),
next_rowid: 1,
}
}
/// Transform a delta's rowids based on the union mode with state tracking
fn transform_delta(&mut self, delta: Delta, is_left: bool) -> Delta {
match &self.union_mode {
UnionMode::Distinct => {
// For UNION distinct, track seen values and deduplicate
let mut output = Delta::new();
for (row, weight) in delta.changes {
// Hash only the values (not rowid) for deduplication
let temp_row = HashableRow::new(0, row.values.clone());
let value_hash = temp_row.cached_hash();
// Check if we've seen this value before
let assigned_rowid =
if let Some(&existing_rowid) = self.seen_rows.get(&value_hash) {
// Value already seen - use existing rowid
existing_rowid
} else {
// New value - assign new rowid and remember it
let new_rowid = self.next_rowid;
self.next_rowid += 1;
self.seen_rows.insert(value_hash, new_rowid);
new_rowid
};
// Output the row with the assigned rowid
let final_row = HashableRow::new(assigned_rowid, temp_row.values);
output.changes.push((final_row, weight));
}
output
}
UnionMode::All {
left_table,
right_table,
} => {
// For UNION ALL, maintain consistent rowid mapping per source
let table = if is_left { left_table } else { right_table };
let mut source_hasher = DefaultHasher::new();
table.hash(&mut source_hasher);
let source_id = source_hasher.finish();
let mut output = Delta::new();
for (row, weight) in delta.changes {
// Create a unique key for this (source, rowid) pair
let mut key_hasher = DefaultHasher::new();
source_id.hash(&mut key_hasher);
row.rowid.hash(&mut key_hasher);
let key_hash = key_hasher.finish();
// Check if we've seen this (source, rowid) before
let assigned_rowid =
if let Some(&existing_rowid) = self.seen_rows.get(&key_hash) {
// Use existing rowid for this (source, rowid) pair
existing_rowid
} else {
// New row - assign new rowid
let new_rowid = self.next_rowid;
self.next_rowid += 1;
self.seen_rows.insert(key_hash, new_rowid);
new_rowid
};
// Create output row with consistent rowid
let final_row = HashableRow::new(assigned_rowid, row.values.clone());
output.changes.push((final_row, weight));
}
output
}
}
}
}
impl Display for MergeOperator {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.union_mode {
UnionMode::Distinct => write!(f, "MergeOperator({}, UNION)", self.operator_id),
UnionMode::All { .. } => write!(f, "MergeOperator({}, UNION ALL)", self.operator_id),
}
}
}
impl IncrementalOperator for MergeOperator {
fn eval(
&mut self,
input: &mut EvalState,
_cursors: &mut DbspStateCursors,
) -> Result<IOResult<Delta>> {
match input {
EvalState::Init { deltas } => {
// Extract deltas from the evaluation state
let delta_pair = std::mem::take(deltas);
// Transform deltas based on union mode (with state tracking)
let left_transformed = self.transform_delta(delta_pair.left, true);
let right_transformed = self.transform_delta(delta_pair.right, false);
// Merge the transformed deltas
let mut output = Delta::new();
output.merge(&left_transformed);
output.merge(&right_transformed);
// Move to Done state
*input = EvalState::Done;
Ok(IOResult::Done(output))
}
EvalState::Aggregate(_) | EvalState::Join(_) | EvalState::Uninitialized => {
// Merge operator only handles Init state
unreachable!("MergeOperator only handles Init state")
}
EvalState::Done => {
// Already evaluated
Ok(IOResult::Done(Delta::new()))
}
}
}
fn commit(
&mut self,
deltas: DeltaPair,
_cursors: &mut DbspStateCursors,
) -> Result<IOResult<Delta>> {
// Transform deltas based on union mode
let left_transformed = self.transform_delta(deltas.left, true);
let right_transformed = self.transform_delta(deltas.right, false);
// Merge the transformed deltas
let mut output = Delta::new();
output.merge(&left_transformed);
output.merge(&right_transformed);
Ok(IOResult::Done(output))
}
fn set_tracker(&mut self, _tracker: Arc<Mutex<ComputationTracker>>) {
// Merge operator doesn't need tracking for now
}
}