mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-31 13:54:27 +01:00
Implement JOINs in the DBSP circuit
This PR improves the DBSP circuit so that it handles the JOIN operator. The JOIN operator exposes a weakness of our current model: we usually pass a list of columns between operators, and find the right column by name when needed. But with JOINs, many tables can have the same columns. The operators will then find the wrong column (same name, different table), and produce incorrect results. To fix this, we must do two things: 1) Change the Logical Plan. It needs to track table provenance. 2) Fix the aggregators: it needs to operate on indexes, not names. For the aggregators, note that table provenance is the wrong abstraction. The aggregator is likely working with a logical table that is the result of previous nodes in the circuit. So we just need to be able to tell it which index in the column array it should use.
This commit is contained in:
@@ -19,20 +19,20 @@ pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both)
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AggregateFunction {
|
||||
Count,
|
||||
Sum(String),
|
||||
Avg(String),
|
||||
Min(String),
|
||||
Max(String),
|
||||
Sum(usize), // Column index
|
||||
Avg(usize), // Column index
|
||||
Min(usize), // Column index
|
||||
Max(usize), // Column index
|
||||
}
|
||||
|
||||
impl Display for AggregateFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
AggregateFunction::Count => write!(f, "COUNT(*)"),
|
||||
AggregateFunction::Sum(col) => write!(f, "SUM({col})"),
|
||||
AggregateFunction::Avg(col) => write!(f, "AVG({col})"),
|
||||
AggregateFunction::Min(col) => write!(f, "MIN({col})"),
|
||||
AggregateFunction::Max(col) => write!(f, "MAX({col})"),
|
||||
AggregateFunction::Sum(idx) => write!(f, "SUM(col{idx})"),
|
||||
AggregateFunction::Avg(idx) => write!(f, "AVG(col{idx})"),
|
||||
AggregateFunction::Min(idx) => write!(f, "MIN(col{idx})"),
|
||||
AggregateFunction::Max(idx) => write!(f, "MAX(col{idx})"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -48,16 +48,16 @@ impl AggregateFunction {
|
||||
/// Returns None if the function is not a supported aggregate
|
||||
pub fn from_sql_function(
|
||||
func: &crate::function::Func,
|
||||
input_column: Option<String>,
|
||||
input_column_idx: Option<usize>,
|
||||
) -> Option<Self> {
|
||||
match func {
|
||||
Func::Agg(agg_func) => {
|
||||
match agg_func {
|
||||
AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count),
|
||||
AggFunc::Sum => input_column.map(AggregateFunction::Sum),
|
||||
AggFunc::Avg => input_column.map(AggregateFunction::Avg),
|
||||
AggFunc::Min => input_column.map(AggregateFunction::Min),
|
||||
AggFunc::Max => input_column.map(AggregateFunction::Max),
|
||||
AggFunc::Sum => input_column_idx.map(AggregateFunction::Sum),
|
||||
AggFunc::Avg => input_column_idx.map(AggregateFunction::Avg),
|
||||
AggFunc::Min => input_column_idx.map(AggregateFunction::Min),
|
||||
AggFunc::Max => input_column_idx.map(AggregateFunction::Max),
|
||||
_ => None, // Other aggregate functions not yet supported in DBSP
|
||||
}
|
||||
}
|
||||
@@ -115,8 +115,8 @@ pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> {
|
||||
|
||||
// group_key_str -> (group_key, state)
|
||||
type ComputedStates = HashMap<String, (Vec<Value>, AggregateState)>;
|
||||
// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight
|
||||
pub type MinMaxDeltas = HashMap<String, HashMap<(String, HashableRow), isize>>;
|
||||
// group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight
|
||||
pub type MinMaxDeltas = HashMap<String, HashMap<(usize, HashableRow), isize>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum AggregateCommitState {
|
||||
@@ -178,14 +178,14 @@ pub enum AggregateEvalState {
|
||||
pub struct AggregateOperator {
|
||||
// Unique operator ID for indexing in persistent storage
|
||||
pub operator_id: usize,
|
||||
// GROUP BY columns
|
||||
group_by: Vec<String>,
|
||||
// GROUP BY column indices
|
||||
group_by: Vec<usize>,
|
||||
// Aggregate functions to compute (including MIN/MAX)
|
||||
pub aggregates: Vec<AggregateFunction>,
|
||||
// Column names from input
|
||||
pub input_column_names: Vec<String>,
|
||||
// Map from column name to aggregate info for quick lookup
|
||||
pub column_min_max: HashMap<String, AggColumnInfo>,
|
||||
// Map from column index to aggregate info for quick lookup
|
||||
pub column_min_max: HashMap<usize, AggColumnInfo>,
|
||||
tracker: Option<Arc<Mutex<ComputationTracker>>>,
|
||||
|
||||
// State machine for commit operation
|
||||
@@ -197,14 +197,14 @@ pub struct AggregateOperator {
|
||||
pub struct AggregateState {
|
||||
// For COUNT: just the count
|
||||
pub count: i64,
|
||||
// For SUM: column_name -> sum value
|
||||
sums: HashMap<String, f64>,
|
||||
// For AVG: column_name -> (sum, count) for computing average
|
||||
avgs: HashMap<String, (f64, i64)>,
|
||||
// For MIN: column_name -> minimum value
|
||||
pub mins: HashMap<String, Value>,
|
||||
// For MAX: column_name -> maximum value
|
||||
pub maxs: HashMap<String, Value>,
|
||||
// For SUM: column_index -> sum value
|
||||
sums: HashMap<usize, f64>,
|
||||
// For AVG: column_index -> (sum, count) for computing average
|
||||
avgs: HashMap<usize, (f64, i64)>,
|
||||
// For MIN: column_index -> minimum value
|
||||
pub mins: HashMap<usize, Value>,
|
||||
// For MAX: column_index -> maximum value
|
||||
pub maxs: HashMap<usize, Value>,
|
||||
}
|
||||
|
||||
impl AggregateEvalState {
|
||||
@@ -520,14 +520,14 @@ impl AggregateState {
|
||||
AggregateFunction::Sum(col_name) => {
|
||||
let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
state.sums.insert(col_name.clone(), sum);
|
||||
state.sums.insert(*col_name, sum);
|
||||
}
|
||||
AggregateFunction::Avg(col_name) => {
|
||||
let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
state.avgs.insert(col_name.clone(), (sum, count));
|
||||
state.avgs.insert(*col_name, (sum, count));
|
||||
}
|
||||
AggregateFunction::Count => {
|
||||
// Count was already read above
|
||||
@@ -540,7 +540,7 @@ impl AggregateState {
|
||||
if has_value == 1 {
|
||||
let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?;
|
||||
cursor += bytes_consumed;
|
||||
state.mins.insert(col_name.clone(), min_value);
|
||||
state.mins.insert(*col_name, min_value);
|
||||
}
|
||||
}
|
||||
AggregateFunction::Max(col_name) => {
|
||||
@@ -551,7 +551,7 @@ impl AggregateState {
|
||||
if has_value == 1 {
|
||||
let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?;
|
||||
cursor += bytes_consumed;
|
||||
state.maxs.insert(col_name.clone(), max_value);
|
||||
state.maxs.insert(*col_name, max_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -566,7 +566,7 @@ impl AggregateState {
|
||||
values: &[Value],
|
||||
weight: isize,
|
||||
aggregates: &[AggregateFunction],
|
||||
column_names: &[String],
|
||||
_column_names: &[String], // No longer needed
|
||||
) {
|
||||
// Update COUNT
|
||||
self.count += weight as i64;
|
||||
@@ -577,32 +577,26 @@ impl AggregateState {
|
||||
AggregateFunction::Count => {
|
||||
// Already handled above
|
||||
}
|
||||
AggregateFunction::Sum(col_name) => {
|
||||
if let Some(idx) = column_names.iter().position(|c| c == col_name) {
|
||||
if let Some(val) = values.get(idx) {
|
||||
let num_val = match val {
|
||||
Value::Integer(i) => *i as f64,
|
||||
Value::Float(f) => *f,
|
||||
_ => 0.0,
|
||||
};
|
||||
*self.sums.entry(col_name.clone()).or_insert(0.0) +=
|
||||
num_val * weight as f64;
|
||||
}
|
||||
AggregateFunction::Sum(col_idx) => {
|
||||
if let Some(val) = values.get(*col_idx) {
|
||||
let num_val = match val {
|
||||
Value::Integer(i) => *i as f64,
|
||||
Value::Float(f) => *f,
|
||||
_ => 0.0,
|
||||
};
|
||||
*self.sums.entry(*col_idx).or_insert(0.0) += num_val * weight as f64;
|
||||
}
|
||||
}
|
||||
AggregateFunction::Avg(col_name) => {
|
||||
if let Some(idx) = column_names.iter().position(|c| c == col_name) {
|
||||
if let Some(val) = values.get(idx) {
|
||||
let num_val = match val {
|
||||
Value::Integer(i) => *i as f64,
|
||||
Value::Float(f) => *f,
|
||||
_ => 0.0,
|
||||
};
|
||||
let (sum, count) =
|
||||
self.avgs.entry(col_name.clone()).or_insert((0.0, 0));
|
||||
*sum += num_val * weight as f64;
|
||||
*count += weight as i64;
|
||||
}
|
||||
AggregateFunction::Avg(col_idx) => {
|
||||
if let Some(val) = values.get(*col_idx) {
|
||||
let num_val = match val {
|
||||
Value::Integer(i) => *i as f64,
|
||||
Value::Float(f) => *f,
|
||||
_ => 0.0,
|
||||
};
|
||||
let (sum, count) = self.avgs.entry(*col_idx).or_insert((0.0, 0));
|
||||
*sum += num_val * weight as f64;
|
||||
*count += weight as i64;
|
||||
}
|
||||
}
|
||||
AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => {
|
||||
@@ -644,8 +638,8 @@ impl AggregateState {
|
||||
AggregateFunction::Count => {
|
||||
result.push(Value::Integer(self.count));
|
||||
}
|
||||
AggregateFunction::Sum(col_name) => {
|
||||
let sum = self.sums.get(col_name).copied().unwrap_or(0.0);
|
||||
AggregateFunction::Sum(col_idx) => {
|
||||
let sum = self.sums.get(col_idx).copied().unwrap_or(0.0);
|
||||
// Return as integer if it's a whole number, otherwise as float
|
||||
if sum.fract() == 0.0 {
|
||||
result.push(Value::Integer(sum as i64));
|
||||
@@ -653,8 +647,8 @@ impl AggregateState {
|
||||
result.push(Value::Float(sum));
|
||||
}
|
||||
}
|
||||
AggregateFunction::Avg(col_name) => {
|
||||
if let Some((sum, count)) = self.avgs.get(col_name) {
|
||||
AggregateFunction::Avg(col_idx) => {
|
||||
if let Some((sum, count)) = self.avgs.get(col_idx) {
|
||||
if *count > 0 {
|
||||
result.push(Value::Float(sum / *count as f64));
|
||||
} else {
|
||||
@@ -664,13 +658,13 @@ impl AggregateState {
|
||||
result.push(Value::Null);
|
||||
}
|
||||
}
|
||||
AggregateFunction::Min(col_name) => {
|
||||
AggregateFunction::Min(col_idx) => {
|
||||
// Return the MIN value from our state
|
||||
result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null));
|
||||
result.push(self.mins.get(col_idx).cloned().unwrap_or(Value::Null));
|
||||
}
|
||||
AggregateFunction::Max(col_name) => {
|
||||
AggregateFunction::Max(col_idx) => {
|
||||
// Return the MAX value from our state
|
||||
result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null));
|
||||
result.push(self.maxs.get(col_idx).cloned().unwrap_or(Value::Null));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -682,20 +676,20 @@ impl AggregateState {
|
||||
impl AggregateOperator {
|
||||
pub fn new(
|
||||
operator_id: usize,
|
||||
group_by: Vec<String>,
|
||||
group_by: Vec<usize>,
|
||||
aggregates: Vec<AggregateFunction>,
|
||||
input_column_names: Vec<String>,
|
||||
) -> Self {
|
||||
// Build map of column names to their MIN/MAX info with indices
|
||||
// Build map of column indices to their MIN/MAX info
|
||||
let mut column_min_max = HashMap::new();
|
||||
let mut column_indices = HashMap::new();
|
||||
let mut storage_indices = HashMap::new();
|
||||
let mut current_index = 0;
|
||||
|
||||
// First pass: assign indices to unique MIN/MAX columns
|
||||
// First pass: assign storage indices to unique MIN/MAX columns
|
||||
for agg in &aggregates {
|
||||
match agg {
|
||||
AggregateFunction::Min(col) | AggregateFunction::Max(col) => {
|
||||
column_indices.entry(col.clone()).or_insert_with(|| {
|
||||
AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => {
|
||||
storage_indices.entry(*col_idx).or_insert_with(|| {
|
||||
let idx = current_index;
|
||||
current_index += 1;
|
||||
idx
|
||||
@@ -708,19 +702,19 @@ impl AggregateOperator {
|
||||
// Second pass: build the column info map
|
||||
for agg in &aggregates {
|
||||
match agg {
|
||||
AggregateFunction::Min(col) => {
|
||||
let index = *column_indices.get(col).unwrap();
|
||||
let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo {
|
||||
index,
|
||||
AggregateFunction::Min(col_idx) => {
|
||||
let storage_index = *storage_indices.get(col_idx).unwrap();
|
||||
let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo {
|
||||
index: storage_index,
|
||||
has_min: false,
|
||||
has_max: false,
|
||||
});
|
||||
entry.has_min = true;
|
||||
}
|
||||
AggregateFunction::Max(col) => {
|
||||
let index = *column_indices.get(col).unwrap();
|
||||
let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo {
|
||||
index,
|
||||
AggregateFunction::Max(col_idx) => {
|
||||
let storage_index = *storage_indices.get(col_idx).unwrap();
|
||||
let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo {
|
||||
index: storage_index,
|
||||
has_min: false,
|
||||
has_max: false,
|
||||
});
|
||||
@@ -876,28 +870,24 @@ impl AggregateOperator {
|
||||
|
||||
for agg in &self.aggregates {
|
||||
match agg {
|
||||
AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => {
|
||||
if let Some(idx) =
|
||||
self.input_column_names.iter().position(|c| c == col_name)
|
||||
{
|
||||
if let Some(val) = row.values.get(idx) {
|
||||
// Skip NULL values - they don't participate in MIN/MAX
|
||||
if val == &Value::Null {
|
||||
continue;
|
||||
}
|
||||
// Create a HashableRow with just this value
|
||||
// Use 0 as rowid since we only care about the value for comparison
|
||||
let hashable_value = HashableRow::new(0, vec![val.clone()]);
|
||||
let key = (col_name.clone(), hashable_value);
|
||||
|
||||
let group_entry =
|
||||
min_max_deltas.entry(group_key_str.clone()).or_default();
|
||||
|
||||
let value_entry = group_entry.entry(key).or_insert(0);
|
||||
|
||||
// Accumulate the weight
|
||||
*value_entry += weight;
|
||||
AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => {
|
||||
if let Some(val) = row.values.get(*col_idx) {
|
||||
// Skip NULL values - they don't participate in MIN/MAX
|
||||
if val == &Value::Null {
|
||||
continue;
|
||||
}
|
||||
// Create a HashableRow with just this value
|
||||
// Use 0 as rowid since we only care about the value for comparison
|
||||
let hashable_value = HashableRow::new(0, vec![val.clone()]);
|
||||
let key = (*col_idx, hashable_value);
|
||||
|
||||
let group_entry =
|
||||
min_max_deltas.entry(group_key_str.clone()).or_default();
|
||||
|
||||
let value_entry = group_entry.entry(key).or_insert(0);
|
||||
|
||||
// Accumulate the weight
|
||||
*value_entry += weight;
|
||||
}
|
||||
}
|
||||
_ => {} // Ignore non-MIN/MAX aggregates
|
||||
@@ -929,13 +919,9 @@ impl AggregateOperator {
|
||||
pub fn extract_group_key(&self, values: &[Value]) -> Vec<Value> {
|
||||
let mut key = Vec::new();
|
||||
|
||||
for group_col in &self.group_by {
|
||||
if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) {
|
||||
if let Some(val) = values.get(idx) {
|
||||
key.push(val.clone());
|
||||
} else {
|
||||
key.push(Value::Null);
|
||||
}
|
||||
for &idx in &self.group_by {
|
||||
if let Some(val) = values.get(idx) {
|
||||
key.push(val.clone());
|
||||
} else {
|
||||
key.push(Value::Null);
|
||||
}
|
||||
@@ -1124,13 +1110,13 @@ pub enum RecomputeMinMax {
|
||||
/// Current column being processed
|
||||
current_column_idx: usize,
|
||||
/// Columns to process (combined MIN and MAX)
|
||||
columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min)
|
||||
columns_to_process: Vec<(String, usize, bool)>, // (group_key, column_name, is_min)
|
||||
/// MIN/MAX deltas for checking values and weights
|
||||
min_max_deltas: MinMaxDeltas,
|
||||
},
|
||||
Scan {
|
||||
/// Columns still to process
|
||||
columns_to_process: Vec<(String, String, bool)>,
|
||||
columns_to_process: Vec<(String, usize, bool)>,
|
||||
/// Current index in columns_to_process (will resume from here)
|
||||
current_column_idx: usize,
|
||||
/// MIN/MAX deltas for checking values and weights
|
||||
@@ -1138,7 +1124,7 @@ pub enum RecomputeMinMax {
|
||||
/// Current group key being processed
|
||||
group_key: String,
|
||||
/// Current column name being processed
|
||||
column_name: String,
|
||||
column_name: usize,
|
||||
/// Whether we're looking for MIN (true) or MAX (false)
|
||||
is_min: bool,
|
||||
/// The scan state machine for finding the new MIN/MAX
|
||||
@@ -1153,7 +1139,7 @@ impl RecomputeMinMax {
|
||||
existing_groups: &HashMap<String, AggregateState>,
|
||||
operator: &AggregateOperator,
|
||||
) -> Self {
|
||||
let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new();
|
||||
let mut groups_to_check: HashSet<(String, usize, bool)> = HashSet::new();
|
||||
|
||||
// Remember the min_max_deltas are essentially just the only column that is affected by
|
||||
// this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier
|
||||
@@ -1173,21 +1159,13 @@ impl RecomputeMinMax {
|
||||
// Check for MIN
|
||||
if let Some(current_min) = state.mins.get(col_name) {
|
||||
if current_min == value {
|
||||
groups_to_check.insert((
|
||||
group_key_str.clone(),
|
||||
col_name.clone(),
|
||||
true,
|
||||
));
|
||||
groups_to_check.insert((group_key_str.clone(), *col_name, true));
|
||||
}
|
||||
}
|
||||
// Check for MAX
|
||||
if let Some(current_max) = state.maxs.get(col_name) {
|
||||
if current_max == value {
|
||||
groups_to_check.insert((
|
||||
group_key_str.clone(),
|
||||
col_name.clone(),
|
||||
false,
|
||||
));
|
||||
groups_to_check.insert((group_key_str.clone(), *col_name, false));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1196,14 +1174,10 @@ impl RecomputeMinMax {
|
||||
// about this if this is a new record being inserted
|
||||
if let Some(info) = col_info {
|
||||
if info.has_min {
|
||||
groups_to_check.insert((group_key_str.clone(), col_name.clone(), true));
|
||||
groups_to_check.insert((group_key_str.clone(), *col_name, true));
|
||||
}
|
||||
if info.has_max {
|
||||
groups_to_check.insert((
|
||||
group_key_str.clone(),
|
||||
col_name.clone(),
|
||||
false,
|
||||
));
|
||||
groups_to_check.insert((group_key_str.clone(), *col_name, false));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1245,12 +1219,13 @@ impl RecomputeMinMax {
|
||||
let (group_key, column_name, is_min) =
|
||||
columns_to_process[*current_column_idx].clone();
|
||||
|
||||
// Get column index from pre-computed info
|
||||
let column_index = operator
|
||||
// Column name is already the index
|
||||
// Get the storage index from column_min_max map
|
||||
let column_info = operator
|
||||
.column_min_max
|
||||
.get(&column_name)
|
||||
.map(|info| info.index)
|
||||
.unwrap(); // Should always exist since we're processing known columns
|
||||
.expect("Column should exist in column_min_max map");
|
||||
let storage_index = column_info.index;
|
||||
|
||||
// Get current value from existing state
|
||||
let current_value = existing_groups.get(&group_key).and_then(|state| {
|
||||
@@ -1263,7 +1238,7 @@ impl RecomputeMinMax {
|
||||
|
||||
// Create storage keys for index lookup
|
||||
let storage_id =
|
||||
generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX);
|
||||
generate_storage_id(operator.operator_id, storage_index, AGG_TYPE_MINMAX);
|
||||
let zset_id = operator.generate_group_rowid(&group_key);
|
||||
|
||||
// Get the values for this group from min_max_deltas
|
||||
@@ -1276,7 +1251,7 @@ impl RecomputeMinMax {
|
||||
Box::new(ScanState::new_for_min(
|
||||
current_value,
|
||||
group_key.clone(),
|
||||
column_name.clone(),
|
||||
column_name,
|
||||
storage_id,
|
||||
zset_id,
|
||||
group_values,
|
||||
@@ -1285,7 +1260,7 @@ impl RecomputeMinMax {
|
||||
Box::new(ScanState::new_for_max(
|
||||
current_value,
|
||||
group_key.clone(),
|
||||
column_name.clone(),
|
||||
column_name,
|
||||
storage_id,
|
||||
zset_id,
|
||||
group_values,
|
||||
@@ -1319,12 +1294,12 @@ impl RecomputeMinMax {
|
||||
|
||||
if *is_min {
|
||||
if let Some(min_val) = new_value {
|
||||
state.mins.insert(column_name.clone(), min_val);
|
||||
state.mins.insert(*column_name, min_val);
|
||||
} else {
|
||||
state.mins.remove(column_name);
|
||||
}
|
||||
} else if let Some(max_val) = new_value {
|
||||
state.maxs.insert(column_name.clone(), max_val);
|
||||
state.maxs.insert(*column_name, max_val);
|
||||
} else {
|
||||
state.maxs.remove(column_name);
|
||||
}
|
||||
@@ -1355,13 +1330,13 @@ pub enum ScanState {
|
||||
/// Group key being processed
|
||||
group_key: String,
|
||||
/// Column name being processed
|
||||
column_name: String,
|
||||
column_name: usize,
|
||||
/// Storage ID for the index seek
|
||||
storage_id: i64,
|
||||
/// ZSet ID for the group
|
||||
zset_id: i64,
|
||||
/// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight
|
||||
group_values: HashMap<(String, HashableRow), isize>,
|
||||
group_values: HashMap<(usize, HashableRow), isize>,
|
||||
/// Whether we're looking for MIN (true) or MAX (false)
|
||||
is_min: bool,
|
||||
},
|
||||
@@ -1371,13 +1346,13 @@ pub enum ScanState {
|
||||
/// Group key being processed
|
||||
group_key: String,
|
||||
/// Column name being processed
|
||||
column_name: String,
|
||||
column_name: usize,
|
||||
/// Storage ID for the index seek
|
||||
storage_id: i64,
|
||||
/// ZSet ID for the group
|
||||
zset_id: i64,
|
||||
/// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight
|
||||
group_values: HashMap<(String, HashableRow), isize>,
|
||||
group_values: HashMap<(usize, HashableRow), isize>,
|
||||
/// Whether we're looking for MIN (true) or MAX (false)
|
||||
is_min: bool,
|
||||
},
|
||||
@@ -1391,10 +1366,10 @@ impl ScanState {
|
||||
pub fn new_for_min(
|
||||
current_min: Option<Value>,
|
||||
group_key: String,
|
||||
column_name: String,
|
||||
column_name: usize,
|
||||
storage_id: i64,
|
||||
zset_id: i64,
|
||||
group_values: HashMap<(String, HashableRow), isize>,
|
||||
group_values: HashMap<(usize, HashableRow), isize>,
|
||||
) -> Self {
|
||||
Self::CheckCandidate {
|
||||
candidate: current_min,
|
||||
@@ -1460,10 +1435,10 @@ impl ScanState {
|
||||
pub fn new_for_max(
|
||||
current_max: Option<Value>,
|
||||
group_key: String,
|
||||
column_name: String,
|
||||
column_name: usize,
|
||||
storage_id: i64,
|
||||
zset_id: i64,
|
||||
group_values: HashMap<(String, HashableRow), isize>,
|
||||
group_values: HashMap<(usize, HashableRow), isize>,
|
||||
) -> Self {
|
||||
Self::CheckCandidate {
|
||||
candidate: current_max,
|
||||
@@ -1496,7 +1471,7 @@ impl ScanState {
|
||||
// Check if the candidate is retracted (weight <= 0)
|
||||
// Create a HashableRow to look up the weight
|
||||
let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]);
|
||||
let key = (column_name.clone(), hashable_cand);
|
||||
let key = (*column_name, hashable_cand);
|
||||
let is_retracted =
|
||||
group_values.get(&key).is_some_and(|weight| *weight <= 0);
|
||||
|
||||
@@ -1633,7 +1608,7 @@ pub enum MinMaxPersistState {
|
||||
group_idx: usize,
|
||||
value_idx: usize,
|
||||
value: Value,
|
||||
column_name: String,
|
||||
column_name: usize,
|
||||
weight: isize,
|
||||
write_row: WriteRow,
|
||||
},
|
||||
@@ -1652,7 +1627,7 @@ impl MinMaxPersistState {
|
||||
pub fn persist_min_max(
|
||||
&mut self,
|
||||
operator_id: usize,
|
||||
column_min_max: &HashMap<String, AggColumnInfo>,
|
||||
column_min_max: &HashMap<usize, AggColumnInfo>,
|
||||
cursors: &mut DbspStateCursors,
|
||||
generate_group_rowid: impl Fn(&str) -> i64,
|
||||
) -> Result<IOResult<()>> {
|
||||
@@ -1699,7 +1674,7 @@ impl MinMaxPersistState {
|
||||
|
||||
// Process current value and extract what we need before taking ownership
|
||||
let ((column_name, hashable_row), weight) = values_vec[*value_idx];
|
||||
let column_name = column_name.clone();
|
||||
let column_name = *column_name;
|
||||
let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow
|
||||
let weight = *weight;
|
||||
|
||||
@@ -1731,9 +1706,9 @@ impl MinMaxPersistState {
|
||||
|
||||
let group_key_str = &group_keys[*group_idx];
|
||||
|
||||
// Get the column index from the pre-computed map
|
||||
// Get the column info from the pre-computed map
|
||||
let column_info = column_min_max
|
||||
.get(&*column_name)
|
||||
.get(column_name)
|
||||
.expect("Column should exist in column_min_max map");
|
||||
let column_index = column_info.index;
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@
|
||||
// Based on Feldera DBSP design but adapted for Turso's architecture
|
||||
|
||||
pub use crate::incremental::aggregate_operator::{
|
||||
AggregateEvalState, AggregateFunction, AggregateOperator, AggregateState,
|
||||
AggregateEvalState, AggregateFunction, AggregateState,
|
||||
};
|
||||
pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate};
|
||||
pub use crate::incremental::input_operator::InputOperator;
|
||||
@@ -251,7 +251,7 @@ pub trait IncrementalOperator: Debug {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::incremental::aggregate_operator::AGG_TYPE_REGULAR;
|
||||
use crate::incremental::aggregate_operator::{AggregateOperator, AGG_TYPE_REGULAR};
|
||||
use crate::incremental::dbsp::HashableRow;
|
||||
use crate::storage::pager::CreateBTreeFlags;
|
||||
use crate::types::Text;
|
||||
@@ -395,9 +395,9 @@ mod tests {
|
||||
|
||||
// Create an aggregate operator for SUM(age) with no GROUP BY
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec![], // No GROUP BY
|
||||
vec![AggregateFunction::Sum("age".to_string())],
|
||||
1, // operator_id for testing
|
||||
vec![], // No GROUP BY
|
||||
vec![AggregateFunction::Sum(2)], // age is at index 2
|
||||
vec!["id".to_string(), "name".to_string(), "age".to_string()],
|
||||
);
|
||||
|
||||
@@ -514,9 +514,9 @@ mod tests {
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec!["team".to_string()], // GROUP BY team
|
||||
vec![AggregateFunction::Sum("score".to_string())],
|
||||
1, // operator_id for testing
|
||||
vec![1], // GROUP BY team (index 1)
|
||||
vec![AggregateFunction::Sum(3)], // score is at index 3
|
||||
vec![
|
||||
"id".to_string(),
|
||||
"team".to_string(),
|
||||
@@ -666,8 +666,8 @@ mod tests {
|
||||
|
||||
// Create COUNT(*) GROUP BY category
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec!["category".to_string()],
|
||||
1, // operator_id for testing
|
||||
vec![1], // category is at index 1
|
||||
vec![AggregateFunction::Count],
|
||||
vec![
|
||||
"item_id".to_string(),
|
||||
@@ -746,9 +746,9 @@ mod tests {
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec!["product".to_string()],
|
||||
vec![AggregateFunction::Sum("amount".to_string())],
|
||||
1, // operator_id for testing
|
||||
vec![1], // product is at index 1
|
||||
vec![AggregateFunction::Sum(2)], // amount is at index 2
|
||||
vec![
|
||||
"sale_id".to_string(),
|
||||
"product".to_string(),
|
||||
@@ -843,11 +843,11 @@ mod tests {
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec!["user_id".to_string()],
|
||||
1, // operator_id for testing
|
||||
vec![1], // user_id is at index 1
|
||||
vec![
|
||||
AggregateFunction::Count,
|
||||
AggregateFunction::Sum("amount".to_string()),
|
||||
AggregateFunction::Sum(2), // amount is at index 2
|
||||
],
|
||||
vec![
|
||||
"order_id".to_string(),
|
||||
@@ -935,9 +935,9 @@ mod tests {
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec!["category".to_string()],
|
||||
vec![AggregateFunction::Avg("value".to_string())],
|
||||
1, // operator_id for testing
|
||||
vec![1], // category is at index 1
|
||||
vec![AggregateFunction::Avg(2)], // value is at index 2
|
||||
vec![
|
||||
"id".to_string(),
|
||||
"category".to_string(),
|
||||
@@ -1035,11 +1035,11 @@ mod tests {
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec!["category".to_string()],
|
||||
1, // operator_id for testing
|
||||
vec![1], // category is at index 1
|
||||
vec![
|
||||
AggregateFunction::Count,
|
||||
AggregateFunction::Sum("value".to_string()),
|
||||
AggregateFunction::Sum(2), // value is at index 2
|
||||
],
|
||||
vec![
|
||||
"id".to_string(),
|
||||
@@ -1108,7 +1108,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_count_aggregation_with_deletions() {
|
||||
let aggregates = vec![AggregateFunction::Count];
|
||||
let group_by = vec!["category".to_string()];
|
||||
let group_by = vec![0]; // category is at index 0
|
||||
let input_columns = vec!["category".to_string(), "value".to_string()];
|
||||
|
||||
// Create a persistent pager for the test
|
||||
@@ -1197,8 +1197,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_sum_aggregation_with_deletions() {
|
||||
let aggregates = vec![AggregateFunction::Sum("value".to_string())];
|
||||
let group_by = vec!["category".to_string()];
|
||||
let aggregates = vec![AggregateFunction::Sum(1)]; // value is at index 1
|
||||
let group_by = vec![0]; // category is at index 0
|
||||
let input_columns = vec!["category".to_string(), "value".to_string()];
|
||||
|
||||
// Create a persistent pager for the test
|
||||
@@ -1281,8 +1281,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_avg_aggregation_with_deletions() {
|
||||
let aggregates = vec![AggregateFunction::Avg("value".to_string())];
|
||||
let group_by = vec!["category".to_string()];
|
||||
let aggregates = vec![AggregateFunction::Avg(1)]; // value is at index 1
|
||||
let group_by = vec![0]; // category is at index 0
|
||||
let input_columns = vec!["category".to_string(), "value".to_string()];
|
||||
|
||||
// Create a persistent pager for the test
|
||||
@@ -1348,10 +1348,10 @@ mod tests {
|
||||
// Test COUNT, SUM, and AVG together
|
||||
let aggregates = vec![
|
||||
AggregateFunction::Count,
|
||||
AggregateFunction::Sum("value".to_string()),
|
||||
AggregateFunction::Avg("value".to_string()),
|
||||
AggregateFunction::Sum(1), // value is at index 1
|
||||
AggregateFunction::Avg(1), // value is at index 1
|
||||
];
|
||||
let group_by = vec!["category".to_string()];
|
||||
let group_by = vec![0]; // category is at index 0
|
||||
let input_columns = vec!["category".to_string(), "value".to_string()];
|
||||
|
||||
// Create a persistent pager for the test
|
||||
@@ -1607,11 +1607,11 @@ mod tests {
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec!["category".to_string()],
|
||||
1, // operator_id for testing
|
||||
vec![1], // category is at index 1
|
||||
vec![
|
||||
AggregateFunction::Count,
|
||||
AggregateFunction::Sum("amount".to_string()),
|
||||
AggregateFunction::Sum(2), // amount is at index 2
|
||||
],
|
||||
vec![
|
||||
"id".to_string(),
|
||||
@@ -1781,7 +1781,7 @@ mod tests {
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Count,
|
||||
AggregateFunction::Sum("value".to_string()),
|
||||
AggregateFunction::Sum(1), // value is at index 1
|
||||
],
|
||||
vec!["id".to_string(), "value".to_string()],
|
||||
);
|
||||
@@ -1859,8 +1859,8 @@ mod tests {
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id for testing
|
||||
vec!["type".to_string()],
|
||||
1, // operator_id for testing
|
||||
vec![1], // type is at index 1
|
||||
vec![AggregateFunction::Count],
|
||||
vec!["id".to_string(), "type".to_string()],
|
||||
);
|
||||
@@ -1976,8 +1976,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("price".to_string()),
|
||||
AggregateFunction::Max("price".to_string()),
|
||||
AggregateFunction::Min(2), // price is at index 2
|
||||
AggregateFunction::Max(2), // price is at index 2
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string(), "price".to_string()],
|
||||
);
|
||||
@@ -2044,8 +2044,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("price".to_string()),
|
||||
AggregateFunction::Max("price".to_string()),
|
||||
AggregateFunction::Min(2), // price is at index 2
|
||||
AggregateFunction::Max(2), // price is at index 2
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string(), "price".to_string()],
|
||||
);
|
||||
@@ -2134,8 +2134,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("price".to_string()),
|
||||
AggregateFunction::Max("price".to_string()),
|
||||
AggregateFunction::Min(2), // price is at index 2
|
||||
AggregateFunction::Max(2), // price is at index 2
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string(), "price".to_string()],
|
||||
);
|
||||
@@ -2224,8 +2224,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("price".to_string()),
|
||||
AggregateFunction::Max("price".to_string()),
|
||||
AggregateFunction::Min(2), // price is at index 2
|
||||
AggregateFunction::Max(2), // price is at index 2
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string(), "price".to_string()],
|
||||
);
|
||||
@@ -2306,8 +2306,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("price".to_string()),
|
||||
AggregateFunction::Max("price".to_string()),
|
||||
AggregateFunction::Min(2), // price is at index 2
|
||||
AggregateFunction::Max(2), // price is at index 2
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string(), "price".to_string()],
|
||||
);
|
||||
@@ -2388,8 +2388,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("price".to_string()),
|
||||
AggregateFunction::Max("price".to_string()),
|
||||
AggregateFunction::Min(2), // price is at index 2
|
||||
AggregateFunction::Max(2), // price is at index 2
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string(), "price".to_string()],
|
||||
);
|
||||
@@ -2475,11 +2475,11 @@ mod tests {
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
let mut agg = AggregateOperator::new(
|
||||
1, // operator_id
|
||||
vec!["category".to_string()], // GROUP BY category
|
||||
1, // operator_id
|
||||
vec![1], // GROUP BY category (index 1)
|
||||
vec![
|
||||
AggregateFunction::Min("price".to_string()),
|
||||
AggregateFunction::Max("price".to_string()),
|
||||
AggregateFunction::Min(3), // price is at index 3
|
||||
AggregateFunction::Max(3), // price is at index 3
|
||||
],
|
||||
vec![
|
||||
"id".to_string(),
|
||||
@@ -2580,8 +2580,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("price".to_string()),
|
||||
AggregateFunction::Max("price".to_string()),
|
||||
AggregateFunction::Min(2), // price is at index 2
|
||||
AggregateFunction::Max(2), // price is at index 2
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string(), "price".to_string()],
|
||||
);
|
||||
@@ -2656,8 +2656,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("score".to_string()),
|
||||
AggregateFunction::Max("score".to_string()),
|
||||
AggregateFunction::Min(2), // score is at index 2
|
||||
AggregateFunction::Max(2), // score is at index 2
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string(), "score".to_string()],
|
||||
);
|
||||
@@ -2724,8 +2724,8 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("name".to_string()),
|
||||
AggregateFunction::Max("name".to_string()),
|
||||
AggregateFunction::Min(1), // name is at index 1
|
||||
AggregateFunction::Max(1), // name is at index 1
|
||||
],
|
||||
vec!["id".to_string(), "name".to_string()],
|
||||
);
|
||||
@@ -2764,10 +2764,10 @@ mod tests {
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Count,
|
||||
AggregateFunction::Sum("value".to_string()),
|
||||
AggregateFunction::Min("value".to_string()),
|
||||
AggregateFunction::Max("value".to_string()),
|
||||
AggregateFunction::Avg("value".to_string()),
|
||||
AggregateFunction::Sum(1), // value is at index 1
|
||||
AggregateFunction::Min(1), // value is at index 1
|
||||
AggregateFunction::Max(1), // value is at index 1
|
||||
AggregateFunction::Avg(1), // value is at index 1
|
||||
],
|
||||
vec!["id".to_string(), "value".to_string()],
|
||||
);
|
||||
@@ -2855,9 +2855,9 @@ mod tests {
|
||||
1, // operator_id
|
||||
vec![], // No GROUP BY
|
||||
vec![
|
||||
AggregateFunction::Min("col1".to_string()),
|
||||
AggregateFunction::Max("col2".to_string()),
|
||||
AggregateFunction::Min("col3".to_string()),
|
||||
AggregateFunction::Min(0), // col1 is at index 0
|
||||
AggregateFunction::Max(1), // col2 is at index 1
|
||||
AggregateFunction::Min(2), // col3 is at index 2
|
||||
],
|
||||
vec!["col1".to_string(), "col2".to_string(), "col3".to_string()],
|
||||
);
|
||||
|
||||
@@ -19,26 +19,35 @@ use turso_parser::ast;
|
||||
|
||||
/// Result type for preprocessing aggregate expressions
|
||||
type PreprocessAggregateResult = (
|
||||
bool, // needs_pre_projection
|
||||
Vec<LogicalExpr>, // pre_projection_exprs
|
||||
Vec<(String, Type)>, // pre_projection_schema
|
||||
Vec<LogicalExpr>, // modified_aggr_exprs
|
||||
bool, // needs_pre_projection
|
||||
Vec<LogicalExpr>, // pre_projection_exprs
|
||||
Vec<ColumnInfo>, // pre_projection_schema
|
||||
Vec<LogicalExpr>, // modified_aggr_exprs
|
||||
);
|
||||
|
||||
/// Result type for parsing join conditions
|
||||
type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option<LogicalExpr>);
|
||||
|
||||
/// Information about a column in a logical schema
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ColumnInfo {
|
||||
pub name: String,
|
||||
pub ty: Type,
|
||||
pub database: Option<String>,
|
||||
pub table: Option<String>,
|
||||
pub table_alias: Option<String>,
|
||||
}
|
||||
|
||||
/// Schema information for logical plan nodes
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct LogicalSchema {
|
||||
/// Column names and types
|
||||
pub columns: Vec<(String, Type)>,
|
||||
pub columns: Vec<ColumnInfo>,
|
||||
}
|
||||
/// A reference to a schema that can be shared between nodes
|
||||
pub type SchemaRef = Arc<LogicalSchema>;
|
||||
|
||||
impl LogicalSchema {
|
||||
pub fn new(columns: Vec<(String, Type)>) -> Self {
|
||||
pub fn new(columns: Vec<ColumnInfo>) -> Self {
|
||||
Self { columns }
|
||||
}
|
||||
|
||||
@@ -52,11 +61,42 @@ impl LogicalSchema {
|
||||
self.columns.len()
|
||||
}
|
||||
|
||||
pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> {
|
||||
self.columns
|
||||
.iter()
|
||||
.position(|(n, _)| n == name)
|
||||
.map(|idx| (idx, &self.columns[idx].1))
|
||||
pub fn find_column(&self, name: &str, table: Option<&str>) -> Option<(usize, &ColumnInfo)> {
|
||||
if let Some(table_ref) = table {
|
||||
// Check if it's a database.table format
|
||||
if table_ref.contains('.') {
|
||||
let parts: Vec<&str> = table_ref.splitn(2, '.').collect();
|
||||
if parts.len() == 2 {
|
||||
let db = parts[0];
|
||||
let tbl = parts[1];
|
||||
return self
|
||||
.columns
|
||||
.iter()
|
||||
.position(|c| {
|
||||
c.name == name
|
||||
&& c.database.as_deref() == Some(db)
|
||||
&& c.table.as_deref() == Some(tbl)
|
||||
})
|
||||
.map(|idx| (idx, &self.columns[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
// Try to match against table alias first, then table name
|
||||
self.columns
|
||||
.iter()
|
||||
.position(|c| {
|
||||
c.name == name
|
||||
&& (c.table_alias.as_deref() == Some(table_ref)
|
||||
|| c.table.as_deref() == Some(table_ref))
|
||||
})
|
||||
.map(|idx| (idx, &self.columns[idx]))
|
||||
} else {
|
||||
// Unqualified lookup - just match by name
|
||||
self.columns
|
||||
.iter()
|
||||
.position(|c| c.name == name)
|
||||
.map(|idx| (idx, &self.columns[idx]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -548,14 +588,14 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
}
|
||||
|
||||
// Regular table scan
|
||||
let table_schema = self.get_table_schema(&table_name)?;
|
||||
let table_alias = alias.as_ref().map(|a| match a {
|
||||
ast::As::As(name) => Self::name_to_string(name),
|
||||
ast::As::Elided(name) => Self::name_to_string(name),
|
||||
});
|
||||
let table_schema = self.get_table_schema(&table_name, table_alias.as_deref())?;
|
||||
Ok(LogicalPlan::TableScan(TableScan {
|
||||
table_name,
|
||||
alias: table_alias,
|
||||
alias: table_alias.clone(),
|
||||
schema: table_schema,
|
||||
projection: None,
|
||||
}))
|
||||
@@ -751,14 +791,14 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
let _left_idx = left_schema
|
||||
.columns
|
||||
.iter()
|
||||
.position(|(n, _)| n == &name)
|
||||
.position(|col| col.name == name)
|
||||
.ok_or_else(|| {
|
||||
LimboError::ParseError(format!("Column {name} not found in left table"))
|
||||
})?;
|
||||
let _right_idx = right_schema
|
||||
.columns
|
||||
.iter()
|
||||
.position(|(n, _)| n == &name)
|
||||
.position(|col| col.name == name)
|
||||
.ok_or_else(|| {
|
||||
LimboError::ParseError(format!("Column {name} not found in right table"))
|
||||
})?;
|
||||
@@ -790,9 +830,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
|
||||
// Find common column names
|
||||
let mut common_columns = Vec::new();
|
||||
for (left_name, _) in &left_schema.columns {
|
||||
if right_schema.columns.iter().any(|(n, _)| n == left_name) {
|
||||
common_columns.push(ast::Name::Ident(left_name.clone()));
|
||||
for left_col in &left_schema.columns {
|
||||
if right_schema
|
||||
.columns
|
||||
.iter()
|
||||
.any(|col| col.name == left_col.name)
|
||||
{
|
||||
common_columns.push(ast::Name::Ident(left_col.name.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -833,10 +877,18 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
let left_schema = left.schema();
|
||||
let right_schema = right.schema();
|
||||
|
||||
// For now, simply concatenate the schemas
|
||||
// In a real implementation, we'd handle column name conflicts and nullable columns
|
||||
let mut columns = left_schema.columns.clone();
|
||||
columns.extend(right_schema.columns.clone());
|
||||
// Concatenate the schemas, preserving all column information
|
||||
let mut columns = Vec::new();
|
||||
|
||||
// Keep all columns from left with their table info
|
||||
for col in &left_schema.columns {
|
||||
columns.push(col.clone());
|
||||
}
|
||||
|
||||
// Keep all columns from right with their table info
|
||||
for col in &right_schema.columns {
|
||||
columns.push(col.clone());
|
||||
}
|
||||
|
||||
Ok(Arc::new(LogicalSchema::new(columns)))
|
||||
}
|
||||
@@ -870,7 +922,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
};
|
||||
let col_type = Self::infer_expr_type(&logical_expr, input_schema)?;
|
||||
|
||||
schema_columns.push((col_name.clone(), col_type));
|
||||
schema_columns.push(ColumnInfo {
|
||||
name: col_name.clone(),
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
|
||||
if let Some(as_alias) = alias {
|
||||
let alias_name = match as_alias {
|
||||
@@ -886,21 +944,21 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
}
|
||||
ast::ResultColumn::Star => {
|
||||
// Expand * to all columns
|
||||
for (name, typ) in &input_schema.columns {
|
||||
proj_exprs.push(LogicalExpr::Column(Column::new(name.clone())));
|
||||
schema_columns.push((name.clone(), *typ));
|
||||
for col in &input_schema.columns {
|
||||
proj_exprs.push(LogicalExpr::Column(Column::new(col.name.clone())));
|
||||
schema_columns.push(col.clone());
|
||||
}
|
||||
}
|
||||
ast::ResultColumn::TableStar(table) => {
|
||||
// Expand table.* to all columns from that table
|
||||
let table_name = Self::name_to_string(table);
|
||||
for (name, typ) in &input_schema.columns {
|
||||
for col in &input_schema.columns {
|
||||
// Simple check - would need proper table tracking in real implementation
|
||||
proj_exprs.push(LogicalExpr::Column(Column::with_table(
|
||||
name.clone(),
|
||||
col.name.clone(),
|
||||
table_name.clone(),
|
||||
)));
|
||||
schema_columns.push((name.clone(), *typ));
|
||||
schema_columns.push(col.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -938,7 +996,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
if let LogicalExpr::Column(col) = expr {
|
||||
pre_projection_exprs.push(expr.clone());
|
||||
let col_type = Self::infer_expr_type(expr, input_schema)?;
|
||||
pre_projection_schema.push((col.name.clone(), col_type));
|
||||
pre_projection_schema.push(ColumnInfo {
|
||||
name: col.name.clone(),
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: col.table.clone(),
|
||||
table_alias: None,
|
||||
});
|
||||
} else {
|
||||
// Complex group by expression - project it
|
||||
needs_pre_projection = true;
|
||||
@@ -946,7 +1010,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
projected_col_counter += 1;
|
||||
pre_projection_exprs.push(expr.clone());
|
||||
let col_type = Self::infer_expr_type(expr, input_schema)?;
|
||||
pre_projection_schema.push((proj_col_name.clone(), col_type));
|
||||
pre_projection_schema.push(ColumnInfo {
|
||||
name: proj_col_name.clone(),
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -970,7 +1040,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
pre_projection_exprs.push(arg.clone());
|
||||
let col_type = Self::infer_expr_type(arg, input_schema)?;
|
||||
if let LogicalExpr::Column(col) = arg {
|
||||
pre_projection_schema.push((col.name.clone(), col_type));
|
||||
pre_projection_schema.push(ColumnInfo {
|
||||
name: col.name.clone(),
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: col.table.clone(),
|
||||
table_alias: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -983,7 +1059,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
// Add the expression to the pre-projection
|
||||
pre_projection_exprs.push(arg.clone());
|
||||
let col_type = Self::infer_expr_type(arg, input_schema)?;
|
||||
pre_projection_schema.push((proj_col_name.clone(), col_type));
|
||||
pre_projection_schema.push(ColumnInfo {
|
||||
name: proj_col_name.clone(),
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
|
||||
// In the aggregate, reference the projected column
|
||||
modified_args.push(LogicalExpr::Column(Column::new(proj_col_name)));
|
||||
@@ -1057,15 +1139,39 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
// First, add GROUP BY columns to the aggregate output schema
|
||||
// These are always part of the aggregate operator's output
|
||||
for group_expr in &group_exprs {
|
||||
let col_name = match group_expr {
|
||||
LogicalExpr::Column(col) => col.name.clone(),
|
||||
match group_expr {
|
||||
LogicalExpr::Column(col) => {
|
||||
// For column references in GROUP BY, preserve the original column info
|
||||
if let Some((_, col_info)) =
|
||||
input_schema.find_column(&col.name, col.table.as_deref())
|
||||
{
|
||||
// Preserve the column with all its table information
|
||||
aggregate_schema_columns.push(col_info.clone());
|
||||
} else {
|
||||
// Fallback if column not found (shouldn't happen)
|
||||
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
|
||||
aggregate_schema_columns.push(ColumnInfo {
|
||||
name: col.name.clone(),
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: col.table.clone(),
|
||||
table_alias: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// For complex GROUP BY expressions, generate a name
|
||||
format!("__group_{}", aggregate_schema_columns.len())
|
||||
let col_name = format!("__group_{}", aggregate_schema_columns.len());
|
||||
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
|
||||
aggregate_schema_columns.push(ColumnInfo {
|
||||
name: col_name,
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
}
|
||||
};
|
||||
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
|
||||
aggregate_schema_columns.push((col_name, col_type));
|
||||
}
|
||||
}
|
||||
|
||||
// Track aggregates we've already seen to avoid duplicates
|
||||
@@ -1098,7 +1204,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
} else {
|
||||
// New aggregate - add it
|
||||
let col_type = Self::infer_expr_type(&logical_expr, input_schema)?;
|
||||
aggregate_schema_columns.push((col_name.clone(), col_type));
|
||||
aggregate_schema_columns.push(ColumnInfo {
|
||||
name: col_name.clone(),
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
aggr_exprs.push(logical_expr);
|
||||
aggregate_map.insert(agg_key, col_name.clone());
|
||||
col_name.clone()
|
||||
@@ -1122,7 +1234,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
// Add only new aggregates
|
||||
for (agg_expr, agg_name) in extracted_aggs {
|
||||
let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?;
|
||||
aggregate_schema_columns.push((agg_name, agg_type));
|
||||
aggregate_schema_columns.push(ColumnInfo {
|
||||
name: agg_name,
|
||||
ty: agg_type,
|
||||
database: None,
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
aggr_exprs.push(agg_expr);
|
||||
}
|
||||
|
||||
@@ -1197,7 +1315,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
// For type inference, we need the aggregate schema for column references
|
||||
let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone());
|
||||
let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?;
|
||||
projection_schema_columns.push((col_name, col_type));
|
||||
projection_schema_columns.push(ColumnInfo {
|
||||
name: col_name,
|
||||
ty: col_type,
|
||||
database: None,
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Create the input plan (with pre-projection if needed)
|
||||
@@ -1220,11 +1344,11 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
|
||||
// Check if we need the outer projection
|
||||
// We need a projection if:
|
||||
// 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id)))
|
||||
// 2. We're selecting a different set of columns than what the aggregate outputs
|
||||
// 3. Columns are renamed or reordered
|
||||
// 1. We have expressions that compute new values (e.g., SUM(x) * 2)
|
||||
// 2. We're selecting a different set of columns than GROUP BY + aggregates
|
||||
// 3. We're reordering columns from their natural aggregate output order
|
||||
let needs_outer_projection = {
|
||||
// Check if any expression is more complex than a simple column reference
|
||||
// Check for complex expressions
|
||||
let has_complex_exprs = projection_exprs
|
||||
.iter()
|
||||
.any(|expr| !matches!(expr, LogicalExpr::Column(_)));
|
||||
@@ -1232,17 +1356,29 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
if has_complex_exprs {
|
||||
true
|
||||
} else {
|
||||
// All are simple columns - check if we're selecting exactly what the aggregate outputs
|
||||
// The projection might be selecting a subset (e.g., only aggregates without group columns)
|
||||
// or reordering columns, or using different names
|
||||
// Check if we're selecting exactly what aggregate outputs in the same order
|
||||
// The aggregate outputs: all GROUP BY columns, then all aggregate expressions
|
||||
// The projection might select a subset or reorder these
|
||||
|
||||
// For now, keep it simple: if schemas don't match exactly, we need projection
|
||||
// This handles all cases: subset selection, reordering, renaming
|
||||
projection_schema_columns != aggregate_schema_columns
|
||||
if projection_exprs.len() != aggregate_schema_columns.len() {
|
||||
// Different number of columns
|
||||
true
|
||||
} else {
|
||||
// Check if columns match in order and name
|
||||
!projection_exprs.iter().zip(&aggregate_schema_columns).all(
|
||||
|(expr, agg_col)| {
|
||||
if let LogicalExpr::Column(col) = expr {
|
||||
col.name == agg_col.name
|
||||
} else {
|
||||
false
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create the aggregate node
|
||||
// Create the aggregate node with its natural schema
|
||||
let aggregate_plan = LogicalPlan::Aggregate(Aggregate {
|
||||
input: aggregate_input,
|
||||
group_expr: group_exprs,
|
||||
@@ -1257,7 +1393,7 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
schema: Arc::new(LogicalSchema::new(projection_schema_columns)),
|
||||
}))
|
||||
} else {
|
||||
// No projection needed - the aggregate output is exactly what we want
|
||||
// No projection needed - aggregate output matches what we want
|
||||
Ok(aggregate_plan)
|
||||
}
|
||||
}
|
||||
@@ -1275,7 +1411,13 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
// Infer schema from first row
|
||||
let mut schema_columns = Vec::new();
|
||||
for (i, _) in values[0].iter().enumerate() {
|
||||
schema_columns.push((format!("column{}", i + 1), Type::Text));
|
||||
schema_columns.push(ColumnInfo {
|
||||
name: format!("column{}", i + 1),
|
||||
ty: Type::Text,
|
||||
database: None,
|
||||
table: None,
|
||||
table_alias: None,
|
||||
});
|
||||
}
|
||||
|
||||
for row in values {
|
||||
@@ -2003,17 +2145,31 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
}
|
||||
|
||||
// Get table schema
|
||||
fn get_table_schema(&self, table_name: &str) -> Result<SchemaRef> {
|
||||
fn get_table_schema(&self, table_name: &str, alias: Option<&str>) -> Result<SchemaRef> {
|
||||
// Look up table in schema
|
||||
let table = self
|
||||
.schema
|
||||
.get_table(table_name)
|
||||
.ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?;
|
||||
|
||||
// Parse table_name which might be "db.table" for attached databases
|
||||
let (database, actual_table) = if table_name.contains('.') {
|
||||
let parts: Vec<&str> = table_name.splitn(2, '.').collect();
|
||||
(Some(parts[0].to_string()), parts[1].to_string())
|
||||
} else {
|
||||
(None, table_name.to_string())
|
||||
};
|
||||
|
||||
let mut columns = Vec::new();
|
||||
for col in table.columns() {
|
||||
if let Some(ref name) = col.name {
|
||||
columns.push((name.clone(), col.ty));
|
||||
columns.push(ColumnInfo {
|
||||
name: name.clone(),
|
||||
ty: col.ty,
|
||||
database: database.clone(),
|
||||
table: Some(actual_table.clone()),
|
||||
table_alias: alias.map(|s| s.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2024,8 +2180,8 @@ impl<'a> LogicalPlanBuilder<'a> {
|
||||
fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result<Type> {
|
||||
match expr {
|
||||
LogicalExpr::Column(col) => {
|
||||
if let Some((_, typ)) = schema.find_column(&col.name) {
|
||||
Ok(*typ)
|
||||
if let Some((_, col_info)) = schema.find_column(&col.name, col.table.as_deref()) {
|
||||
Ok(col_info.ty)
|
||||
} else {
|
||||
Ok(Type::Text)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user