mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-18 09:04:19 +01:00
788 lines
30 KiB
Rust
788 lines
30 KiB
Rust
#![allow(dead_code)]
|
||
|
||
use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow};
|
||
use crate::incremental::operator::{
|
||
generate_storage_id, ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator,
|
||
};
|
||
use crate::incremental::persistence::WriteRow;
|
||
use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult};
|
||
use crate::{return_and_restore_if_io, return_if_io, Result, Value};
|
||
use std::sync::{Arc, Mutex};
|
||
|
||
#[derive(Debug, Clone, PartialEq)]
|
||
pub enum JoinType {
|
||
Inner,
|
||
Left,
|
||
Right,
|
||
Full,
|
||
Cross,
|
||
}
|
||
|
||
// Helper function to read the next row from the BTree for joins
|
||
fn read_next_join_row(
|
||
storage_id: i64,
|
||
join_key: &HashableRow,
|
||
last_element_id: i64,
|
||
cursors: &mut DbspStateCursors,
|
||
) -> Result<IOResult<Option<(i64, HashableRow, isize)>>> {
|
||
// Build the index key: (storage_id, zset_id, element_id)
|
||
// zset_id is the hash of the join key
|
||
let zset_id = join_key.cached_hash() as i64;
|
||
|
||
let index_key_values = vec![
|
||
Value::Integer(storage_id),
|
||
Value::Integer(zset_id),
|
||
Value::Integer(last_element_id),
|
||
];
|
||
|
||
let index_record = ImmutableRecord::from_values(&index_key_values, index_key_values.len());
|
||
let seek_result = return_if_io!(cursors
|
||
.index_cursor
|
||
.seek(SeekKey::IndexKey(&index_record), SeekOp::GT));
|
||
|
||
if !matches!(seek_result, SeekResult::Found) {
|
||
return Ok(IOResult::Done(None));
|
||
}
|
||
|
||
// Check if we're still in the same (storage_id, zset_id) range
|
||
let current_record = return_if_io!(cursors.index_cursor.record());
|
||
|
||
// Extract all needed values from the record before dropping it
|
||
let (found_storage_id, found_zset_id, element_id) = if let Some(rec) = current_record {
|
||
let values = rec.get_values();
|
||
|
||
// Index has 4 values: storage_id, zset_id, element_id, rowid (appended by WriteRow)
|
||
if values.len() >= 3 {
|
||
let found_storage_id = match &values[0].to_owned() {
|
||
Value::Integer(id) => *id,
|
||
_ => return Ok(IOResult::Done(None)),
|
||
};
|
||
let found_zset_id = match &values[1].to_owned() {
|
||
Value::Integer(id) => *id,
|
||
_ => return Ok(IOResult::Done(None)),
|
||
};
|
||
let element_id = match &values[2].to_owned() {
|
||
Value::Integer(id) => *id,
|
||
_ => {
|
||
return Ok(IOResult::Done(None));
|
||
}
|
||
};
|
||
(found_storage_id, found_zset_id, element_id)
|
||
} else {
|
||
return Ok(IOResult::Done(None));
|
||
}
|
||
} else {
|
||
return Ok(IOResult::Done(None));
|
||
};
|
||
|
||
// Now we can safely check if we're in the right range
|
||
// If we've moved to a different storage_id or zset_id, we're done
|
||
if found_storage_id != storage_id || found_zset_id != zset_id {
|
||
return Ok(IOResult::Done(None));
|
||
}
|
||
|
||
// Now get the actual row from the table using the rowid from the index
|
||
let rowid = return_if_io!(cursors.index_cursor.rowid());
|
||
if let Some(rowid) = rowid {
|
||
return_if_io!(cursors
|
||
.table_cursor
|
||
.seek(SeekKey::TableRowId(rowid), SeekOp::GE { eq_only: true }));
|
||
|
||
let table_record = return_if_io!(cursors.table_cursor.record());
|
||
if let Some(rec) = table_record {
|
||
let table_values = rec.get_values();
|
||
// Table format: [storage_id, zset_id, element_id, value_blob, weight]
|
||
if table_values.len() >= 5 {
|
||
// Deserialize the row from the blob
|
||
let value_at_3 = table_values[3].to_owned();
|
||
let blob = match value_at_3 {
|
||
Value::Blob(ref b) => b,
|
||
_ => return Ok(IOResult::Done(None)),
|
||
};
|
||
|
||
// The blob contains the serialized HashableRow
|
||
// For now, let's deserialize it simply
|
||
let row = deserialize_hashable_row(blob)?;
|
||
|
||
let weight = match &table_values[4].to_owned() {
|
||
Value::Integer(w) => *w as isize,
|
||
_ => return Ok(IOResult::Done(None)),
|
||
};
|
||
|
||
return Ok(IOResult::Done(Some((element_id, row, weight))));
|
||
}
|
||
}
|
||
}
|
||
Ok(IOResult::Done(None))
|
||
}
|
||
|
||
// Join-specific eval states
|
||
#[derive(Debug)]
|
||
pub enum JoinEvalState {
|
||
ProcessDeltaJoin {
|
||
deltas: DeltaPair,
|
||
output: Delta,
|
||
},
|
||
ProcessLeftJoin {
|
||
deltas: DeltaPair,
|
||
output: Delta,
|
||
current_idx: usize,
|
||
last_row_scanned: i64,
|
||
},
|
||
ProcessRightJoin {
|
||
deltas: DeltaPair,
|
||
output: Delta,
|
||
current_idx: usize,
|
||
last_row_scanned: i64,
|
||
},
|
||
Done {
|
||
output: Delta,
|
||
},
|
||
}
|
||
|
||
impl JoinEvalState {
|
||
fn combine_rows(
|
||
left_row: &HashableRow,
|
||
left_weight: i64,
|
||
right_row: &HashableRow,
|
||
right_weight: i64,
|
||
output: &mut Delta,
|
||
) {
|
||
// Combine the rows
|
||
let mut combined_values = left_row.values.clone();
|
||
combined_values.extend(right_row.values.clone());
|
||
// Use hash of the combined values as rowid to ensure uniqueness
|
||
let temp_row = HashableRow::new(0, combined_values.clone());
|
||
let joined_rowid = temp_row.cached_hash() as i64;
|
||
let joined_row = HashableRow::new(joined_rowid, combined_values);
|
||
|
||
// Add to output with combined weight
|
||
let combined_weight = left_weight * right_weight;
|
||
output.changes.push((joined_row, combined_weight as isize));
|
||
}
|
||
|
||
fn process_join_state(
|
||
&mut self,
|
||
cursors: &mut DbspStateCursors,
|
||
left_key_indices: &[usize],
|
||
right_key_indices: &[usize],
|
||
left_storage_id: i64,
|
||
right_storage_id: i64,
|
||
) -> Result<IOResult<Delta>> {
|
||
loop {
|
||
match self {
|
||
JoinEvalState::ProcessDeltaJoin { deltas, output } => {
|
||
// Move to ProcessLeftJoin
|
||
*self = JoinEvalState::ProcessLeftJoin {
|
||
deltas: std::mem::take(deltas),
|
||
output: std::mem::take(output),
|
||
current_idx: 0,
|
||
last_row_scanned: i64::MIN,
|
||
};
|
||
}
|
||
JoinEvalState::ProcessLeftJoin {
|
||
deltas,
|
||
output,
|
||
current_idx,
|
||
last_row_scanned,
|
||
} => {
|
||
if *current_idx >= deltas.left.changes.len() {
|
||
*self = JoinEvalState::ProcessRightJoin {
|
||
deltas: std::mem::take(deltas),
|
||
output: std::mem::take(output),
|
||
current_idx: 0,
|
||
last_row_scanned: i64::MIN,
|
||
};
|
||
} else {
|
||
let (left_row, left_weight) = &deltas.left.changes[*current_idx];
|
||
// Extract join key using provided indices
|
||
let key_values: Vec<Value> = left_key_indices
|
||
.iter()
|
||
.map(|&idx| left_row.values.get(idx).cloned().unwrap_or(Value::Null))
|
||
.collect();
|
||
let left_key = HashableRow::new(0, key_values);
|
||
|
||
let next_row = return_if_io!(read_next_join_row(
|
||
right_storage_id,
|
||
&left_key,
|
||
*last_row_scanned,
|
||
cursors
|
||
));
|
||
match next_row {
|
||
Some((element_id, right_row, right_weight)) => {
|
||
Self::combine_rows(
|
||
left_row,
|
||
(*left_weight) as i64,
|
||
&right_row,
|
||
right_weight as i64,
|
||
output,
|
||
);
|
||
// Continue scanning with this left row
|
||
*self = JoinEvalState::ProcessLeftJoin {
|
||
deltas: std::mem::take(deltas),
|
||
output: std::mem::take(output),
|
||
current_idx: *current_idx,
|
||
last_row_scanned: element_id,
|
||
};
|
||
}
|
||
None => {
|
||
// No more matches for this left row, move to next
|
||
*self = JoinEvalState::ProcessLeftJoin {
|
||
deltas: std::mem::take(deltas),
|
||
output: std::mem::take(output),
|
||
current_idx: *current_idx + 1,
|
||
last_row_scanned: i64::MIN,
|
||
};
|
||
}
|
||
}
|
||
}
|
||
}
|
||
JoinEvalState::ProcessRightJoin {
|
||
deltas,
|
||
output,
|
||
current_idx,
|
||
last_row_scanned,
|
||
} => {
|
||
if *current_idx >= deltas.right.changes.len() {
|
||
*self = JoinEvalState::Done {
|
||
output: std::mem::take(output),
|
||
};
|
||
} else {
|
||
let (right_row, right_weight) = &deltas.right.changes[*current_idx];
|
||
// Extract join key using provided indices
|
||
let key_values: Vec<Value> = right_key_indices
|
||
.iter()
|
||
.map(|&idx| right_row.values.get(idx).cloned().unwrap_or(Value::Null))
|
||
.collect();
|
||
let right_key = HashableRow::new(0, key_values);
|
||
|
||
let next_row = return_if_io!(read_next_join_row(
|
||
left_storage_id,
|
||
&right_key,
|
||
*last_row_scanned,
|
||
cursors
|
||
));
|
||
match next_row {
|
||
Some((element_id, left_row, left_weight)) => {
|
||
Self::combine_rows(
|
||
&left_row,
|
||
left_weight as i64,
|
||
right_row,
|
||
(*right_weight) as i64,
|
||
output,
|
||
);
|
||
// Continue scanning with this right row
|
||
*self = JoinEvalState::ProcessRightJoin {
|
||
deltas: std::mem::take(deltas),
|
||
output: std::mem::take(output),
|
||
current_idx: *current_idx,
|
||
last_row_scanned: element_id,
|
||
};
|
||
}
|
||
None => {
|
||
// No more matches for this right row, move to next
|
||
*self = JoinEvalState::ProcessRightJoin {
|
||
deltas: std::mem::take(deltas),
|
||
output: std::mem::take(output),
|
||
current_idx: *current_idx + 1,
|
||
last_row_scanned: i64::MIN,
|
||
};
|
||
}
|
||
}
|
||
}
|
||
}
|
||
JoinEvalState::Done { output } => {
|
||
return Ok(IOResult::Done(std::mem::take(output)));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
enum JoinCommitState {
|
||
Idle,
|
||
Eval {
|
||
eval_state: EvalState,
|
||
},
|
||
CommitLeftDelta {
|
||
deltas: DeltaPair,
|
||
output: Delta,
|
||
current_idx: usize,
|
||
write_row: WriteRow,
|
||
},
|
||
CommitRightDelta {
|
||
deltas: DeltaPair,
|
||
output: Delta,
|
||
current_idx: usize,
|
||
write_row: WriteRow,
|
||
},
|
||
Invalid,
|
||
}
|
||
|
||
/// Join operator - performs incremental join between two relations
|
||
/// Implements the DBSP formula: δ(R ⋈ S) = (δR ⋈ S) ∪ (R ⋈ δS) ∪ (δR ⋈ δS)
|
||
#[derive(Debug)]
|
||
pub struct JoinOperator {
|
||
/// Unique operator ID for indexing in persistent storage
|
||
operator_id: usize,
|
||
/// Type of join to perform
|
||
join_type: JoinType,
|
||
/// Column indices for extracting join keys from left input
|
||
left_key_indices: Vec<usize>,
|
||
/// Column indices for extracting join keys from right input
|
||
right_key_indices: Vec<usize>,
|
||
/// Column names from left input
|
||
left_columns: Vec<String>,
|
||
/// Column names from right input
|
||
right_columns: Vec<String>,
|
||
/// Tracker for computation statistics
|
||
tracker: Option<Arc<Mutex<ComputationTracker>>>,
|
||
|
||
commit_state: JoinCommitState,
|
||
}
|
||
|
||
impl JoinOperator {
|
||
pub fn new(
|
||
operator_id: usize,
|
||
join_type: JoinType,
|
||
left_key_indices: Vec<usize>,
|
||
right_key_indices: Vec<usize>,
|
||
left_columns: Vec<String>,
|
||
right_columns: Vec<String>,
|
||
) -> Result<Self> {
|
||
// Check for unsupported join types
|
||
match join_type {
|
||
JoinType::Left => {
|
||
return Err(crate::LimboError::ParseError(
|
||
"LEFT OUTER JOIN is not yet supported in incremental views".to_string(),
|
||
))
|
||
}
|
||
JoinType::Right => {
|
||
return Err(crate::LimboError::ParseError(
|
||
"RIGHT OUTER JOIN is not yet supported in incremental views".to_string(),
|
||
))
|
||
}
|
||
JoinType::Full => {
|
||
return Err(crate::LimboError::ParseError(
|
||
"FULL OUTER JOIN is not yet supported in incremental views".to_string(),
|
||
))
|
||
}
|
||
JoinType::Cross => {
|
||
return Err(crate::LimboError::ParseError(
|
||
"CROSS JOIN is not yet supported in incremental views".to_string(),
|
||
))
|
||
}
|
||
JoinType::Inner => {} // Inner join is supported
|
||
}
|
||
|
||
Ok(Self {
|
||
operator_id,
|
||
join_type,
|
||
left_key_indices,
|
||
right_key_indices,
|
||
left_columns,
|
||
right_columns,
|
||
tracker: None,
|
||
commit_state: JoinCommitState::Idle,
|
||
})
|
||
}
|
||
|
||
/// Extract join key from row values using the specified indices
|
||
fn extract_join_key(&self, values: &[Value], indices: &[usize]) -> HashableRow {
|
||
let key_values: Vec<Value> = indices
|
||
.iter()
|
||
.map(|&idx| values.get(idx).cloned().unwrap_or(Value::Null))
|
||
.collect();
|
||
// Use 0 as a dummy rowid for join keys. They don't come from a table,
|
||
// so they don't need a rowid. Their key will be the hash of the row values.
|
||
HashableRow::new(0, key_values)
|
||
}
|
||
|
||
/// Generate storage ID for left table
|
||
fn left_storage_id(&self) -> i64 {
|
||
// Use column_index=0 for left side
|
||
generate_storage_id(self.operator_id, 0, 0)
|
||
}
|
||
|
||
/// Generate storage ID for right table
|
||
fn right_storage_id(&self) -> i64 {
|
||
// Use column_index=1 for right side
|
||
generate_storage_id(self.operator_id, 1, 0)
|
||
}
|
||
|
||
/// SQL-compliant comparison for join keys
|
||
/// Returns true if keys match according to SQL semantics (NULL != NULL)
|
||
fn sql_keys_equal(left_key: &HashableRow, right_key: &HashableRow) -> bool {
|
||
if left_key.values.len() != right_key.values.len() {
|
||
return false;
|
||
}
|
||
|
||
for (left_val, right_val) in left_key.values.iter().zip(right_key.values.iter()) {
|
||
// In SQL, NULL never equals NULL
|
||
if matches!(left_val, Value::Null) || matches!(right_val, Value::Null) {
|
||
return false;
|
||
}
|
||
|
||
// For non-NULL values, use regular comparison
|
||
if left_val != right_val {
|
||
return false;
|
||
}
|
||
}
|
||
|
||
true
|
||
}
|
||
|
||
fn process_join_state(
|
||
&mut self,
|
||
state: &mut EvalState,
|
||
cursors: &mut DbspStateCursors,
|
||
) -> Result<IOResult<Delta>> {
|
||
// Get the join state out of the enum
|
||
match state {
|
||
EvalState::Join(js) => js.process_join_state(
|
||
cursors,
|
||
&self.left_key_indices,
|
||
&self.right_key_indices,
|
||
self.left_storage_id(),
|
||
self.right_storage_id(),
|
||
),
|
||
_ => panic!("process_join_state called with non-join state"),
|
||
}
|
||
}
|
||
|
||
fn eval_internal(
|
||
&mut self,
|
||
state: &mut EvalState,
|
||
cursors: &mut DbspStateCursors,
|
||
) -> Result<IOResult<Delta>> {
|
||
loop {
|
||
let loop_state = std::mem::replace(state, EvalState::Uninitialized);
|
||
match loop_state {
|
||
EvalState::Uninitialized => {
|
||
panic!("Cannot eval JoinOperator with Uninitialized state");
|
||
}
|
||
EvalState::Init { deltas } => {
|
||
let mut output = Delta::new();
|
||
|
||
// Component 3: δR ⋈ δS (left delta join right delta)
|
||
for (left_row, left_weight) in &deltas.left.changes {
|
||
let left_key =
|
||
self.extract_join_key(&left_row.values, &self.left_key_indices);
|
||
|
||
for (right_row, right_weight) in &deltas.right.changes {
|
||
let right_key =
|
||
self.extract_join_key(&right_row.values, &self.right_key_indices);
|
||
|
||
if Self::sql_keys_equal(&left_key, &right_key) {
|
||
if let Some(tracker) = &self.tracker {
|
||
tracker.lock().unwrap().record_join_lookup();
|
||
}
|
||
|
||
// Combine the rows
|
||
let mut combined_values = left_row.values.clone();
|
||
combined_values.extend(right_row.values.clone());
|
||
|
||
// Create the joined row with a unique rowid
|
||
// Use hash of the combined values to ensure uniqueness
|
||
let temp_row = HashableRow::new(0, combined_values.clone());
|
||
let joined_rowid = temp_row.cached_hash() as i64;
|
||
let joined_row =
|
||
HashableRow::new(joined_rowid, combined_values.clone());
|
||
|
||
// Add to output with combined weight
|
||
let combined_weight = left_weight * right_weight;
|
||
output.changes.push((joined_row, combined_weight));
|
||
}
|
||
}
|
||
}
|
||
|
||
*state = EvalState::Join(Box::new(JoinEvalState::ProcessDeltaJoin {
|
||
deltas,
|
||
output,
|
||
}));
|
||
}
|
||
EvalState::Join(join_state) => {
|
||
*state = EvalState::Join(join_state);
|
||
let output = return_if_io!(self.process_join_state(state, cursors));
|
||
return Ok(IOResult::Done(output));
|
||
}
|
||
EvalState::Done => {
|
||
return Ok(IOResult::Done(Delta::new()));
|
||
}
|
||
EvalState::Aggregate(_) => {
|
||
panic!("Aggregate state should not appear in join operator");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Helper to deserialize a HashableRow from a blob
|
||
fn deserialize_hashable_row(blob: &[u8]) -> Result<HashableRow> {
|
||
// Simple deserialization - this needs to match how we serialize in commit
|
||
// Format: [rowid:8 bytes][num_values:4 bytes][values...]
|
||
if blob.len() < 12 {
|
||
return Err(crate::LimboError::InternalError(
|
||
"Invalid blob size".to_string(),
|
||
));
|
||
}
|
||
|
||
let rowid = i64::from_le_bytes(blob[0..8].try_into().unwrap());
|
||
let num_values = u32::from_le_bytes(blob[8..12].try_into().unwrap()) as usize;
|
||
|
||
let mut values = Vec::new();
|
||
let mut offset = 12;
|
||
|
||
for _ in 0..num_values {
|
||
if offset >= blob.len() {
|
||
break;
|
||
}
|
||
|
||
let type_tag = blob[offset];
|
||
offset += 1;
|
||
|
||
match type_tag {
|
||
0 => values.push(Value::Null),
|
||
1 => {
|
||
if offset + 8 <= blob.len() {
|
||
let i = i64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap());
|
||
values.push(Value::Integer(i));
|
||
offset += 8;
|
||
}
|
||
}
|
||
2 => {
|
||
if offset + 8 <= blob.len() {
|
||
let f = f64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap());
|
||
values.push(Value::Float(f));
|
||
offset += 8;
|
||
}
|
||
}
|
||
3 => {
|
||
if offset + 4 <= blob.len() {
|
||
let len =
|
||
u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize;
|
||
offset += 4;
|
||
if offset + len < blob.len() {
|
||
let text_bytes = blob[offset..offset + len].to_vec();
|
||
offset += len;
|
||
let subtype = match blob[offset] {
|
||
0 => crate::types::TextSubtype::Text,
|
||
1 => crate::types::TextSubtype::Json,
|
||
_ => crate::types::TextSubtype::Text,
|
||
};
|
||
offset += 1;
|
||
values.push(Value::Text(crate::types::Text {
|
||
value: text_bytes,
|
||
subtype,
|
||
}));
|
||
}
|
||
}
|
||
}
|
||
4 => {
|
||
if offset + 4 <= blob.len() {
|
||
let len =
|
||
u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize;
|
||
offset += 4;
|
||
if offset + len <= blob.len() {
|
||
let blob_data = blob[offset..offset + len].to_vec();
|
||
values.push(Value::Blob(blob_data));
|
||
offset += len;
|
||
}
|
||
}
|
||
}
|
||
_ => break, // Unknown type tag
|
||
}
|
||
}
|
||
|
||
Ok(HashableRow::new(rowid, values))
|
||
}
|
||
|
||
// Helper to serialize a HashableRow to a blob
|
||
fn serialize_hashable_row(row: &HashableRow) -> Vec<u8> {
|
||
let mut blob = Vec::new();
|
||
|
||
// Write rowid
|
||
blob.extend_from_slice(&row.rowid.to_le_bytes());
|
||
|
||
// Write number of values
|
||
blob.extend_from_slice(&(row.values.len() as u32).to_le_bytes());
|
||
|
||
// Write each value directly with type tags (like AggregateState does)
|
||
for value in &row.values {
|
||
match value {
|
||
Value::Null => blob.push(0u8),
|
||
Value::Integer(i) => {
|
||
blob.push(1u8);
|
||
blob.extend_from_slice(&i.to_le_bytes());
|
||
}
|
||
Value::Float(f) => {
|
||
blob.push(2u8);
|
||
blob.extend_from_slice(&f.to_le_bytes());
|
||
}
|
||
Value::Text(s) => {
|
||
blob.push(3u8);
|
||
let bytes = &s.value;
|
||
blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
|
||
blob.extend_from_slice(bytes);
|
||
blob.push(s.subtype as u8);
|
||
}
|
||
Value::Blob(b) => {
|
||
blob.push(4u8);
|
||
blob.extend_from_slice(&(b.len() as u32).to_le_bytes());
|
||
blob.extend_from_slice(b);
|
||
}
|
||
}
|
||
}
|
||
|
||
blob
|
||
}
|
||
|
||
impl IncrementalOperator for JoinOperator {
|
||
fn eval(
|
||
&mut self,
|
||
state: &mut EvalState,
|
||
cursors: &mut DbspStateCursors,
|
||
) -> Result<IOResult<Delta>> {
|
||
let delta = return_if_io!(self.eval_internal(state, cursors));
|
||
Ok(IOResult::Done(delta))
|
||
}
|
||
|
||
fn commit(
|
||
&mut self,
|
||
deltas: DeltaPair,
|
||
cursors: &mut DbspStateCursors,
|
||
) -> Result<IOResult<Delta>> {
|
||
loop {
|
||
let mut state = std::mem::replace(&mut self.commit_state, JoinCommitState::Invalid);
|
||
match &mut state {
|
||
JoinCommitState::Idle => {
|
||
self.commit_state = JoinCommitState::Eval {
|
||
eval_state: deltas.clone().into(),
|
||
}
|
||
}
|
||
JoinCommitState::Eval { ref mut eval_state } => {
|
||
let output = return_and_restore_if_io!(
|
||
&mut self.commit_state,
|
||
state,
|
||
self.eval(eval_state, cursors)
|
||
);
|
||
self.commit_state = JoinCommitState::CommitLeftDelta {
|
||
deltas: deltas.clone(),
|
||
output,
|
||
current_idx: 0,
|
||
write_row: WriteRow::new(),
|
||
};
|
||
}
|
||
JoinCommitState::CommitLeftDelta {
|
||
deltas,
|
||
output,
|
||
current_idx,
|
||
ref mut write_row,
|
||
} => {
|
||
if *current_idx >= deltas.left.changes.len() {
|
||
self.commit_state = JoinCommitState::CommitRightDelta {
|
||
deltas: std::mem::take(deltas),
|
||
output: std::mem::take(output),
|
||
current_idx: 0,
|
||
write_row: WriteRow::new(),
|
||
};
|
||
continue;
|
||
}
|
||
|
||
let (row, weight) = &deltas.left.changes[*current_idx];
|
||
// Extract join key from the left row
|
||
let join_key = self.extract_join_key(&row.values, &self.left_key_indices);
|
||
|
||
// The index key: (storage_id, zset_id, element_id)
|
||
// zset_id is the hash of the join key, element_id is hash of the row
|
||
let storage_id = self.left_storage_id();
|
||
let zset_id = join_key.cached_hash() as i64;
|
||
let element_id = row.cached_hash() as i64;
|
||
let index_key = vec![
|
||
Value::Integer(storage_id),
|
||
Value::Integer(zset_id),
|
||
Value::Integer(element_id),
|
||
];
|
||
|
||
// The record values: we'll store the serialized row as a blob
|
||
let row_blob = serialize_hashable_row(row);
|
||
let record_values = vec![
|
||
Value::Integer(self.left_storage_id()),
|
||
Value::Integer(join_key.cached_hash() as i64),
|
||
Value::Integer(row.cached_hash() as i64),
|
||
Value::Blob(row_blob),
|
||
];
|
||
|
||
// Use return_and_restore_if_io to handle I/O properly
|
||
return_and_restore_if_io!(
|
||
&mut self.commit_state,
|
||
state,
|
||
write_row.write_row(cursors, index_key, record_values, *weight)
|
||
);
|
||
|
||
self.commit_state = JoinCommitState::CommitLeftDelta {
|
||
deltas: deltas.clone(),
|
||
output: output.clone(),
|
||
current_idx: *current_idx + 1,
|
||
write_row: WriteRow::new(),
|
||
};
|
||
}
|
||
JoinCommitState::CommitRightDelta {
|
||
deltas,
|
||
output,
|
||
current_idx,
|
||
ref mut write_row,
|
||
} => {
|
||
if *current_idx >= deltas.right.changes.len() {
|
||
// Reset to Idle state for next commit
|
||
self.commit_state = JoinCommitState::Idle;
|
||
return Ok(IOResult::Done(output.clone()));
|
||
}
|
||
|
||
let (row, weight) = &deltas.right.changes[*current_idx];
|
||
// Extract join key from the right row
|
||
let join_key = self.extract_join_key(&row.values, &self.right_key_indices);
|
||
|
||
// The index key: (storage_id, zset_id, element_id)
|
||
let index_key = vec![
|
||
Value::Integer(self.right_storage_id()),
|
||
Value::Integer(join_key.cached_hash() as i64),
|
||
Value::Integer(row.cached_hash() as i64),
|
||
];
|
||
|
||
// The record values: we'll store the serialized row as a blob
|
||
let row_blob = serialize_hashable_row(row);
|
||
let record_values = vec![
|
||
Value::Integer(self.right_storage_id()),
|
||
Value::Integer(join_key.cached_hash() as i64),
|
||
Value::Integer(row.cached_hash() as i64),
|
||
Value::Blob(row_blob),
|
||
];
|
||
|
||
// Use return_and_restore_if_io to handle I/O properly
|
||
return_and_restore_if_io!(
|
||
&mut self.commit_state,
|
||
state,
|
||
write_row.write_row(cursors, index_key, record_values, *weight)
|
||
);
|
||
|
||
self.commit_state = JoinCommitState::CommitRightDelta {
|
||
deltas: std::mem::take(deltas),
|
||
output: std::mem::take(output),
|
||
current_idx: *current_idx + 1,
|
||
write_row: WriteRow::new(),
|
||
};
|
||
}
|
||
JoinCommitState::Invalid => {
|
||
panic!("Invalid join commit state");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
fn set_tracker(&mut self, tracker: Arc<Mutex<ComputationTracker>>) {
|
||
self.tracker = Some(tracker);
|
||
}
|
||
}
|