Implement JOINs in the DBSP circuit

This PR improves the DBSP circuit so that it handles the JOIN operator.
The JOIN operator exposes a weakness of our current model: we usually
pass a list of columns between operators, and find the right column by
name when needed.

But with JOINs, many tables can have the same columns. The operators
will then find the wrong column (same name, different table), and
produce incorrect results.

To fix this, we must do two things:
1) Change the Logical Plan. It needs to track table provenance.
2) Fix the aggregators: it needs to operate on indexes, not names.

For the aggregators, note that table provenance is the wrong
abstraction. The aggregator is likely working with a logical table that
is the result of previous nodes in the circuit. So we just need to be
able to tell it which index in the column array it should use.
This commit is contained in:
Glauber Costa
2025-09-16 16:00:15 -05:00
parent 9f3d119a5a
commit f149b40e75
4 changed files with 1625 additions and 310 deletions

View File

@@ -19,20 +19,20 @@ pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both)
#[derive(Debug, Clone, PartialEq)]
pub enum AggregateFunction {
Count,
Sum(String),
Avg(String),
Min(String),
Max(String),
Sum(usize), // Column index
Avg(usize), // Column index
Min(usize), // Column index
Max(usize), // Column index
}
impl Display for AggregateFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AggregateFunction::Count => write!(f, "COUNT(*)"),
AggregateFunction::Sum(col) => write!(f, "SUM({col})"),
AggregateFunction::Avg(col) => write!(f, "AVG({col})"),
AggregateFunction::Min(col) => write!(f, "MIN({col})"),
AggregateFunction::Max(col) => write!(f, "MAX({col})"),
AggregateFunction::Sum(idx) => write!(f, "SUM(col{idx})"),
AggregateFunction::Avg(idx) => write!(f, "AVG(col{idx})"),
AggregateFunction::Min(idx) => write!(f, "MIN(col{idx})"),
AggregateFunction::Max(idx) => write!(f, "MAX(col{idx})"),
}
}
}
@@ -48,16 +48,16 @@ impl AggregateFunction {
/// Returns None if the function is not a supported aggregate
pub fn from_sql_function(
func: &crate::function::Func,
input_column: Option<String>,
input_column_idx: Option<usize>,
) -> Option<Self> {
match func {
Func::Agg(agg_func) => {
match agg_func {
AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count),
AggFunc::Sum => input_column.map(AggregateFunction::Sum),
AggFunc::Avg => input_column.map(AggregateFunction::Avg),
AggFunc::Min => input_column.map(AggregateFunction::Min),
AggFunc::Max => input_column.map(AggregateFunction::Max),
AggFunc::Sum => input_column_idx.map(AggregateFunction::Sum),
AggFunc::Avg => input_column_idx.map(AggregateFunction::Avg),
AggFunc::Min => input_column_idx.map(AggregateFunction::Min),
AggFunc::Max => input_column_idx.map(AggregateFunction::Max),
_ => None, // Other aggregate functions not yet supported in DBSP
}
}
@@ -115,8 +115,8 @@ pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> {
// group_key_str -> (group_key, state)
type ComputedStates = HashMap<String, (Vec<Value>, AggregateState)>;
// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight
pub type MinMaxDeltas = HashMap<String, HashMap<(String, HashableRow), isize>>;
// group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight
pub type MinMaxDeltas = HashMap<String, HashMap<(usize, HashableRow), isize>>;
#[derive(Debug)]
enum AggregateCommitState {
@@ -178,14 +178,14 @@ pub enum AggregateEvalState {
pub struct AggregateOperator {
// Unique operator ID for indexing in persistent storage
pub operator_id: usize,
// GROUP BY columns
group_by: Vec<String>,
// GROUP BY column indices
group_by: Vec<usize>,
// Aggregate functions to compute (including MIN/MAX)
pub aggregates: Vec<AggregateFunction>,
// Column names from input
pub input_column_names: Vec<String>,
// Map from column name to aggregate info for quick lookup
pub column_min_max: HashMap<String, AggColumnInfo>,
// Map from column index to aggregate info for quick lookup
pub column_min_max: HashMap<usize, AggColumnInfo>,
tracker: Option<Arc<Mutex<ComputationTracker>>>,
// State machine for commit operation
@@ -197,14 +197,14 @@ pub struct AggregateOperator {
pub struct AggregateState {
// For COUNT: just the count
pub count: i64,
// For SUM: column_name -> sum value
sums: HashMap<String, f64>,
// For AVG: column_name -> (sum, count) for computing average
avgs: HashMap<String, (f64, i64)>,
// For MIN: column_name -> minimum value
pub mins: HashMap<String, Value>,
// For MAX: column_name -> maximum value
pub maxs: HashMap<String, Value>,
// For SUM: column_index -> sum value
sums: HashMap<usize, f64>,
// For AVG: column_index -> (sum, count) for computing average
avgs: HashMap<usize, (f64, i64)>,
// For MIN: column_index -> minimum value
pub mins: HashMap<usize, Value>,
// For MAX: column_index -> maximum value
pub maxs: HashMap<usize, Value>,
}
impl AggregateEvalState {
@@ -520,14 +520,14 @@ impl AggregateState {
AggregateFunction::Sum(col_name) => {
let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
cursor += 8;
state.sums.insert(col_name.clone(), sum);
state.sums.insert(*col_name, sum);
}
AggregateFunction::Avg(col_name) => {
let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
cursor += 8;
let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
cursor += 8;
state.avgs.insert(col_name.clone(), (sum, count));
state.avgs.insert(*col_name, (sum, count));
}
AggregateFunction::Count => {
// Count was already read above
@@ -540,7 +540,7 @@ impl AggregateState {
if has_value == 1 {
let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?;
cursor += bytes_consumed;
state.mins.insert(col_name.clone(), min_value);
state.mins.insert(*col_name, min_value);
}
}
AggregateFunction::Max(col_name) => {
@@ -551,7 +551,7 @@ impl AggregateState {
if has_value == 1 {
let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?;
cursor += bytes_consumed;
state.maxs.insert(col_name.clone(), max_value);
state.maxs.insert(*col_name, max_value);
}
}
}
@@ -566,7 +566,7 @@ impl AggregateState {
values: &[Value],
weight: isize,
aggregates: &[AggregateFunction],
column_names: &[String],
_column_names: &[String], // No longer needed
) {
// Update COUNT
self.count += weight as i64;
@@ -577,32 +577,26 @@ impl AggregateState {
AggregateFunction::Count => {
// Already handled above
}
AggregateFunction::Sum(col_name) => {
if let Some(idx) = column_names.iter().position(|c| c == col_name) {
if let Some(val) = values.get(idx) {
let num_val = match val {
Value::Integer(i) => *i as f64,
Value::Float(f) => *f,
_ => 0.0,
};
*self.sums.entry(col_name.clone()).or_insert(0.0) +=
num_val * weight as f64;
}
AggregateFunction::Sum(col_idx) => {
if let Some(val) = values.get(*col_idx) {
let num_val = match val {
Value::Integer(i) => *i as f64,
Value::Float(f) => *f,
_ => 0.0,
};
*self.sums.entry(*col_idx).or_insert(0.0) += num_val * weight as f64;
}
}
AggregateFunction::Avg(col_name) => {
if let Some(idx) = column_names.iter().position(|c| c == col_name) {
if let Some(val) = values.get(idx) {
let num_val = match val {
Value::Integer(i) => *i as f64,
Value::Float(f) => *f,
_ => 0.0,
};
let (sum, count) =
self.avgs.entry(col_name.clone()).or_insert((0.0, 0));
*sum += num_val * weight as f64;
*count += weight as i64;
}
AggregateFunction::Avg(col_idx) => {
if let Some(val) = values.get(*col_idx) {
let num_val = match val {
Value::Integer(i) => *i as f64,
Value::Float(f) => *f,
_ => 0.0,
};
let (sum, count) = self.avgs.entry(*col_idx).or_insert((0.0, 0));
*sum += num_val * weight as f64;
*count += weight as i64;
}
}
AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => {
@@ -644,8 +638,8 @@ impl AggregateState {
AggregateFunction::Count => {
result.push(Value::Integer(self.count));
}
AggregateFunction::Sum(col_name) => {
let sum = self.sums.get(col_name).copied().unwrap_or(0.0);
AggregateFunction::Sum(col_idx) => {
let sum = self.sums.get(col_idx).copied().unwrap_or(0.0);
// Return as integer if it's a whole number, otherwise as float
if sum.fract() == 0.0 {
result.push(Value::Integer(sum as i64));
@@ -653,8 +647,8 @@ impl AggregateState {
result.push(Value::Float(sum));
}
}
AggregateFunction::Avg(col_name) => {
if let Some((sum, count)) = self.avgs.get(col_name) {
AggregateFunction::Avg(col_idx) => {
if let Some((sum, count)) = self.avgs.get(col_idx) {
if *count > 0 {
result.push(Value::Float(sum / *count as f64));
} else {
@@ -664,13 +658,13 @@ impl AggregateState {
result.push(Value::Null);
}
}
AggregateFunction::Min(col_name) => {
AggregateFunction::Min(col_idx) => {
// Return the MIN value from our state
result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null));
result.push(self.mins.get(col_idx).cloned().unwrap_or(Value::Null));
}
AggregateFunction::Max(col_name) => {
AggregateFunction::Max(col_idx) => {
// Return the MAX value from our state
result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null));
result.push(self.maxs.get(col_idx).cloned().unwrap_or(Value::Null));
}
}
}
@@ -682,20 +676,20 @@ impl AggregateState {
impl AggregateOperator {
pub fn new(
operator_id: usize,
group_by: Vec<String>,
group_by: Vec<usize>,
aggregates: Vec<AggregateFunction>,
input_column_names: Vec<String>,
) -> Self {
// Build map of column names to their MIN/MAX info with indices
// Build map of column indices to their MIN/MAX info
let mut column_min_max = HashMap::new();
let mut column_indices = HashMap::new();
let mut storage_indices = HashMap::new();
let mut current_index = 0;
// First pass: assign indices to unique MIN/MAX columns
// First pass: assign storage indices to unique MIN/MAX columns
for agg in &aggregates {
match agg {
AggregateFunction::Min(col) | AggregateFunction::Max(col) => {
column_indices.entry(col.clone()).or_insert_with(|| {
AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => {
storage_indices.entry(*col_idx).or_insert_with(|| {
let idx = current_index;
current_index += 1;
idx
@@ -708,19 +702,19 @@ impl AggregateOperator {
// Second pass: build the column info map
for agg in &aggregates {
match agg {
AggregateFunction::Min(col) => {
let index = *column_indices.get(col).unwrap();
let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo {
index,
AggregateFunction::Min(col_idx) => {
let storage_index = *storage_indices.get(col_idx).unwrap();
let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo {
index: storage_index,
has_min: false,
has_max: false,
});
entry.has_min = true;
}
AggregateFunction::Max(col) => {
let index = *column_indices.get(col).unwrap();
let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo {
index,
AggregateFunction::Max(col_idx) => {
let storage_index = *storage_indices.get(col_idx).unwrap();
let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo {
index: storage_index,
has_min: false,
has_max: false,
});
@@ -876,28 +870,24 @@ impl AggregateOperator {
for agg in &self.aggregates {
match agg {
AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => {
if let Some(idx) =
self.input_column_names.iter().position(|c| c == col_name)
{
if let Some(val) = row.values.get(idx) {
// Skip NULL values - they don't participate in MIN/MAX
if val == &Value::Null {
continue;
}
// Create a HashableRow with just this value
// Use 0 as rowid since we only care about the value for comparison
let hashable_value = HashableRow::new(0, vec![val.clone()]);
let key = (col_name.clone(), hashable_value);
let group_entry =
min_max_deltas.entry(group_key_str.clone()).or_default();
let value_entry = group_entry.entry(key).or_insert(0);
// Accumulate the weight
*value_entry += weight;
AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => {
if let Some(val) = row.values.get(*col_idx) {
// Skip NULL values - they don't participate in MIN/MAX
if val == &Value::Null {
continue;
}
// Create a HashableRow with just this value
// Use 0 as rowid since we only care about the value for comparison
let hashable_value = HashableRow::new(0, vec![val.clone()]);
let key = (*col_idx, hashable_value);
let group_entry =
min_max_deltas.entry(group_key_str.clone()).or_default();
let value_entry = group_entry.entry(key).or_insert(0);
// Accumulate the weight
*value_entry += weight;
}
}
_ => {} // Ignore non-MIN/MAX aggregates
@@ -929,13 +919,9 @@ impl AggregateOperator {
pub fn extract_group_key(&self, values: &[Value]) -> Vec<Value> {
let mut key = Vec::new();
for group_col in &self.group_by {
if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) {
if let Some(val) = values.get(idx) {
key.push(val.clone());
} else {
key.push(Value::Null);
}
for &idx in &self.group_by {
if let Some(val) = values.get(idx) {
key.push(val.clone());
} else {
key.push(Value::Null);
}
@@ -1124,13 +1110,13 @@ pub enum RecomputeMinMax {
/// Current column being processed
current_column_idx: usize,
/// Columns to process (combined MIN and MAX)
columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min)
columns_to_process: Vec<(String, usize, bool)>, // (group_key, column_name, is_min)
/// MIN/MAX deltas for checking values and weights
min_max_deltas: MinMaxDeltas,
},
Scan {
/// Columns still to process
columns_to_process: Vec<(String, String, bool)>,
columns_to_process: Vec<(String, usize, bool)>,
/// Current index in columns_to_process (will resume from here)
current_column_idx: usize,
/// MIN/MAX deltas for checking values and weights
@@ -1138,7 +1124,7 @@ pub enum RecomputeMinMax {
/// Current group key being processed
group_key: String,
/// Current column name being processed
column_name: String,
column_name: usize,
/// Whether we're looking for MIN (true) or MAX (false)
is_min: bool,
/// The scan state machine for finding the new MIN/MAX
@@ -1153,7 +1139,7 @@ impl RecomputeMinMax {
existing_groups: &HashMap<String, AggregateState>,
operator: &AggregateOperator,
) -> Self {
let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new();
let mut groups_to_check: HashSet<(String, usize, bool)> = HashSet::new();
// Remember the min_max_deltas are essentially just the only column that is affected by
// this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier
@@ -1173,21 +1159,13 @@ impl RecomputeMinMax {
// Check for MIN
if let Some(current_min) = state.mins.get(col_name) {
if current_min == value {
groups_to_check.insert((
group_key_str.clone(),
col_name.clone(),
true,
));
groups_to_check.insert((group_key_str.clone(), *col_name, true));
}
}
// Check for MAX
if let Some(current_max) = state.maxs.get(col_name) {
if current_max == value {
groups_to_check.insert((
group_key_str.clone(),
col_name.clone(),
false,
));
groups_to_check.insert((group_key_str.clone(), *col_name, false));
}
}
}
@@ -1196,14 +1174,10 @@ impl RecomputeMinMax {
// about this if this is a new record being inserted
if let Some(info) = col_info {
if info.has_min {
groups_to_check.insert((group_key_str.clone(), col_name.clone(), true));
groups_to_check.insert((group_key_str.clone(), *col_name, true));
}
if info.has_max {
groups_to_check.insert((
group_key_str.clone(),
col_name.clone(),
false,
));
groups_to_check.insert((group_key_str.clone(), *col_name, false));
}
}
}
@@ -1245,12 +1219,13 @@ impl RecomputeMinMax {
let (group_key, column_name, is_min) =
columns_to_process[*current_column_idx].clone();
// Get column index from pre-computed info
let column_index = operator
// Column name is already the index
// Get the storage index from column_min_max map
let column_info = operator
.column_min_max
.get(&column_name)
.map(|info| info.index)
.unwrap(); // Should always exist since we're processing known columns
.expect("Column should exist in column_min_max map");
let storage_index = column_info.index;
// Get current value from existing state
let current_value = existing_groups.get(&group_key).and_then(|state| {
@@ -1263,7 +1238,7 @@ impl RecomputeMinMax {
// Create storage keys for index lookup
let storage_id =
generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX);
generate_storage_id(operator.operator_id, storage_index, AGG_TYPE_MINMAX);
let zset_id = operator.generate_group_rowid(&group_key);
// Get the values for this group from min_max_deltas
@@ -1276,7 +1251,7 @@ impl RecomputeMinMax {
Box::new(ScanState::new_for_min(
current_value,
group_key.clone(),
column_name.clone(),
column_name,
storage_id,
zset_id,
group_values,
@@ -1285,7 +1260,7 @@ impl RecomputeMinMax {
Box::new(ScanState::new_for_max(
current_value,
group_key.clone(),
column_name.clone(),
column_name,
storage_id,
zset_id,
group_values,
@@ -1319,12 +1294,12 @@ impl RecomputeMinMax {
if *is_min {
if let Some(min_val) = new_value {
state.mins.insert(column_name.clone(), min_val);
state.mins.insert(*column_name, min_val);
} else {
state.mins.remove(column_name);
}
} else if let Some(max_val) = new_value {
state.maxs.insert(column_name.clone(), max_val);
state.maxs.insert(*column_name, max_val);
} else {
state.maxs.remove(column_name);
}
@@ -1355,13 +1330,13 @@ pub enum ScanState {
/// Group key being processed
group_key: String,
/// Column name being processed
column_name: String,
column_name: usize,
/// Storage ID for the index seek
storage_id: i64,
/// ZSet ID for the group
zset_id: i64,
/// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight
group_values: HashMap<(String, HashableRow), isize>,
group_values: HashMap<(usize, HashableRow), isize>,
/// Whether we're looking for MIN (true) or MAX (false)
is_min: bool,
},
@@ -1371,13 +1346,13 @@ pub enum ScanState {
/// Group key being processed
group_key: String,
/// Column name being processed
column_name: String,
column_name: usize,
/// Storage ID for the index seek
storage_id: i64,
/// ZSet ID for the group
zset_id: i64,
/// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight
group_values: HashMap<(String, HashableRow), isize>,
group_values: HashMap<(usize, HashableRow), isize>,
/// Whether we're looking for MIN (true) or MAX (false)
is_min: bool,
},
@@ -1391,10 +1366,10 @@ impl ScanState {
pub fn new_for_min(
current_min: Option<Value>,
group_key: String,
column_name: String,
column_name: usize,
storage_id: i64,
zset_id: i64,
group_values: HashMap<(String, HashableRow), isize>,
group_values: HashMap<(usize, HashableRow), isize>,
) -> Self {
Self::CheckCandidate {
candidate: current_min,
@@ -1460,10 +1435,10 @@ impl ScanState {
pub fn new_for_max(
current_max: Option<Value>,
group_key: String,
column_name: String,
column_name: usize,
storage_id: i64,
zset_id: i64,
group_values: HashMap<(String, HashableRow), isize>,
group_values: HashMap<(usize, HashableRow), isize>,
) -> Self {
Self::CheckCandidate {
candidate: current_max,
@@ -1496,7 +1471,7 @@ impl ScanState {
// Check if the candidate is retracted (weight <= 0)
// Create a HashableRow to look up the weight
let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]);
let key = (column_name.clone(), hashable_cand);
let key = (*column_name, hashable_cand);
let is_retracted =
group_values.get(&key).is_some_and(|weight| *weight <= 0);
@@ -1633,7 +1608,7 @@ pub enum MinMaxPersistState {
group_idx: usize,
value_idx: usize,
value: Value,
column_name: String,
column_name: usize,
weight: isize,
write_row: WriteRow,
},
@@ -1652,7 +1627,7 @@ impl MinMaxPersistState {
pub fn persist_min_max(
&mut self,
operator_id: usize,
column_min_max: &HashMap<String, AggColumnInfo>,
column_min_max: &HashMap<usize, AggColumnInfo>,
cursors: &mut DbspStateCursors,
generate_group_rowid: impl Fn(&str) -> i64,
) -> Result<IOResult<()>> {
@@ -1699,7 +1674,7 @@ impl MinMaxPersistState {
// Process current value and extract what we need before taking ownership
let ((column_name, hashable_row), weight) = values_vec[*value_idx];
let column_name = column_name.clone();
let column_name = *column_name;
let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow
let weight = *weight;
@@ -1731,9 +1706,9 @@ impl MinMaxPersistState {
let group_key_str = &group_keys[*group_idx];
// Get the column index from the pre-computed map
// Get the column info from the pre-computed map
let column_info = column_min_max
.get(&*column_name)
.get(column_name)
.expect("Column should exist in column_min_max map");
let column_index = column_info.index;

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@
// Based on Feldera DBSP design but adapted for Turso's architecture
pub use crate::incremental::aggregate_operator::{
AggregateEvalState, AggregateFunction, AggregateOperator, AggregateState,
AggregateEvalState, AggregateFunction, AggregateState,
};
pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate};
pub use crate::incremental::input_operator::InputOperator;
@@ -251,7 +251,7 @@ pub trait IncrementalOperator: Debug {
#[cfg(test)]
mod tests {
use super::*;
use crate::incremental::aggregate_operator::AGG_TYPE_REGULAR;
use crate::incremental::aggregate_operator::{AggregateOperator, AGG_TYPE_REGULAR};
use crate::incremental::dbsp::HashableRow;
use crate::storage::pager::CreateBTreeFlags;
use crate::types::Text;
@@ -395,9 +395,9 @@ mod tests {
// Create an aggregate operator for SUM(age) with no GROUP BY
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec![], // No GROUP BY
vec![AggregateFunction::Sum("age".to_string())],
1, // operator_id for testing
vec![], // No GROUP BY
vec![AggregateFunction::Sum(2)], // age is at index 2
vec!["id".to_string(), "name".to_string(), "age".to_string()],
);
@@ -514,9 +514,9 @@ mod tests {
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec!["team".to_string()], // GROUP BY team
vec![AggregateFunction::Sum("score".to_string())],
1, // operator_id for testing
vec![1], // GROUP BY team (index 1)
vec![AggregateFunction::Sum(3)], // score is at index 3
vec![
"id".to_string(),
"team".to_string(),
@@ -666,8 +666,8 @@ mod tests {
// Create COUNT(*) GROUP BY category
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec!["category".to_string()],
1, // operator_id for testing
vec![1], // category is at index 1
vec![AggregateFunction::Count],
vec![
"item_id".to_string(),
@@ -746,9 +746,9 @@ mod tests {
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec!["product".to_string()],
vec![AggregateFunction::Sum("amount".to_string())],
1, // operator_id for testing
vec![1], // product is at index 1
vec![AggregateFunction::Sum(2)], // amount is at index 2
vec![
"sale_id".to_string(),
"product".to_string(),
@@ -843,11 +843,11 @@ mod tests {
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec!["user_id".to_string()],
1, // operator_id for testing
vec![1], // user_id is at index 1
vec![
AggregateFunction::Count,
AggregateFunction::Sum("amount".to_string()),
AggregateFunction::Sum(2), // amount is at index 2
],
vec![
"order_id".to_string(),
@@ -935,9 +935,9 @@ mod tests {
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec!["category".to_string()],
vec![AggregateFunction::Avg("value".to_string())],
1, // operator_id for testing
vec![1], // category is at index 1
vec![AggregateFunction::Avg(2)], // value is at index 2
vec![
"id".to_string(),
"category".to_string(),
@@ -1035,11 +1035,11 @@ mod tests {
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec!["category".to_string()],
1, // operator_id for testing
vec![1], // category is at index 1
vec![
AggregateFunction::Count,
AggregateFunction::Sum("value".to_string()),
AggregateFunction::Sum(2), // value is at index 2
],
vec![
"id".to_string(),
@@ -1108,7 +1108,7 @@ mod tests {
#[test]
fn test_count_aggregation_with_deletions() {
let aggregates = vec![AggregateFunction::Count];
let group_by = vec!["category".to_string()];
let group_by = vec![0]; // category is at index 0
let input_columns = vec!["category".to_string(), "value".to_string()];
// Create a persistent pager for the test
@@ -1197,8 +1197,8 @@ mod tests {
#[test]
fn test_sum_aggregation_with_deletions() {
let aggregates = vec![AggregateFunction::Sum("value".to_string())];
let group_by = vec!["category".to_string()];
let aggregates = vec![AggregateFunction::Sum(1)]; // value is at index 1
let group_by = vec![0]; // category is at index 0
let input_columns = vec!["category".to_string(), "value".to_string()];
// Create a persistent pager for the test
@@ -1281,8 +1281,8 @@ mod tests {
#[test]
fn test_avg_aggregation_with_deletions() {
let aggregates = vec![AggregateFunction::Avg("value".to_string())];
let group_by = vec!["category".to_string()];
let aggregates = vec![AggregateFunction::Avg(1)]; // value is at index 1
let group_by = vec![0]; // category is at index 0
let input_columns = vec!["category".to_string(), "value".to_string()];
// Create a persistent pager for the test
@@ -1348,10 +1348,10 @@ mod tests {
// Test COUNT, SUM, and AVG together
let aggregates = vec![
AggregateFunction::Count,
AggregateFunction::Sum("value".to_string()),
AggregateFunction::Avg("value".to_string()),
AggregateFunction::Sum(1), // value is at index 1
AggregateFunction::Avg(1), // value is at index 1
];
let group_by = vec!["category".to_string()];
let group_by = vec![0]; // category is at index 0
let input_columns = vec!["category".to_string(), "value".to_string()];
// Create a persistent pager for the test
@@ -1607,11 +1607,11 @@ mod tests {
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec!["category".to_string()],
1, // operator_id for testing
vec![1], // category is at index 1
vec![
AggregateFunction::Count,
AggregateFunction::Sum("amount".to_string()),
AggregateFunction::Sum(2), // amount is at index 2
],
vec![
"id".to_string(),
@@ -1781,7 +1781,7 @@ mod tests {
vec![], // No GROUP BY
vec![
AggregateFunction::Count,
AggregateFunction::Sum("value".to_string()),
AggregateFunction::Sum(1), // value is at index 1
],
vec!["id".to_string(), "value".to_string()],
);
@@ -1859,8 +1859,8 @@ mod tests {
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
let mut agg = AggregateOperator::new(
1, // operator_id for testing
vec!["type".to_string()],
1, // operator_id for testing
vec![1], // type is at index 1
vec![AggregateFunction::Count],
vec!["id".to_string(), "type".to_string()],
);
@@ -1976,8 +1976,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("price".to_string()),
AggregateFunction::Max("price".to_string()),
AggregateFunction::Min(2), // price is at index 2
AggregateFunction::Max(2), // price is at index 2
],
vec!["id".to_string(), "name".to_string(), "price".to_string()],
);
@@ -2044,8 +2044,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("price".to_string()),
AggregateFunction::Max("price".to_string()),
AggregateFunction::Min(2), // price is at index 2
AggregateFunction::Max(2), // price is at index 2
],
vec!["id".to_string(), "name".to_string(), "price".to_string()],
);
@@ -2134,8 +2134,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("price".to_string()),
AggregateFunction::Max("price".to_string()),
AggregateFunction::Min(2), // price is at index 2
AggregateFunction::Max(2), // price is at index 2
],
vec!["id".to_string(), "name".to_string(), "price".to_string()],
);
@@ -2224,8 +2224,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("price".to_string()),
AggregateFunction::Max("price".to_string()),
AggregateFunction::Min(2), // price is at index 2
AggregateFunction::Max(2), // price is at index 2
],
vec!["id".to_string(), "name".to_string(), "price".to_string()],
);
@@ -2306,8 +2306,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("price".to_string()),
AggregateFunction::Max("price".to_string()),
AggregateFunction::Min(2), // price is at index 2
AggregateFunction::Max(2), // price is at index 2
],
vec!["id".to_string(), "name".to_string(), "price".to_string()],
);
@@ -2388,8 +2388,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("price".to_string()),
AggregateFunction::Max("price".to_string()),
AggregateFunction::Min(2), // price is at index 2
AggregateFunction::Max(2), // price is at index 2
],
vec!["id".to_string(), "name".to_string(), "price".to_string()],
);
@@ -2475,11 +2475,11 @@ mod tests {
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
let mut agg = AggregateOperator::new(
1, // operator_id
vec!["category".to_string()], // GROUP BY category
1, // operator_id
vec![1], // GROUP BY category (index 1)
vec![
AggregateFunction::Min("price".to_string()),
AggregateFunction::Max("price".to_string()),
AggregateFunction::Min(3), // price is at index 3
AggregateFunction::Max(3), // price is at index 3
],
vec![
"id".to_string(),
@@ -2580,8 +2580,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("price".to_string()),
AggregateFunction::Max("price".to_string()),
AggregateFunction::Min(2), // price is at index 2
AggregateFunction::Max(2), // price is at index 2
],
vec!["id".to_string(), "name".to_string(), "price".to_string()],
);
@@ -2656,8 +2656,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("score".to_string()),
AggregateFunction::Max("score".to_string()),
AggregateFunction::Min(2), // score is at index 2
AggregateFunction::Max(2), // score is at index 2
],
vec!["id".to_string(), "name".to_string(), "score".to_string()],
);
@@ -2724,8 +2724,8 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("name".to_string()),
AggregateFunction::Max("name".to_string()),
AggregateFunction::Min(1), // name is at index 1
AggregateFunction::Max(1), // name is at index 1
],
vec!["id".to_string(), "name".to_string()],
);
@@ -2764,10 +2764,10 @@ mod tests {
vec![], // No GROUP BY
vec![
AggregateFunction::Count,
AggregateFunction::Sum("value".to_string()),
AggregateFunction::Min("value".to_string()),
AggregateFunction::Max("value".to_string()),
AggregateFunction::Avg("value".to_string()),
AggregateFunction::Sum(1), // value is at index 1
AggregateFunction::Min(1), // value is at index 1
AggregateFunction::Max(1), // value is at index 1
AggregateFunction::Avg(1), // value is at index 1
],
vec!["id".to_string(), "value".to_string()],
);
@@ -2855,9 +2855,9 @@ mod tests {
1, // operator_id
vec![], // No GROUP BY
vec![
AggregateFunction::Min("col1".to_string()),
AggregateFunction::Max("col2".to_string()),
AggregateFunction::Min("col3".to_string()),
AggregateFunction::Min(0), // col1 is at index 0
AggregateFunction::Max(1), // col2 is at index 1
AggregateFunction::Min(2), // col3 is at index 2
],
vec!["col1".to_string(), "col2".to_string(), "col3".to_string()],
);

View File

@@ -19,26 +19,35 @@ use turso_parser::ast;
/// Result type for preprocessing aggregate expressions
type PreprocessAggregateResult = (
bool, // needs_pre_projection
Vec<LogicalExpr>, // pre_projection_exprs
Vec<(String, Type)>, // pre_projection_schema
Vec<LogicalExpr>, // modified_aggr_exprs
bool, // needs_pre_projection
Vec<LogicalExpr>, // pre_projection_exprs
Vec<ColumnInfo>, // pre_projection_schema
Vec<LogicalExpr>, // modified_aggr_exprs
);
/// Result type for parsing join conditions
type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option<LogicalExpr>);
/// Information about a column in a logical schema
#[derive(Debug, Clone, PartialEq)]
pub struct ColumnInfo {
pub name: String,
pub ty: Type,
pub database: Option<String>,
pub table: Option<String>,
pub table_alias: Option<String>,
}
/// Schema information for logical plan nodes
#[derive(Debug, Clone, PartialEq)]
pub struct LogicalSchema {
/// Column names and types
pub columns: Vec<(String, Type)>,
pub columns: Vec<ColumnInfo>,
}
/// A reference to a schema that can be shared between nodes
pub type SchemaRef = Arc<LogicalSchema>;
impl LogicalSchema {
pub fn new(columns: Vec<(String, Type)>) -> Self {
pub fn new(columns: Vec<ColumnInfo>) -> Self {
Self { columns }
}
@@ -52,11 +61,42 @@ impl LogicalSchema {
self.columns.len()
}
pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> {
self.columns
.iter()
.position(|(n, _)| n == name)
.map(|idx| (idx, &self.columns[idx].1))
pub fn find_column(&self, name: &str, table: Option<&str>) -> Option<(usize, &ColumnInfo)> {
if let Some(table_ref) = table {
// Check if it's a database.table format
if table_ref.contains('.') {
let parts: Vec<&str> = table_ref.splitn(2, '.').collect();
if parts.len() == 2 {
let db = parts[0];
let tbl = parts[1];
return self
.columns
.iter()
.position(|c| {
c.name == name
&& c.database.as_deref() == Some(db)
&& c.table.as_deref() == Some(tbl)
})
.map(|idx| (idx, &self.columns[idx]));
}
}
// Try to match against table alias first, then table name
self.columns
.iter()
.position(|c| {
c.name == name
&& (c.table_alias.as_deref() == Some(table_ref)
|| c.table.as_deref() == Some(table_ref))
})
.map(|idx| (idx, &self.columns[idx]))
} else {
// Unqualified lookup - just match by name
self.columns
.iter()
.position(|c| c.name == name)
.map(|idx| (idx, &self.columns[idx]))
}
}
}
@@ -548,14 +588,14 @@ impl<'a> LogicalPlanBuilder<'a> {
}
// Regular table scan
let table_schema = self.get_table_schema(&table_name)?;
let table_alias = alias.as_ref().map(|a| match a {
ast::As::As(name) => Self::name_to_string(name),
ast::As::Elided(name) => Self::name_to_string(name),
});
let table_schema = self.get_table_schema(&table_name, table_alias.as_deref())?;
Ok(LogicalPlan::TableScan(TableScan {
table_name,
alias: table_alias,
alias: table_alias.clone(),
schema: table_schema,
projection: None,
}))
@@ -751,14 +791,14 @@ impl<'a> LogicalPlanBuilder<'a> {
let _left_idx = left_schema
.columns
.iter()
.position(|(n, _)| n == &name)
.position(|col| col.name == name)
.ok_or_else(|| {
LimboError::ParseError(format!("Column {name} not found in left table"))
})?;
let _right_idx = right_schema
.columns
.iter()
.position(|(n, _)| n == &name)
.position(|col| col.name == name)
.ok_or_else(|| {
LimboError::ParseError(format!("Column {name} not found in right table"))
})?;
@@ -790,9 +830,13 @@ impl<'a> LogicalPlanBuilder<'a> {
// Find common column names
let mut common_columns = Vec::new();
for (left_name, _) in &left_schema.columns {
if right_schema.columns.iter().any(|(n, _)| n == left_name) {
common_columns.push(ast::Name::Ident(left_name.clone()));
for left_col in &left_schema.columns {
if right_schema
.columns
.iter()
.any(|col| col.name == left_col.name)
{
common_columns.push(ast::Name::Ident(left_col.name.clone()));
}
}
@@ -833,10 +877,18 @@ impl<'a> LogicalPlanBuilder<'a> {
let left_schema = left.schema();
let right_schema = right.schema();
// For now, simply concatenate the schemas
// In a real implementation, we'd handle column name conflicts and nullable columns
let mut columns = left_schema.columns.clone();
columns.extend(right_schema.columns.clone());
// Concatenate the schemas, preserving all column information
let mut columns = Vec::new();
// Keep all columns from left with their table info
for col in &left_schema.columns {
columns.push(col.clone());
}
// Keep all columns from right with their table info
for col in &right_schema.columns {
columns.push(col.clone());
}
Ok(Arc::new(LogicalSchema::new(columns)))
}
@@ -870,7 +922,13 @@ impl<'a> LogicalPlanBuilder<'a> {
};
let col_type = Self::infer_expr_type(&logical_expr, input_schema)?;
schema_columns.push((col_name.clone(), col_type));
schema_columns.push(ColumnInfo {
name: col_name.clone(),
ty: col_type,
database: None,
table: None,
table_alias: None,
});
if let Some(as_alias) = alias {
let alias_name = match as_alias {
@@ -886,21 +944,21 @@ impl<'a> LogicalPlanBuilder<'a> {
}
ast::ResultColumn::Star => {
// Expand * to all columns
for (name, typ) in &input_schema.columns {
proj_exprs.push(LogicalExpr::Column(Column::new(name.clone())));
schema_columns.push((name.clone(), *typ));
for col in &input_schema.columns {
proj_exprs.push(LogicalExpr::Column(Column::new(col.name.clone())));
schema_columns.push(col.clone());
}
}
ast::ResultColumn::TableStar(table) => {
// Expand table.* to all columns from that table
let table_name = Self::name_to_string(table);
for (name, typ) in &input_schema.columns {
for col in &input_schema.columns {
// Simple check - would need proper table tracking in real implementation
proj_exprs.push(LogicalExpr::Column(Column::with_table(
name.clone(),
col.name.clone(),
table_name.clone(),
)));
schema_columns.push((name.clone(), *typ));
schema_columns.push(col.clone());
}
}
}
@@ -938,7 +996,13 @@ impl<'a> LogicalPlanBuilder<'a> {
if let LogicalExpr::Column(col) = expr {
pre_projection_exprs.push(expr.clone());
let col_type = Self::infer_expr_type(expr, input_schema)?;
pre_projection_schema.push((col.name.clone(), col_type));
pre_projection_schema.push(ColumnInfo {
name: col.name.clone(),
ty: col_type,
database: None,
table: col.table.clone(),
table_alias: None,
});
} else {
// Complex group by expression - project it
needs_pre_projection = true;
@@ -946,7 +1010,13 @@ impl<'a> LogicalPlanBuilder<'a> {
projected_col_counter += 1;
pre_projection_exprs.push(expr.clone());
let col_type = Self::infer_expr_type(expr, input_schema)?;
pre_projection_schema.push((proj_col_name.clone(), col_type));
pre_projection_schema.push(ColumnInfo {
name: proj_col_name.clone(),
ty: col_type,
database: None,
table: None,
table_alias: None,
});
}
}
@@ -970,7 +1040,13 @@ impl<'a> LogicalPlanBuilder<'a> {
pre_projection_exprs.push(arg.clone());
let col_type = Self::infer_expr_type(arg, input_schema)?;
if let LogicalExpr::Column(col) = arg {
pre_projection_schema.push((col.name.clone(), col_type));
pre_projection_schema.push(ColumnInfo {
name: col.name.clone(),
ty: col_type,
database: None,
table: col.table.clone(),
table_alias: None,
});
}
}
}
@@ -983,7 +1059,13 @@ impl<'a> LogicalPlanBuilder<'a> {
// Add the expression to the pre-projection
pre_projection_exprs.push(arg.clone());
let col_type = Self::infer_expr_type(arg, input_schema)?;
pre_projection_schema.push((proj_col_name.clone(), col_type));
pre_projection_schema.push(ColumnInfo {
name: proj_col_name.clone(),
ty: col_type,
database: None,
table: None,
table_alias: None,
});
// In the aggregate, reference the projected column
modified_args.push(LogicalExpr::Column(Column::new(proj_col_name)));
@@ -1057,15 +1139,39 @@ impl<'a> LogicalPlanBuilder<'a> {
// First, add GROUP BY columns to the aggregate output schema
// These are always part of the aggregate operator's output
for group_expr in &group_exprs {
let col_name = match group_expr {
LogicalExpr::Column(col) => col.name.clone(),
match group_expr {
LogicalExpr::Column(col) => {
// For column references in GROUP BY, preserve the original column info
if let Some((_, col_info)) =
input_schema.find_column(&col.name, col.table.as_deref())
{
// Preserve the column with all its table information
aggregate_schema_columns.push(col_info.clone());
} else {
// Fallback if column not found (shouldn't happen)
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
aggregate_schema_columns.push(ColumnInfo {
name: col.name.clone(),
ty: col_type,
database: None,
table: col.table.clone(),
table_alias: None,
});
}
}
_ => {
// For complex GROUP BY expressions, generate a name
format!("__group_{}", aggregate_schema_columns.len())
let col_name = format!("__group_{}", aggregate_schema_columns.len());
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
aggregate_schema_columns.push(ColumnInfo {
name: col_name,
ty: col_type,
database: None,
table: None,
table_alias: None,
});
}
};
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
aggregate_schema_columns.push((col_name, col_type));
}
}
// Track aggregates we've already seen to avoid duplicates
@@ -1098,7 +1204,13 @@ impl<'a> LogicalPlanBuilder<'a> {
} else {
// New aggregate - add it
let col_type = Self::infer_expr_type(&logical_expr, input_schema)?;
aggregate_schema_columns.push((col_name.clone(), col_type));
aggregate_schema_columns.push(ColumnInfo {
name: col_name.clone(),
ty: col_type,
database: None,
table: None,
table_alias: None,
});
aggr_exprs.push(logical_expr);
aggregate_map.insert(agg_key, col_name.clone());
col_name.clone()
@@ -1122,7 +1234,13 @@ impl<'a> LogicalPlanBuilder<'a> {
// Add only new aggregates
for (agg_expr, agg_name) in extracted_aggs {
let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?;
aggregate_schema_columns.push((agg_name, agg_type));
aggregate_schema_columns.push(ColumnInfo {
name: agg_name,
ty: agg_type,
database: None,
table: None,
table_alias: None,
});
aggr_exprs.push(agg_expr);
}
@@ -1197,7 +1315,13 @@ impl<'a> LogicalPlanBuilder<'a> {
// For type inference, we need the aggregate schema for column references
let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone());
let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?;
projection_schema_columns.push((col_name, col_type));
projection_schema_columns.push(ColumnInfo {
name: col_name,
ty: col_type,
database: None,
table: None,
table_alias: None,
});
}
// Create the input plan (with pre-projection if needed)
@@ -1220,11 +1344,11 @@ impl<'a> LogicalPlanBuilder<'a> {
// Check if we need the outer projection
// We need a projection if:
// 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id)))
// 2. We're selecting a different set of columns than what the aggregate outputs
// 3. Columns are renamed or reordered
// 1. We have expressions that compute new values (e.g., SUM(x) * 2)
// 2. We're selecting a different set of columns than GROUP BY + aggregates
// 3. We're reordering columns from their natural aggregate output order
let needs_outer_projection = {
// Check if any expression is more complex than a simple column reference
// Check for complex expressions
let has_complex_exprs = projection_exprs
.iter()
.any(|expr| !matches!(expr, LogicalExpr::Column(_)));
@@ -1232,17 +1356,29 @@ impl<'a> LogicalPlanBuilder<'a> {
if has_complex_exprs {
true
} else {
// All are simple columns - check if we're selecting exactly what the aggregate outputs
// The projection might be selecting a subset (e.g., only aggregates without group columns)
// or reordering columns, or using different names
// Check if we're selecting exactly what aggregate outputs in the same order
// The aggregate outputs: all GROUP BY columns, then all aggregate expressions
// The projection might select a subset or reorder these
// For now, keep it simple: if schemas don't match exactly, we need projection
// This handles all cases: subset selection, reordering, renaming
projection_schema_columns != aggregate_schema_columns
if projection_exprs.len() != aggregate_schema_columns.len() {
// Different number of columns
true
} else {
// Check if columns match in order and name
!projection_exprs.iter().zip(&aggregate_schema_columns).all(
|(expr, agg_col)| {
if let LogicalExpr::Column(col) = expr {
col.name == agg_col.name
} else {
false
}
},
)
}
}
};
// Create the aggregate node
// Create the aggregate node with its natural schema
let aggregate_plan = LogicalPlan::Aggregate(Aggregate {
input: aggregate_input,
group_expr: group_exprs,
@@ -1257,7 +1393,7 @@ impl<'a> LogicalPlanBuilder<'a> {
schema: Arc::new(LogicalSchema::new(projection_schema_columns)),
}))
} else {
// No projection needed - the aggregate output is exactly what we want
// No projection needed - aggregate output matches what we want
Ok(aggregate_plan)
}
}
@@ -1275,7 +1411,13 @@ impl<'a> LogicalPlanBuilder<'a> {
// Infer schema from first row
let mut schema_columns = Vec::new();
for (i, _) in values[0].iter().enumerate() {
schema_columns.push((format!("column{}", i + 1), Type::Text));
schema_columns.push(ColumnInfo {
name: format!("column{}", i + 1),
ty: Type::Text,
database: None,
table: None,
table_alias: None,
});
}
for row in values {
@@ -2003,17 +2145,31 @@ impl<'a> LogicalPlanBuilder<'a> {
}
// Get table schema
fn get_table_schema(&self, table_name: &str) -> Result<SchemaRef> {
fn get_table_schema(&self, table_name: &str, alias: Option<&str>) -> Result<SchemaRef> {
// Look up table in schema
let table = self
.schema
.get_table(table_name)
.ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?;
// Parse table_name which might be "db.table" for attached databases
let (database, actual_table) = if table_name.contains('.') {
let parts: Vec<&str> = table_name.splitn(2, '.').collect();
(Some(parts[0].to_string()), parts[1].to_string())
} else {
(None, table_name.to_string())
};
let mut columns = Vec::new();
for col in table.columns() {
if let Some(ref name) = col.name {
columns.push((name.clone(), col.ty));
columns.push(ColumnInfo {
name: name.clone(),
ty: col.ty,
database: database.clone(),
table: Some(actual_table.clone()),
table_alias: alias.map(|s| s.to_string()),
});
}
}
@@ -2024,8 +2180,8 @@ impl<'a> LogicalPlanBuilder<'a> {
fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result<Type> {
match expr {
LogicalExpr::Column(col) => {
if let Some((_, typ)) = schema.find_column(&col.name) {
Ok(*typ)
if let Some((_, col_info)) = schema.find_column(&col.name, col.table.as_deref()) {
Ok(col_info.ty)
} else {
Ok(Type::Text)
}